diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py index a446fe8c..a51f00aa 100644 --- a/mmeval/metrics/word_accuracy.py +++ b/mmeval/metrics/word_accuracy.py @@ -19,8 +19,8 @@ class WordAccuracy(BaseMetric): If mode is a list, then metrics in mode will be calculated separately. Defaults to 'ignore_case_symbol'. - valid_symbol (str): Valid characters. Defaults to - '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + invalid_symbol (str): A regular expression to filter out invalid or + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]' Example: >>> from mmeval import WordAccuracy @@ -37,11 +37,11 @@ class WordAccuracy(BaseMetric): def __init__(self, mode: Union[str, Sequence[str]] = 'ignore_case_symbol', - valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) self.mode = mode - self.valid_symbol = re.compile(valid_symbol) + self.invalid_symbol = re.compile(invalid_symbol) assert isinstance(mode, (str, list)) if isinstance(mode, str): mode = [mode] @@ -66,8 +66,8 @@ def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # typ label_lower = label.lower() ignore_case_num = pred_lower == label_lower if 'ignore_case_symbol' in self.mode: - label_lower_ignore = self.valid_symbol.sub('', label_lower) - pred_lower_ignore = self.valid_symbol.sub('', pred_lower) + label_lower_ignore = self.invalid_symbol.sub('', label_lower) + pred_lower_ignore = self.invalid_symbol.sub('', pred_lower) ignore_case_symbol_num =\ label_lower_ignore == pred_lower_ignore self._results.append( diff --git a/tests/test_metrics/test_word_accuracy.py b/tests/test_metrics/test_word_accuracy.py index 4898d96a..db9ee883 100644 --- a/tests/test_metrics/test_word_accuracy.py +++ b/tests/test_metrics/test_word_accuracy.py @@ -20,3 +20,11 @@ def test_word_accuracy(): assert abs(res['accuracy'] - 1. / 3) < 1e-7 assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7 + metric.reset() + for pred, label in zip(['hello', 'hello', 'hello'], + ['hello', 'HELLO', '$HELLO$']): + metric.add([pred], [label]) + res = metric.compute() + assert abs(res['accuracy'] - 1. / 3) < 1e-7 + assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 + assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7