diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index c0d7761f..9578358a 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -49,6 +49,7 @@ Metrics DOTAMeanAP ROUGE NaturalImageQualityEvaluator + OneMinusNormEditDistance Perplexity KeypointEndPointError KeypointAUC diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index c0d7761f..9578358a 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -49,6 +49,7 @@ Metrics DOTAMeanAP ROUGE NaturalImageQualityEvaluator + OneMinusNormEditDistance Perplexity KeypointEndPointError KeypointAUC diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 9d7a66aa..a20aac86 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -22,6 +22,7 @@ from .mse import MeanSquaredError from .niqe import NaturalImageQualityEvaluator from .oid_map import OIDMeanAP +from .one_minus_norm_edit_distance import OneMinusNormEditDistance from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy from .perplexity import Perplexity from .precision_recall_f1score import (MultiLabelPrecisionRecallF1score, @@ -46,7 +47,8 @@ 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', 'WordAccuracy', 'PrecisionRecallF1score', - 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score' + 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score', + 'OneMinusNormEditDistance' ] _deprecated_msg = ( diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py new file mode 100644 index 00000000..f6bd0bec --- /dev/null +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import TYPE_CHECKING, Dict, List, Sequence + +from mmeval.core import BaseMetric +from mmeval.utils import try_import + +if TYPE_CHECKING: + from rapidfuzz.distance import Levenshtein +else: + distance = try_import('rapidfuzz.distance') + if distance is not None: + Levenshtein = distance.Levenshtein + + +class OneMinusNormEditDistance(BaseMetric): + r"""One minus NED metric for text recognition task. + + Args: + letter_case (str): There are three options to alter the letter cases + + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + + Usually, it only works for English characters. Defaults to + 'unchanged'. + invalid_symbol (str): A regular expression to filter out invalid or + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + >>> from mmeval import OneMinusNormEditDistance + >>> metric = OneMinusNormEditDistance() + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.6} + >>> metric = OneMinusNormEditDistance(letter_case='upper') + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.7} + """ + + def __init__(self, + letter_case: str = 'unchanged', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + + assert letter_case in ['unchanged', 'upper', 'lower'] + self.letter_case = letter_case + self.invalid_symbol = re.compile(invalid_symbol) + + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + groundtruths (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, groundtruths): + if self.letter_case in ['upper', 'lower']: + pred = getattr(pred, self.letter_case)() + label = getattr(label, self.letter_case)() + label = self.invalid_symbol.sub('', label) + pred = self.invalid_symbol.sub('', pred) + norm_ed = Levenshtein.normalized_distance(pred, label) + self._results.append(norm_ed) + + def compute_metric(self, results: List[float]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[float]): The processed results of each batch. + + Returns: + dict[str, float]: Nested dicts as results. + + - 1-N.E.D (float): One minus the normalized edit distance. + """ + gt_word_num = len(results) + norm_ed_sum = sum(results) + normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) + metric_results = {} + metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance + return metric_results diff --git a/requirements/optional.txt b/requirements/optional.txt index 02d2642f..24645f58 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ opencv-python!=4.5.5.62,!=4.5.5.64 pycocotools +rapidfuzz shapely diff --git a/setup.cfg b/setup.cfg index b35a1556..3d5e4513 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = dota, rouge +ignore-words-list = dota, rouge, ned [mypy] allow_redefinition = True diff --git a/tests/test_metrics/test_one_minus_norm_edit_distance.py b/tests/test_metrics/test_one_minus_norm_edit_distance.py new file mode 100644 index 00000000..e488f928 --- /dev/null +++ b/tests/test_metrics/test_one_minus_norm_edit_distance.py @@ -0,0 +1,22 @@ +import pytest + +from mmeval import OneMinusNormEditDistance + + +def test_init(): + with pytest.raises(AssertionError): + OneMinusNormEditDistance(letter_case='fake') + + +@pytest.mark.parametrize( + argnames=['letter_case', 'expected'], + argvalues=[('unchanged', 0.6), ('upper', 0.7), ('lower', 0.7)]) +def test_one_minus_norm_edit_distance_metric(letter_case, expected): + metric = OneMinusNormEditDistance(letter_case=letter_case) + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - expected) < 1e-7 + metric.reset() + for pred, label in zip(['helL', 'HEL'], ['hello', 'HELLO']): + metric.add([pred], [label]) + res = metric.compute() + assert abs(res['1-N.E.D'] - expected) < 1e-7