Skip to content

Commit 5bd89db

Browse files
[Feature] Add ROUGE metric (#72)
* Add lowercase super parameter * Add lowercase super parameter * Add lowercase super parameter * Add lowercase super parameter
1 parent 7a12a70 commit 5bd89db

File tree

11 files changed

+552
-89
lines changed

11 files changed

+552
-89
lines changed

docs/en/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ Metrics
4747
MattingMSE
4848
ConnectivityError
4949
DOTAMeanAP
50+
ROUGE

docs/zh_cn/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ Metrics
4747
MattingMSE
4848
ConnectivityError
4949
DOTAMeanAP
50+
ROUGE

mmeval/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy
2020
from .proposal_recall import ProposalRecall
2121
from .psnr import PSNR
22+
from .rouge import ROUGE
2223
from .sad import SAD
2324
from .single_label import SingleLabelMetric
2425
from .snr import SNR
@@ -31,5 +32,5 @@
3132
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
3233
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
3334
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SAD',
34-
'GradientError', 'MattingMSE', 'ConnectivityError'
35+
'GradientError', 'MattingMSE', 'ConnectivityError', 'ROUGE'
3536
]

mmeval/metrics/bleu.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,10 @@
33
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
44
import numpy as np
55
from collections import Counter
6-
from typing import List, Optional, Sequence, Tuple
6+
from typing import Callable, List, Optional, Sequence, Tuple, Union
77

88
from mmeval import BaseMetric
9-
10-
11-
def get_n_gram(token: Sequence[str], n_gram: int) -> Counter:
12-
"""A function get n_gram of sentences.
13-
14-
Args:
15-
token (Sequence[str]): A series of tokens about sentences.
16-
n_gram (int): The maximum number of words contained in a phrase
17-
when calculating word fragments. Defaults to 4.
18-
19-
Returns:
20-
Counter: The n_gram contained in sentences with Counter format.
21-
"""
22-
counter: Counter = Counter()
23-
for i in range(1, n_gram + 1):
24-
for j in range(len(token) - i + 1):
25-
key = tuple(token[j:(i + j)])
26-
counter[key] += 1
27-
return counter
28-
29-
30-
def tokenizer_fn(sentence: str) -> List[str]:
31-
"""This function is used to segment a sentence.
32-
33-
Args:
34-
sentence (str): A sentence.
35-
36-
Returns:
37-
List[str]: A list of tokens after word segmentation.
38-
"""
39-
return sentence.split()
9+
from mmeval.metrics.utils import get_n_gram, get_tokenizer, infer_language
4010

4111

4212
def _get_brevity_penalty(pred_len: np.array,
@@ -67,9 +37,12 @@ class BLEU(BaseMetric):
6737
n_gram (int): The maximum number of words contained in a phrase
6838
when calculating word fragments. Defaults to 4.
6939
smooth (bool): Whether or not to apply to smooth. Defaults to False.
70-
ngram_weights(Sequence[float], optional): Weights used
40+
ngram_weights (Sequence[float], optional): Weights used
7141
for unigrams, bigrams, etc. to calculate BLEU score.
7242
If not provided, uniform weights are used. Defaults to None.
43+
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
44+
Defaults to None.
45+
New in version 0.3.0.
7346
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
7447
7548
Examples:
@@ -93,6 +66,7 @@ def __init__(self,
9366
n_gram: int = 4,
9467
smooth: bool = False,
9568
ngram_weights: Optional[Sequence[float]] = None,
69+
tokenizer_fn: Union[Callable, str, None] = None,
9670
**kwargs) -> None:
9771
super().__init__(**kwargs)
9872
self.n_gram = n_gram
@@ -105,21 +79,35 @@ def __init__(self,
10579
ngram_weights = [1.0 / n_gram] * n_gram
10680
self.ngram_weights = ngram_weights
10781

82+
# Select tokenizer according to the entered value.
83+
self.tokenizer_fn = None
84+
if callable(tokenizer_fn):
85+
self.tokenizer_fn = tokenizer_fn
86+
elif isinstance(tokenizer_fn, str):
87+
self.tokenizer_fn = get_tokenizer(tokenizer_fn)
88+
if self.tokenizer_fn is None:
89+
raise ValueError('Right now, `tokenizer_fn` only supports '
90+
"pre-defined 'en' or 'cn'.")
91+
else:
92+
assert tokenizer_fn is None, \
93+
f'`tokenizer_fn` supports Callable, str or None, but not `{type(tokenizer_fn)}`' # noqa: E501
94+
10895
def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
10996
"""Add the intermediate results to ``self._results``.
11097
11198
Args:
112-
predictions (Sequence[str]): An iterable of machine
113-
translated corpus.
114-
references (Sequence[Sequence[str]]): An iterable of
115-
iterables of reference corpus.
99+
predictions (Sequence[str]): An iterable of predicted sentences.
100+
references (Sequence[Sequence[str]): An iterable of
101+
referenced sentences.
116102
"""
117-
103+
if self.tokenizer_fn is None:
104+
language = infer_language(predictions[0])
105+
self.tokenizer_fn = get_tokenizer(language)
118106
references_token: Sequence[Sequence[Sequence[str]]] = [
119-
[tokenizer_fn(line) for line in r] for r in references
107+
[self.tokenizer_fn(line) for line in r] for r in references
120108
]
121109
predictions_token: Sequence[Sequence[str]] = [
122-
tokenizer_fn(line) for line in predictions
110+
self.tokenizer_fn(line) for line in predictions
123111
]
124112
for prediction, references in zip(predictions_token, references_token):
125113
pred_len = len(prediction)

0 commit comments

Comments
 (0)