From 4af8bef2ac65a5bc3f6620298b2abe131d2b94d0 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Wed, 15 Feb 2023 17:02:34 +0800 Subject: [PATCH 1/5] [Feature] Add WordAccuracy for OCR Task --- mmeval/metrics/__init__.py | 4 +- mmeval/metrics/word_accuracy.py | 102 +++++++++++++++++++++++ tests/test_metrics/test_word_accuracy.py | 22 +++++ 3 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 mmeval/metrics/word_accuracy.py create mode 100644 tests/test_metrics/test_word_accuracy.py diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index d63fed6e..9ae21aef 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -32,6 +32,7 @@ from .snr import SignalNoiseRatio from .ssim import StructuralSimilarity from .voc_map import VOCMeanAP +from .word_accuracy import WordAccuracy __all__ = [ 'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError', @@ -42,7 +43,8 @@ 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', - 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator' + 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', + 'WordAccuracy' ] _deprecated_msg = ( diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py new file mode 100644 index 00000000..6bda6e61 --- /dev/null +++ b/mmeval/metrics/word_accuracy.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Dict, List, Sequence, Tuple, Union + +from mmeval.core import BaseMetric + + +class WordAccuracy(BaseMetric): + """Calculate the word level accuracy. + + Args: + mode (str or list[str]): Options are: + - 'exact': Accuracy at word level. + - 'ignore_case': Accuracy at word level, ignoring letter + case. + - 'ignore_case_symbol': Accuracy at word level, ignoring + letter case and symbol. (Default metric for academic evaluation) + If mode is a list, then metrics in mode will be calculated + separately. Defaults to 'ignore_case_symbol'. + valid_symbol (str): Valid characters. Defaults to + '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + + Example: + >>> from mmeval import WordAccuracy + >>> metric = WordAccuracy() + >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + {'ignore_case_symbol_accuracy': 1.0} + >>> metric = WordAccuracy(mode=['exact', 'ignore_case', + >>> 'ignore_case_symbol']) + >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + {'accuracy': 0.333333333, + 'ignore_case_accuracy': 0.666666667, + 'ignore_case_symbol_accuracy': 1.0} + """ + + def __init__(self, + mode: Union[str, Sequence[str]] = 'ignore_case_symbol', + valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + self.mode = mode + self.valid_symbol = re.compile(valid_symbol) + assert isinstance(mode, (str, list)) + if isinstance(mode, str): + mode = [mode] + assert all([isinstance(item, str) for item in mode]) + assert set(mode).issubset( + {'exact', 'ignore_case', 'ignore_case_symbol'}) + self.mode = set(mode) # type: ignore + + def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + labels (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, labels): + num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0 + if 'exact' in self.mode: + num = pred == label + if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: + pred_lower = pred.lower() + label_lower = label.lower() + ignore_case_num = pred_lower == label_lower + if 'ignore_case_symbol' in self.mode: + label_lower_ignore = self.valid_symbol.sub('', label_lower) + pred_lower_ignore = self.valid_symbol.sub('', pred_lower) + ignore_case_symbol_num =\ + label_lower_ignore == pred_lower_ignore + self._results.append( + (num, ignore_case_num, ignore_case_symbol_num)) + + def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[float]): The processed results of each batch. + + Returns: + dict[str, float]: Nested dicts as results. Provided keys are: + - accuracy (float): Accuracy at word level. + - ignore_case_accuracy (float): Accuracy at word level, ignoring + letter case. + - ignore_case_symbol_accuracy (float): Accuracy at word level, + ignoring letter case and symbol. + """ + eval_res = {} + gt_word_num = max(len(results), 1.0) + exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0 + for exact, ignore_case, ignore_case_symbol in results: + exact_sum += exact + ignore_case_sum += ignore_case + ignore_case_symbol_sum += ignore_case_symbol + if 'exact' in self.mode: + eval_res['accuracy'] = exact_sum / gt_word_num + if 'ignore_case' in self.mode: + eval_res['ignore_case_accuracy'] = ignore_case_sum / gt_word_num + if 'ignore_case_symbol' in self.mode: + eval_res['ignore_case_symbol_accuracy'] =\ + ignore_case_symbol_sum / gt_word_num + return eval_res diff --git a/tests/test_metrics/test_word_accuracy.py b/tests/test_metrics/test_word_accuracy.py new file mode 100644 index 00000000..4898d96a --- /dev/null +++ b/tests/test_metrics/test_word_accuracy.py @@ -0,0 +1,22 @@ +import pytest + +from mmeval import WordAccuracy + + +def test_init(): + with pytest.raises(AssertionError): + WordAccuracy(mode=1) + with pytest.raises(AssertionError): + WordAccuracy(mode=[1, 2]) + with pytest.raises(AssertionError): + WordAccuracy(mode='micro') + metric = WordAccuracy(mode=['ignore_case', 'ignore_case', 'exact']) + assert metric.mode == {'ignore_case', 'ignore_case', 'exact'} + + +def test_word_accuracy(): + metric = WordAccuracy(mode=['exact', 'ignore_case', 'ignore_case_symbol']) + res = metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + assert abs(res['accuracy'] - 1. / 3) < 1e-7 + assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 + assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7 From 3b568d84b76b5e81ee50be08242245d95875cf93 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Fri, 17 Feb 2023 13:36:58 +0800 Subject: [PATCH 2/5] add api and fix comment --- docs/en/api/metrics.rst | 1 + docs/zh_cn/api/metrics.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index b5c4f231..c0d7761f 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -53,3 +53,4 @@ Metrics KeypointEndPointError KeypointAUC KeypointNME + WordAccuracy diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index b5c4f231..c0d7761f 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -53,3 +53,4 @@ Metrics KeypointEndPointError KeypointAUC KeypointNME + WordAccuracy From 4e0dee5d7aa0f855351dc17a638d9010426d3e3e Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Mon, 27 Feb 2023 10:51:13 +0800 Subject: [PATCH 3/5] fix doc comment --- mmeval/metrics/word_accuracy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py index 6bda6e61..a446fe8c 100644 --- a/mmeval/metrics/word_accuracy.py +++ b/mmeval/metrics/word_accuracy.py @@ -6,15 +6,17 @@ class WordAccuracy(BaseMetric): - """Calculate the word level accuracy. + r"""Calculate the word level accuracy. Args: mode (str or list[str]): Options are: + - 'exact': Accuracy at word level. - 'ignore_case': Accuracy at word level, ignoring letter case. - 'ignore_case_symbol': Accuracy at word level, ignoring letter case and symbol. (Default metric for academic evaluation) + If mode is a list, then metrics in mode will be calculated separately. Defaults to 'ignore_case_symbol'. valid_symbol (str): Valid characters. Defaults to @@ -79,6 +81,7 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: Returns: dict[str, float]: Nested dicts as results. Provided keys are: + - accuracy (float): Accuracy at word level. - ignore_case_accuracy (float): Accuracy at word level, ignoring letter case. From ef56e48e510582d01e56a60ec585c6320be0585c Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Tue, 28 Feb 2023 15:07:26 +0800 Subject: [PATCH 4/5] fix comment --- mmeval/metrics/word_accuracy.py | 13 +++++++------ tests/test_metrics/test_word_accuracy.py | 8 ++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py index a446fe8c..6a7d628b 100644 --- a/mmeval/metrics/word_accuracy.py +++ b/mmeval/metrics/word_accuracy.py @@ -19,8 +19,9 @@ class WordAccuracy(BaseMetric): If mode is a list, then metrics in mode will be calculated separately. Defaults to 'ignore_case_symbol'. - valid_symbol (str): Valid characters. Defaults to - '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + invalid_symbol (str): A regular expression to filter out invalid or + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]' + **kwargs: Keyword parameters passed to :class:`BaseMetric`. Example: >>> from mmeval import WordAccuracy @@ -37,11 +38,11 @@ class WordAccuracy(BaseMetric): def __init__(self, mode: Union[str, Sequence[str]] = 'ignore_case_symbol', - valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) self.mode = mode - self.valid_symbol = re.compile(valid_symbol) + self.invalid_symbol = re.compile(invalid_symbol) assert isinstance(mode, (str, list)) if isinstance(mode, str): mode = [mode] @@ -66,8 +67,8 @@ def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # typ label_lower = label.lower() ignore_case_num = pred_lower == label_lower if 'ignore_case_symbol' in self.mode: - label_lower_ignore = self.valid_symbol.sub('', label_lower) - pred_lower_ignore = self.valid_symbol.sub('', pred_lower) + label_lower_ignore = self.invalid_symbol.sub('', label_lower) + pred_lower_ignore = self.invalid_symbol.sub('', pred_lower) ignore_case_symbol_num =\ label_lower_ignore == pred_lower_ignore self._results.append( diff --git a/tests/test_metrics/test_word_accuracy.py b/tests/test_metrics/test_word_accuracy.py index 4898d96a..db9ee883 100644 --- a/tests/test_metrics/test_word_accuracy.py +++ b/tests/test_metrics/test_word_accuracy.py @@ -20,3 +20,11 @@ def test_word_accuracy(): assert abs(res['accuracy'] - 1. / 3) < 1e-7 assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7 + metric.reset() + for pred, label in zip(['hello', 'hello', 'hello'], + ['hello', 'HELLO', '$HELLO$']): + metric.add([pred], [label]) + res = metric.compute() + assert abs(res['accuracy'] - 1. / 3) < 1e-7 + assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 + assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7 From 0c40a5c3c8f1fce88cb24b899394b108b23a9bbc Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Mon, 6 Mar 2023 09:49:09 +0800 Subject: [PATCH 5/5] fix comment --- mmeval/metrics/word_accuracy.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py index 6a7d628b..4a867872 100644 --- a/mmeval/metrics/word_accuracy.py +++ b/mmeval/metrics/word_accuracy.py @@ -23,7 +23,7 @@ class WordAccuracy(BaseMetric): not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]' **kwargs: Keyword parameters passed to :class:`BaseMetric`. - Example: + Examples: >>> from mmeval import WordAccuracy >>> metric = WordAccuracy() >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) @@ -46,19 +46,19 @@ def __init__(self, assert isinstance(mode, (str, list)) if isinstance(mode, str): mode = [mode] - assert all([isinstance(item, str) for item in mode]) - assert set(mode).issubset( - {'exact', 'ignore_case', 'ignore_case_symbol'}) + assert all(isinstance(item, str) for item in mode) self.mode = set(mode) # type: ignore + assert set(self.mode).issubset( + {'exact', 'ignore_case', 'ignore_case_symbol'}) - def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0 if 'exact' in self.mode: num = pred == label @@ -85,11 +85,11 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: - accuracy (float): Accuracy at word level. - ignore_case_accuracy (float): Accuracy at word level, ignoring - letter case. + letter case. - ignore_case_symbol_accuracy (float): Accuracy at word level, - ignoring letter case and symbol. + ignoring letter case and symbol. """ - eval_res = {} + metric_results = {} gt_word_num = max(len(results), 1.0) exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0 for exact, ignore_case, ignore_case_symbol in results: @@ -97,10 +97,11 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: ignore_case_sum += ignore_case ignore_case_symbol_sum += ignore_case_symbol if 'exact' in self.mode: - eval_res['accuracy'] = exact_sum / gt_word_num + metric_results['accuracy'] = exact_sum / gt_word_num if 'ignore_case' in self.mode: - eval_res['ignore_case_accuracy'] = ignore_case_sum / gt_word_num + metric_results[ + 'ignore_case_accuracy'] = ignore_case_sum / gt_word_num if 'ignore_case_symbol' in self.mode: - eval_res['ignore_case_symbol_accuracy'] =\ + metric_results['ignore_case_symbol_accuracy'] =\ ignore_case_symbol_sum / gt_word_num - return eval_res + return metric_results