@@ -14,14 +14,14 @@ class CharRecallPrecision(BaseMetric):
14
14
15
15
- unchanged: Do not change prediction texts and labels.
16
16
- upper: Convert prediction texts and labels into uppercase
17
- characters.
17
+ characters.
18
18
- lower: Convert prediction texts and labels into lowercase
19
- characters.
19
+ characters.
20
20
21
21
Usually, it only works for English characters. Defaults to
22
22
'unchanged'.
23
23
invalid_symbol (str): A regular expression to filter out invalid or
24
- not cared characters. Defaults to '[^A-Z^a-z^0-9^ \u4e00-\u9fa5]'.
24
+ not cared characters. Defaults to '[^A-Za-z0-9 \u4e00-\u9fa5]'.
25
25
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
26
26
27
27
Examples:
@@ -36,21 +36,21 @@ class CharRecallPrecision(BaseMetric):
36
36
37
37
def __init__ (self ,
38
38
letter_case : str = 'unchanged' ,
39
- invalid_symbol : str = '[^A-Z^a-z^0-9^ \u4e00 -\u9fa5 ]' ,
39
+ invalid_symbol : str = '[^A-Za-z0-9 \u4e00 -\u9fa5 ]' ,
40
40
** kwargs ):
41
41
super ().__init__ (** kwargs )
42
42
assert letter_case in ['unchanged' , 'upper' , 'lower' ]
43
43
self .letter_case = letter_case
44
44
self .invalid_symbol = re .compile (invalid_symbol )
45
45
46
- def add (self , predictions : Sequence [str ], labels : Sequence [str ]) -> None : # type: ignore # yapf: disable # noqa: E501
46
+ def add (self , predictions : Sequence [str ], groundtruths : Sequence [str ]) -> None : # type: ignore # yapf: disable # noqa: E501
47
47
"""Process one batch of data and predictions.
48
48
49
49
Args:
50
50
predictions (list[str]): The prediction texts.
51
- labels (list[str]): The ground truth texts.
51
+ groundtruths (list[str]): The ground truth texts.
52
52
"""
53
- for pred , label in zip (predictions , labels ):
53
+ for pred , label in zip (predictions , groundtruths ):
54
54
if self .letter_case in ['upper' , 'lower' ]:
55
55
pred = getattr (pred , self .letter_case )()
56
56
label = getattr (label , self .letter_case )()
@@ -79,10 +79,10 @@ def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict:
79
79
true_positive_sum += true_positive
80
80
char_recall = true_positive_sum / max (gt_sum , 1.0 )
81
81
char_precision = true_positive_sum / max (pred_sum , 1.0 )
82
- eval_res = {}
83
- eval_res ['recall' ] = char_recall
84
- eval_res ['precision' ] = char_precision
85
- return eval_res
82
+ metric_results = {}
83
+ metric_results ['recall' ] = char_recall
84
+ metric_results ['precision' ] = char_precision
85
+ return metric_results
86
86
87
87
def _cal_true_positive_char (self , pred : str , gt : str ) -> int :
88
88
"""Calculate correct character number in prediction.
0 commit comments