Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Feb 28, 2023
1 parent c065ab7 commit a421ca9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mmeval/metrics/word_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class WordAccuracy(BaseMetric):
If mode is a list, then metrics in mode will be calculated
separately. Defaults to 'ignore_case_symbol'.
valid_symbol (str): Valid characters. Defaults to
'[^A-Z^a-z^0-9^\u4e00-\u9fa5]'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
Example:
>>> from mmeval import WordAccuracy
Expand All @@ -37,11 +37,11 @@ class WordAccuracy(BaseMetric):

def __init__(self,
mode: Union[str, Sequence[str]] = 'ignore_case_symbol',
valid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)
self.mode = mode
self.valid_symbol = re.compile(valid_symbol)
self.invalid_symbol = re.compile(invalid_symbol)
assert isinstance(mode, (str, list))
if isinstance(mode, str):
mode = [mode]
Expand All @@ -66,8 +66,8 @@ def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # typ
label_lower = label.lower()
ignore_case_num = pred_lower == label_lower
if 'ignore_case_symbol' in self.mode:
label_lower_ignore = self.valid_symbol.sub('', label_lower)
pred_lower_ignore = self.valid_symbol.sub('', pred_lower)
label_lower_ignore = self.invalid_symbol.sub('', label_lower)
pred_lower_ignore = self.invalid_symbol.sub('', pred_lower)
ignore_case_symbol_num =\
label_lower_ignore == pred_lower_ignore
self._results.append(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_metrics/test_word_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ def test_word_accuracy():
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7
metric.reset()
for pred, label in zip(['hello', 'hello', 'hello'],
['hello', 'HELLO', '$HELLO$']):
metric.add([pred], [label])
res = metric.compute()
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7

0 comments on commit a421ca9

Please sign in to comment.