Skip to content

Commit a9cf49b

Browse files
authored
[Feature] Add CharRecallPrecision for OCR Task (#96)
* [Feature] Adapt MMEval for CharRecallPersion WordAccuracy OneMinusNDE * fix comment * add api doc * fix doc comment * fix comment * fix comment
1 parent 2cae388 commit a9cf49b

File tree

5 files changed

+138
-1
lines changed

5 files changed

+138
-1
lines changed

docs/en/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Metrics
5050
ROUGE
5151
NaturalImageQualityEvaluator
5252
Perplexity
53+
CharRecallPrecision
5354
KeypointEndPointError
5455
KeypointAUC
5556
KeypointNME

docs/zh_cn/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Metrics
5050
ROUGE
5151
NaturalImageQualityEvaluator
5252
Perplexity
53+
CharRecallPrecision
5354
KeypointEndPointError
5455
KeypointAUC
5556
KeypointNME

mmeval/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .ava_map import AVAMeanAP
77
from .average_precision import AveragePrecision
88
from .bleu import BLEU
9+
from .char_recall_precision import CharRecallPrecision
910
from .coco_detection import COCODetection
1011
from .connectivity_error import ConnectivityError
1112
from .dota_map import DOTAMeanAP
@@ -46,7 +47,8 @@
4647
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
4748
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator',
4849
'WordAccuracy', 'PrecisionRecallF1score',
49-
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score'
50+
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score',
51+
'CharRecallPrecision'
5052
]
5153

5254
_deprecated_msg = (
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import re
3+
from difflib import SequenceMatcher
4+
from typing import Dict, Sequence, Tuple
5+
6+
from mmeval.core import BaseMetric
7+
8+
9+
class CharRecallPrecision(BaseMetric):
10+
r"""Calculate the char level recall & precision.
11+
12+
Args:
13+
letter_case (str): There are three options to alter the letter cases
14+
15+
- unchanged: Do not change prediction texts and labels.
16+
- upper: Convert prediction texts and labels into uppercase
17+
characters.
18+
- lower: Convert prediction texts and labels into lowercase
19+
characters.
20+
21+
Usually, it only works for English characters. Defaults to
22+
'unchanged'.
23+
invalid_symbol (str): A regular expression to filter out invalid or
24+
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
25+
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
26+
27+
Examples:
28+
>>> from mmeval import CharRecallPrecision
29+
>>> metric = CharRecallPrecision()
30+
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
31+
{'char_recall': 0.6, 'char_precision': 0.8571428571428571}
32+
>>> metric = CharRecallPrecision(letter_case='upper')
33+
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
34+
{'char_recall': 0.7, 'char_precision': 1.0}
35+
"""
36+
37+
def __init__(self,
38+
letter_case: str = 'unchanged',
39+
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
40+
**kwargs):
41+
super().__init__(**kwargs)
42+
assert letter_case in ['unchanged', 'upper', 'lower']
43+
self.letter_case = letter_case
44+
self.invalid_symbol = re.compile(invalid_symbol)
45+
46+
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
47+
"""Process one batch of data and predictions.
48+
49+
Args:
50+
predictions (list[str]): The prediction texts.
51+
groundtruths (list[str]): The ground truth texts.
52+
"""
53+
for pred, label in zip(predictions, groundtruths):
54+
if self.letter_case in ['upper', 'lower']:
55+
pred = getattr(pred, self.letter_case)()
56+
label = getattr(label, self.letter_case)()
57+
valid_label = self.invalid_symbol.sub('', label)
58+
valid_pred = self.invalid_symbol.sub('', pred)
59+
# number to calculate char level recall & precision
60+
true_positive_char_num = self._cal_true_positive_char(
61+
valid_pred, valid_label)
62+
self._results.append(
63+
(len(valid_label), len(valid_pred), true_positive_char_num))
64+
65+
def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict:
66+
"""Compute the metrics from processed results.
67+
68+
Args:
69+
results (list[tuple]): The processed results of each batch.
70+
71+
Returns:
72+
Dict: The computed metrics. The keys are the names of the
73+
metrics, and the values are corresponding results.
74+
"""
75+
gt_sum, pred_sum, true_positive_sum = 0.0, 0.0, 0.0
76+
for gt, pred, true_positive in results:
77+
gt_sum += gt
78+
pred_sum += pred
79+
true_positive_sum += true_positive
80+
char_recall = true_positive_sum / max(gt_sum, 1.0)
81+
char_precision = true_positive_sum / max(pred_sum, 1.0)
82+
metric_results = {}
83+
metric_results['recall'] = char_recall
84+
metric_results['precision'] = char_precision
85+
return metric_results
86+
87+
def _cal_true_positive_char(self, pred: str, gt: str) -> int:
88+
"""Calculate correct character number in prediction.
89+
90+
Args:
91+
pred (str): Prediction text.
92+
gt (str): Ground truth text.
93+
94+
Returns:
95+
true_positive_char_num (int): The true positive number.
96+
"""
97+
98+
all_opt = SequenceMatcher(None, pred, gt)
99+
true_positive_char_num = 0
100+
for opt, _, _, s2, e2 in all_opt.get_opcodes():
101+
if opt == 'equal':
102+
true_positive_char_num += (e2 - s2)
103+
else:
104+
pass
105+
return true_positive_char_num
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
3+
from mmeval import CharRecallPrecision
4+
5+
6+
def test_init():
7+
with pytest.raises(AssertionError):
8+
CharRecallPrecision(letter_case='fake')
9+
10+
11+
@pytest.mark.parametrize(
12+
argnames=['letter_case', 'recall', 'precision'],
13+
argvalues=[
14+
('lower', 0.7, 1),
15+
('upper', 0.7, 1),
16+
('unchanged', 0.6, 6.0 / 7),
17+
])
18+
def test_char_recall_precision_metric(letter_case, recall, precision):
19+
metric = CharRecallPrecision(letter_case=letter_case)
20+
res = metric(['helL', 'HEL'], ['hello', 'HELLO'])
21+
assert abs(res['recall'] - recall) < 1e-7
22+
assert abs(res['precision'] - precision) < 1e-7
23+
metric.reset()
24+
for pred, label in zip(['helL', 'HEL'], ['hello', 'HELLO']):
25+
metric.add([pred], [label])
26+
res = metric.compute()
27+
assert abs(res['recall'] - recall) < 1e-7
28+
assert abs(res['precision'] - precision) < 1e-7

0 commit comments

Comments
 (0)