diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py index 6a7d628b..4a867872 100644 --- a/mmeval/metrics/word_accuracy.py +++ b/mmeval/metrics/word_accuracy.py @@ -23,7 +23,7 @@ class WordAccuracy(BaseMetric): not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]' **kwargs: Keyword parameters passed to :class:`BaseMetric`. - Example: + Examples: >>> from mmeval import WordAccuracy >>> metric = WordAccuracy() >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) @@ -46,19 +46,19 @@ def __init__(self, assert isinstance(mode, (str, list)) if isinstance(mode, str): mode = [mode] - assert all([isinstance(item, str) for item in mode]) - assert set(mode).issubset( - {'exact', 'ignore_case', 'ignore_case_symbol'}) + assert all(isinstance(item, str) for item in mode) self.mode = set(mode) # type: ignore + assert set(self.mode).issubset( + {'exact', 'ignore_case', 'ignore_case_symbol'}) - def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0 if 'exact' in self.mode: num = pred == label @@ -85,11 +85,11 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: - accuracy (float): Accuracy at word level. - ignore_case_accuracy (float): Accuracy at word level, ignoring - letter case. + letter case. - ignore_case_symbol_accuracy (float): Accuracy at word level, - ignoring letter case and symbol. + ignoring letter case and symbol. """ - eval_res = {} + metric_results = {} gt_word_num = max(len(results), 1.0) exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0 for exact, ignore_case, ignore_case_symbol in results: @@ -97,10 +97,11 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: ignore_case_sum += ignore_case ignore_case_symbol_sum += ignore_case_symbol if 'exact' in self.mode: - eval_res['accuracy'] = exact_sum / gt_word_num + metric_results['accuracy'] = exact_sum / gt_word_num if 'ignore_case' in self.mode: - eval_res['ignore_case_accuracy'] = ignore_case_sum / gt_word_num + metric_results[ + 'ignore_case_accuracy'] = ignore_case_sum / gt_word_num if 'ignore_case_symbol' in self.mode: - eval_res['ignore_case_symbol_accuracy'] =\ + metric_results['ignore_case_symbol_accuracy'] =\ ignore_case_symbol_sum / gt_word_num - return eval_res + return metric_results