Skip to content

Commit 14f6679

Browse files
authored
add support to F1 metric values (#439)
1 parent 32e6f9e commit 14f6679

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/pytorch_ie/metrics/f1.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _compute(self) -> Dict[str, Dict[str, float]]:
141141
p = tp / (tp + fp)
142142
r = tp / (tp + fn)
143143
f1 = 2 * p * r / (p + r)
144-
res[label] = {"f1": f1, "p": p, "r": r}
144+
res[label] = {"f1": f1, "p": p, "r": r, "s": tp + fn}
145145
if self.per_label and label in self.labels:
146146
res["MACRO"]["f1"] += f1 / len(self.labels)
147147
res["MACRO"]["p"] += p / len(self.labels)

tests/metrics/test_f1.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_f1(documents):
5252
metric(documents)
5353
# tp, fp, fn for micro
5454
assert dict(metric.counts) == {"MICRO": (3, 2, 0)}
55-
assert metric.compute() == {"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0}}
55+
assert metric.compute() == {"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0, "s": 3}}
5656

5757

5858
def test_f1_per_label(documents):
@@ -67,10 +67,10 @@ def test_f1_per_label(documents):
6767
}
6868
assert metric.compute() == {
6969
"MACRO": {"f1": 0.5555555555555556, "p": 0.5, "r": 0.6666666666666666},
70-
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0},
71-
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0},
72-
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0},
73-
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0},
70+
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0, "s": 3},
71+
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0, "s": 2},
72+
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0, "s": 0},
73+
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0, "s": 1},
7474
}
7575

7676

@@ -86,10 +86,10 @@ def test_f1_per_label_inferred(documents):
8686
}
8787
assert metric.compute() == {
8888
"MACRO": {"f1": 0.5555555555555556, "p": 0.5, "r": 0.6666666666666666},
89-
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0},
90-
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0},
91-
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0},
92-
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0},
89+
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0, "s": 3},
90+
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0, "s": 2},
91+
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0, "s": 0},
92+
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0, "s": 1},
9393
}
9494

9595

0 commit comments

Comments
 (0)