Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Mar 6, 2023
1 parent ef56e48 commit 0c40a5c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions mmeval/metrics/word_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class WordAccuracy(BaseMetric):
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Example:
Examples:
>>> from mmeval import WordAccuracy
>>> metric = WordAccuracy()
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
Expand All @@ -46,19 +46,19 @@ def __init__(self,
assert isinstance(mode, (str, list))
if isinstance(mode, str):
mode = [mode]
assert all([isinstance(item, str) for item in mode])
assert set(mode).issubset(
{'exact', 'ignore_case', 'ignore_case_symbol'})
assert all(isinstance(item, str) for item in mode)
self.mode = set(mode) # type: ignore
assert set(self.mode).issubset(
{'exact', 'ignore_case', 'ignore_case_symbol'})

def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.
Args:
predictions (list[str]): The prediction texts.
labels (list[str]): The ground truth texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, labels):
for pred, label in zip(predictions, groundtruths):
num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0
if 'exact' in self.mode:
num = pred == label
Expand All @@ -85,22 +85,23 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict:
- accuracy (float): Accuracy at word level.
- ignore_case_accuracy (float): Accuracy at word level, ignoring
letter case.
letter case.
- ignore_case_symbol_accuracy (float): Accuracy at word level,
ignoring letter case and symbol.
ignoring letter case and symbol.
"""
eval_res = {}
metric_results = {}
gt_word_num = max(len(results), 1.0)
exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0
for exact, ignore_case, ignore_case_symbol in results:
exact_sum += exact
ignore_case_sum += ignore_case
ignore_case_symbol_sum += ignore_case_symbol
if 'exact' in self.mode:
eval_res['accuracy'] = exact_sum / gt_word_num
metric_results['accuracy'] = exact_sum / gt_word_num
if 'ignore_case' in self.mode:
eval_res['ignore_case_accuracy'] = ignore_case_sum / gt_word_num
metric_results[
'ignore_case_accuracy'] = ignore_case_sum / gt_word_num
if 'ignore_case_symbol' in self.mode:
eval_res['ignore_case_symbol_accuracy'] =\
metric_results['ignore_case_symbol_accuracy'] =\
ignore_case_symbol_sum / gt_word_num
return eval_res
return metric_results

0 comments on commit 0c40a5c

Please sign in to comment.