We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6f437b1 commit 73db2b4Copy full SHA for 73db2b4
src/pytorch_ie/metrics/confusion_matrix.py
@@ -68,7 +68,7 @@ def calculate_counts(
68
base2pred[base_ann].append(ann)
69
70
# (gold_label, pred_label) -> count
71
- counts = defaultdict(int)
+ counts: Dict[Tuple[str, str], int] = defaultdict(int)
72
for base_ann in set(base2gold) | set(base2pred):
73
gold_labels = [getattr(ann, self.label_field) for ann in base2gold[base_ann]]
74
pred_labels = [getattr(ann, self.label_field) for ann in base2pred[base_ann]]
0 commit comments