3
3
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
4
4
import numpy as np
5
5
from collections import Counter
6
- from typing import List , Optional , Sequence , Tuple
6
+ from typing import Callable , List , Optional , Sequence , Tuple , Union
7
7
8
8
from mmeval import BaseMetric
9
-
10
-
11
- def get_n_gram (token : Sequence [str ], n_gram : int ) -> Counter :
12
- """A function get n_gram of sentences.
13
-
14
- Args:
15
- token (Sequence[str]): A series of tokens about sentences.
16
- n_gram (int): The maximum number of words contained in a phrase
17
- when calculating word fragments. Defaults to 4.
18
-
19
- Returns:
20
- Counter: The n_gram contained in sentences with Counter format.
21
- """
22
- counter : Counter = Counter ()
23
- for i in range (1 , n_gram + 1 ):
24
- for j in range (len (token ) - i + 1 ):
25
- key = tuple (token [j :(i + j )])
26
- counter [key ] += 1
27
- return counter
28
-
29
-
30
- def tokenizer_fn (sentence : str ) -> List [str ]:
31
- """This function is used to segment a sentence.
32
-
33
- Args:
34
- sentence (str): A sentence.
35
-
36
- Returns:
37
- List[str]: A list of tokens after word segmentation.
38
- """
39
- return sentence .split ()
9
+ from mmeval .metrics .utils import get_n_gram , get_tokenizer , infer_language
40
10
41
11
42
12
def _get_brevity_penalty (pred_len : np .array ,
@@ -67,9 +37,12 @@ class BLEU(BaseMetric):
67
37
n_gram (int): The maximum number of words contained in a phrase
68
38
when calculating word fragments. Defaults to 4.
69
39
smooth (bool): Whether or not to apply to smooth. Defaults to False.
70
- ngram_weights(Sequence[float], optional): Weights used
40
+ ngram_weights (Sequence[float], optional): Weights used
71
41
for unigrams, bigrams, etc. to calculate BLEU score.
72
42
If not provided, uniform weights are used. Defaults to None.
43
+ tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
44
+ Defaults to None.
45
+ New in version 0.3.0.
73
46
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
74
47
75
48
Examples:
@@ -93,6 +66,7 @@ def __init__(self,
93
66
n_gram : int = 4 ,
94
67
smooth : bool = False ,
95
68
ngram_weights : Optional [Sequence [float ]] = None ,
69
+ tokenizer_fn : Union [Callable , str , None ] = None ,
96
70
** kwargs ) -> None :
97
71
super ().__init__ (** kwargs )
98
72
self .n_gram = n_gram
@@ -105,21 +79,35 @@ def __init__(self,
105
79
ngram_weights = [1.0 / n_gram ] * n_gram
106
80
self .ngram_weights = ngram_weights
107
81
82
+ # Select tokenizer according to the entered value.
83
+ self .tokenizer_fn = None
84
+ if callable (tokenizer_fn ):
85
+ self .tokenizer_fn = tokenizer_fn
86
+ elif isinstance (tokenizer_fn , str ):
87
+ self .tokenizer_fn = get_tokenizer (tokenizer_fn )
88
+ if self .tokenizer_fn is None :
89
+ raise ValueError ('Right now, `tokenizer_fn` only supports '
90
+ "pre-defined 'en' or 'cn'." )
91
+ else :
92
+ assert tokenizer_fn is None , \
93
+ f'`tokenizer_fn` supports Callable, str or None, but not `{ type (tokenizer_fn )} `' # noqa: E501
94
+
108
95
def add (self , predictions : Sequence [str ], references : Sequence [Sequence [str ]]) -> None : # type: ignore # yapf: disable # noqa: E501
109
96
"""Add the intermediate results to ``self._results``.
110
97
111
98
Args:
112
- predictions (Sequence[str]): An iterable of machine
113
- translated corpus.
114
- references (Sequence[Sequence[str]]): An iterable of
115
- iterables of reference corpus.
99
+ predictions (Sequence[str]): An iterable of predicted sentences.
100
+ references (Sequence[Sequence[str]): An iterable of
101
+ referenced sentences.
116
102
"""
117
-
103
+ if self .tokenizer_fn is None :
104
+ language = infer_language (predictions [0 ])
105
+ self .tokenizer_fn = get_tokenizer (language )
118
106
references_token : Sequence [Sequence [Sequence [str ]]] = [
119
- [tokenizer_fn (line ) for line in r ] for r in references
107
+ [self . tokenizer_fn (line ) for line in r ] for r in references
120
108
]
121
109
predictions_token : Sequence [Sequence [str ]] = [
122
- tokenizer_fn (line ) for line in predictions
110
+ self . tokenizer_fn (line ) for line in predictions
123
111
]
124
112
for prediction , references in zip (predictions_token , references_token ):
125
113
pred_len = len (prediction )
0 commit comments