diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 171e7196..94971f1c 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -18,6 +18,7 @@ from .mse import MeanSquaredError from .multi_label import AveragePrecision, MultiLabelMetric from .oid_map import OIDMeanAP +from .one_minus_norm_edit_distance import OneMinusNormEditDistance from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy from .proposal_recall import ProposalRecall from .psnr import PeakSignalNoiseRatio @@ -36,7 +37,7 @@ 'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric', 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', - 'ConnectivityError', 'ROUGE' + 'ConnectivityError', 'ROUGE', 'OneMinusNormEditDistance' ] _deprecated_msg = ( diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py new file mode 100644 index 00000000..660cd5a4 --- /dev/null +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import TYPE_CHECKING, Dict, List, Sequence + +from mmeval.core import BaseMetric +from mmeval.utils import try_import + +if TYPE_CHECKING: + from rapidfuzz.distance import Levenshtein +else: + distance = try_import('rapidfuzz.distance') + if distance is not None: + Levenshtein = distance.Levenshtein + + +class OneMinusNormEditDistance(BaseMetric): + """One minus NED metric for text recognition task. + + Args: + letter_case (str): There are three options to alter the letter cases + - unchanged: Do not change prediction texts and labels. + - upper: Convert prediction texts and labels into uppercase + characters. + - lower: Convert prediction texts and labels into lowercase + characters. + Usually, it only works for English characters. Defaults to + 'unchanged'. + valid_symbol (str): Valid characters. Defaults to + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + + Example: + >>> from mmeval import OneMinusNormEditDistance + >>> metric = OneMinusNormEditDistance() + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.6} + >>> metric = OneMinusNormEditDistance(letter_case='upper') + >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) + {'1-N.E.D': 0.7} + """ + + def __init__(self, + letter_case: str = 'unchanged', + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + + assert letter_case in ['unchanged', 'upper', 'lower'] + self.letter_case = letter_case + self.valid_symbol = re.compile(valid_symbol) + + def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + labels (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, labels): + if self.letter_case in ['upper', 'lower']: + pred = getattr(pred, self.letter_case)() + label = getattr(label, self.letter_case)() + label = self.valid_symbol.sub('', label) + pred = self.valid_symbol.sub('', pred) + norm_ed = Levenshtein.normalized_distance(pred, label) + self._results.append(norm_ed) + + def compute_metric(self, results: List[float]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[float]): The processed results of each batch. + + Returns: + dict[str, float]: Nested dicts as results. + - 1-N.E.D (float): One minus the normalized edit distance. + """ + gt_word_num = len(results) + norm_ed_sum = sum(results) + normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) + eval_res = {} + eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance + return eval_res diff --git a/setup.cfg b/setup.cfg index b35a1556..3d5e4513 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = dota, rouge +ignore-words-list = dota, rouge, ned [mypy] allow_redefinition = True diff --git a/tests/test_metrics/test_one_minus_norm_edit_distance.py b/tests/test_metrics/test_one_minus_norm_edit_distance.py new file mode 100644 index 00000000..b2b3ac38 --- /dev/null +++ b/tests/test_metrics/test_one_minus_norm_edit_distance.py @@ -0,0 +1,23 @@ +import pytest + +from mmeval import OneMinusNormEditDistance + + +def test_init(): + with pytest.raises(AssertionError): + OneMinusNormEditDistance(letter_case='fake') + + +def test_one_minus_norm_edit_distance_metric(): + metric = OneMinusNormEditDistance(letter_case='lower') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.7) < 1e-7 + metric = OneMinusNormEditDistance(letter_case='upper') + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.7) < 1e-7 + metric = OneMinusNormEditDistance() + res = metric(['helL', 'HEL'], ['hello', 'HELLO']) + assert abs(res['1-N.E.D'] - 0.6) < 1e-7 + + +test_one_minus_norm_edit_distance_metric()