From c43b7c04e5294efe318cde036c68ba23a5f284af Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Wed, 15 Feb 2023 17:29:27 +0800 Subject: [PATCH 1/6] [Feature] Add OneMinusNormEditDistance for OCR Task --- mmeval/metrics/__init__.py | 4 +- .../metrics/one_minus_norm_edit_distance.py | 82 +++++++++++++++++++ requirements/optional.txt | 1 + setup.cfg | 2 +- .../test_one_minus_norm_edit_distance.py | 23 ++++++ 5 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 mmeval/metrics/one_minus_norm_edit_distance.py create mode 100644 tests/test_metrics/test_one_minus_norm_edit_distance.py diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 9d7a66aa..a20aac86 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -22,6 +22,7 @@ from .mse import MeanSquaredError from .niqe import NaturalImageQualityEvaluator from .oid_map import OIDMeanAP +from .one_minus_norm_edit_distance import OneMinusNormEditDistance from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy from .perplexity import Perplexity from .precision_recall_f1score import (MultiLabelPrecisionRecallF1score, @@ -46,7 +47,8 @@ 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', 'WordAccuracy', 'PrecisionRecallF1score', - 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score' + 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score', + 'OneMinusNormEditDistance' ] _deprecated_msg = ( diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py new file mode 100644 index 00000000..660cd5a4 --- /dev/null +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import TYPE_CHECKING, Dict, List, Sequence + +from mmeval.core import BaseMetric +from mmeval.utils import try_import + +if TYPE_CHECKING: + from rapidfuzz.distance import Levenshtein +else: + distance = try_import('rapidfuzz.distance') + if distance is not None: + Levenshtein = distance.Levenshtein + + +class OneMinusNormEditDistance(BaseMetric): + """One minus NED metric for text recognition task. + + Args: + letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. + valid_symbol (str): Valid characters. Defaults to + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + + Example: + >>> from mmeval import OneMinusNormEditDistance + >>> metric = OneMinusNormEditDistance() + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.6} + >>> metric = OneMinusNormEditDistance(letter_case='upper') + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.7} + """ + + def __init__(self, + letter_case: str = 'unchanged', + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + + assert letter_case in ['unchanged', 'upper', 'lower'] + self.letter_case = letter_case + self.valid_symbol = re.compile(valid_symbol) + + def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + labels (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, labels): + if self.letter_case in ['upper', 'lower']: + pred = getattr(pred, self.letter_case)() + label = getattr(label, self.letter_case)() + label = self.valid_symbol.sub('', label) + pred = self.valid_symbol.sub('', pred) + norm_ed = Levenshtein.normalized_distance(pred, label) + self._results.append(norm_ed) + + def compute_metric(self, results: List[float]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[float]): The processed results of each batch. + + Returns: + dict[str, float]: Nested dicts as results. + - 1-N.E.D (float): One minus the normalized edit distance. + """ + gt_word_num = len(results) + norm_ed_sum = sum(results) + normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) + eval_res = {} + eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance + return eval_res diff --git a/requirements/optional.txt b/requirements/optional.txt index 02d2642f..24645f58 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ opencv-python!=4.5.5.62,!=4.5.5.64 pycocotools +rapidfuzz shapely diff --git a/setup.cfg b/setup.cfg index b35a1556..3d5e4513 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = dota, rouge +ignore-words-list = dota, rouge, ned [mypy] allow_redefinition = True diff --git a/tests/test_metrics/test_one_minus_norm_edit_distance.py b/tests/test_metrics/test_one_minus_norm_edit_distance.py new file mode 100644 index 00000000..b2b3ac38 --- /dev/null +++ b/tests/test_metrics/test_one_minus_norm_edit_distance.py @@ -0,0 +1,23 @@ +import pytest + +from mmeval import OneMinusNormEditDistance + + +def test_init(): + with pytest.raises(AssertionError): + OneMinusNormEditDistance(letter_case='fake') + + +def test_one_minus_norm_edit_distance_metric(): + metric = OneMinusNormEditDistance(letter_case='lower') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.7) < 1e-7 + metric = OneMinusNormEditDistance(letter_case='upper') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.7) < 1e-7 + metric = OneMinusNormEditDistance() + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.6) < 1e-7 + + +test_one_minus_norm_edit_distance_metric() From 3b2eef30ddcb419d4e983aa70f8bf5155b04aafd Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Fri, 17 Feb 2023 11:37:50 +0800 Subject: [PATCH 2/6] fix comment --- mmeval/metrics/one_minus_norm_edit_distance.py | 5 +++-- .../test_one_minus_norm_edit_distance.py | 18 ++++++------------ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index 660cd5a4..73c5c198 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -25,8 +25,9 @@ class OneMinusNormEditDistance(BaseMetric): characters. Usually, it only works for English characters. Defaults to 'unchanged'. - valid_symbol (str): Valid characters. Defaults to - '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + valid_symbol (str): A regular expression to filter out invalid or + not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. Example: >>> from mmeval import OneMinusNormEditDistance diff --git a/tests/test_metrics/test_one_minus_norm_edit_distance.py b/tests/test_metrics/test_one_minus_norm_edit_distance.py index b2b3ac38..3582d7db 100644 --- a/tests/test_metrics/test_one_minus_norm_edit_distance.py +++ b/tests/test_metrics/test_one_minus_norm_edit_distance.py @@ -8,16 +8,10 @@ def test_init(): OneMinusNormEditDistance(letter_case='fake') -def test_one_minus_norm_edit_distance_metric(): - metric = OneMinusNormEditDistance(letter_case='lower') +@pytest.mark.parametrize( + argnames=['letter_case', 'expected'], + argvalues=[('unchanged', 0.6), ('upper', 0.7), ('lower', 0.7)]) +def test_one_minus_norm_edit_distance_metric(letter_case, expected): + metric = OneMinusNormEditDistance(letter_case=letter_case) res = metric(['helL', 'HEL'], ['hello', 'HELLO']) - assert abs(res['1-N.E.D'] - 0.7) < 1e-7 - metric = OneMinusNormEditDistance(letter_case='upper') - res = metric(['helL', 'HEL'], ['hello', 'HELLO']) - assert abs(res['1-N.E.D'] - 0.7) < 1e-7 - metric = OneMinusNormEditDistance() - res = metric(['helL', 'HEL'], ['hello', 'HELLO']) - assert abs(res['1-N.E.D'] - 0.6) < 1e-7 - - -test_one_minus_norm_edit_distance_metric() + assert abs(res['1-N.E.D'] - expected) < 1e-7 From 6d51dbc369461815595a39fef7c2d842112ff96f Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Fri, 17 Feb 2023 13:40:36 +0800 Subject: [PATCH 3/6] add api doc --- docs/en/api/metrics.rst | 1 + docs/zh_cn/api/metrics.rst | 1 + tests/test_metrics/test_one_minus_norm_edit_distance.py | 5 +++++ 3 files changed, 7 insertions(+) diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index c0d7761f..9578358a 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -49,6 +49,7 @@ Metrics DOTAMeanAP ROUGE NaturalImageQualityEvaluator + OneMinusNormEditDistance Perplexity KeypointEndPointError KeypointAUC diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index c0d7761f..9578358a 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -49,6 +49,7 @@ Metrics DOTAMeanAP ROUGE NaturalImageQualityEvaluator + OneMinusNormEditDistance Perplexity KeypointEndPointError KeypointAUC diff --git a/tests/test_metrics/test_one_minus_norm_edit_distance.py b/tests/test_metrics/test_one_minus_norm_edit_distance.py index 3582d7db..e488f928 100644 --- a/tests/test_metrics/test_one_minus_norm_edit_distance.py +++ b/tests/test_metrics/test_one_minus_norm_edit_distance.py @@ -15,3 +15,8 @@ def test_one_minus_norm_edit_distance_metric(letter_case, expected): metric = OneMinusNormEditDistance(letter_case=letter_case) res = metric(['helL', 'HEL'], ['hello', 'HELLO']) assert abs(res['1-N.E.D'] - expected) < 1e-7 + metric.reset() + for pred, label in zip(['helL', 'HEL'], ['hello', 'HELLO']): + metric.add([pred], [label]) + res = metric.compute() + assert abs(res['1-N.E.D'] - expected) < 1e-7 From 6283ff86bab69df146a8ec45ad6c898c060f4c70 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Mon, 27 Feb 2023 10:53:26 +0800 Subject: [PATCH 4/6] fix doc comment --- mmeval/metrics/one_minus_norm_edit_distance.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index 73c5c198..b6b2687f 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -14,15 +14,17 @@ class OneMinusNormEditDistance(BaseMetric): - """One minus NED metric for text recognition task. + r"""One minus NED metric for text recognition task. Args: letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. - upper: Convert prediction texts and labels into uppercase characters. - lower: Convert prediction texts and labels into lowercase characters. + Usually, it only works for English characters. Defaults to 'unchanged'. valid_symbol (str): A regular expression to filter out invalid or From e0a934ec547a8a60a422df3a5dbf294b2f541f8d Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Tue, 28 Feb 2023 15:10:50 +0800 Subject: [PATCH 5/6] rename valid to invalid --- mmeval/metrics/one_minus_norm_edit_distance.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index b6b2687f..5c5e5be8 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -27,7 +27,7 @@ class OneMinusNormEditDistance(BaseMetric): Usually, it only works for English characters. Defaults to 'unchanged'. - valid_symbol (str): A regular expression to filter out invalid or + invalid_symbol (str): A regular expression to filter out invalid or not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. **kwargs: Keyword parameters passed to :class:`BaseMetric`. @@ -43,13 +43,13 @@ class OneMinusNormEditDistance(BaseMetric): def __init__(self, letter_case: str = 'unchanged', - valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) assert letter_case in ['unchanged', 'upper', 'lower'] self.letter_case = letter_case - self.valid_symbol = re.compile(valid_symbol) + self.invalid_symbol = re.compile(invalid_symbol) def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. @@ -62,8 +62,8 @@ def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignor if self.letter_case in ['upper', 'lower']: pred = getattr(pred, self.letter_case)() label = getattr(label, self.letter_case)() - label = self.valid_symbol.sub('', label) - pred = self.valid_symbol.sub('', pred) + label = self.invalid_symbol.sub('', label) + pred = self.invalid_symbol.sub('', pred) norm_ed = Levenshtein.normalized_distance(pred, label) self._results.append(norm_ed) From dbeb3ea04b65f8b905f00ccd401b14759586ce5f Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Mon, 6 Mar 2023 10:25:31 +0800 Subject: [PATCH 6/6] fix comment --- .../metrics/one_minus_norm_edit_distance.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index 5c5e5be8..f6bd0bec 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric): - unchanged: Do not change prediction texts and labels. - upper: Convert prediction texts and labels into uppercase - characters. + characters. - lower: Convert prediction texts and labels into lowercase - characters. + characters. Usually, it only works for English characters. Defaults to 'unchanged'. invalid_symbol (str): A regular expression to filter out invalid or - not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'. **kwargs: Keyword parameters passed to :class:`BaseMetric`. - Example: + Examples: >>> from mmeval import OneMinusNormEditDistance >>> metric = OneMinusNormEditDistance() >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) @@ -43,7 +43,7 @@ class OneMinusNormEditDistance(BaseMetric): def __init__(self, letter_case: str = 'unchanged', - invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) @@ -51,14 +51,14 @@ def __init__(self, self.letter_case = letter_case self.invalid_symbol = re.compile(invalid_symbol) - def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): if self.letter_case in ['upper', 'lower']: pred = getattr(pred, self.letter_case)() label = getattr(label, self.letter_case)() @@ -75,11 +75,12 @@ def compute_metric(self, results: List[float]) -> Dict: Returns: dict[str, float]: Nested dicts as results. - - 1-N.E.D (float): One minus the normalized edit distance. + + - 1-N.E.D (float): One minus the normalized edit distance. """ gt_word_num = len(results) norm_ed_sum = sum(results) normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) - eval_res = {} - eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance - return eval_res + metric_results = {} + metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance + return metric_results