Skip to content

Commit 8d144eb

Browse files
HarshTrivediDeNeutoy
authored andcommitted
Add option to have tie breaking in Categorical Accuracy (allenai#1485)
* Add tie-breaking support in Categorical Accuracy. * Add tests for tie-breaking in categorical accuracy. * Remove redundant expand_as and add some comments for complex indexing.
1 parent 89b4751 commit 8d144eb

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

allennlp/tests/training/metrics/categorical_accuracy_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,40 @@ def test_top_k_categorical_accuracy_catches_exceptions(self):
7878
out_of_range_labels = torch.Tensor([10, 3, 4, 0, 1])
7979
with pytest.raises(ConfigurationError):
8080
accuracy(predictions, out_of_range_labels)
81+
82+
def test_tie_break_categorical_accuracy(self):
83+
accuracy = CategoricalAccuracy(tie_break=True)
84+
predictions = torch.Tensor([[0.35, 0.25, 0.35, 0.35, 0.35],
85+
[0.1, 0.6, 0.1, 0.2, 0.2],
86+
[0.1, 0.0, 0.1, 0.2, 0.2]])
87+
# Test without mask:
88+
targets = torch.Tensor([2, 1, 4])
89+
accuracy(predictions, targets)
90+
assert accuracy.get_metric(reset=True) == (0.25 + 1 + 0.5)/3.0
91+
92+
# # # Test with mask
93+
mask = torch.Tensor([1, 0, 1])
94+
targets = torch.Tensor([2, 1, 4])
95+
accuracy(predictions, targets, mask)
96+
assert accuracy.get_metric(reset=True) == (0.25 + 0.5)/2.0
97+
98+
# # Test tie-break with sequence
99+
predictions = torch.Tensor([[[0.35, 0.25, 0.35, 0.35, 0.35],
100+
[0.1, 0.6, 0.1, 0.2, 0.2],
101+
[0.1, 0.0, 0.1, 0.2, 0.2]],
102+
[[0.35, 0.25, 0.35, 0.35, 0.35],
103+
[0.1, 0.6, 0.1, 0.2, 0.2],
104+
[0.1, 0.0, 0.1, 0.2, 0.2]]])
105+
targets = torch.Tensor([[0, 1, 3], # 0.25 + 1 + 0.5
106+
[0, 3, 4]]) # 0.25 + 0 + 0.5 = 2.5
107+
accuracy(predictions, targets)
108+
actual_accuracy = accuracy.get_metric(reset=True)
109+
numpy.testing.assert_almost_equal(actual_accuracy, 2.5/6.0)
110+
111+
def test_top_k_and_tie_break_together_catches_exceptions(self):
112+
with pytest.raises(ConfigurationError):
113+
CategoricalAccuracy(top_k=2, tie_break=True)
114+
115+
def test_incorrect_top_k_catches_exceptions(self):
116+
with pytest.raises(ConfigurationError):
117+
CategoricalAccuracy(top_k=0)

allennlp/training/metrics/categorical_accuracy.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@ class CategoricalAccuracy(Metric):
1212
"""
1313
Categorical Top-K accuracy. Assumes integer labels, with
1414
each item to be classified having a single correct class.
15+
Tie break enables equal distribution of scores among the
16+
classes with same maximum predicted scores.
1517
"""
16-
def __init__(self, top_k: int = 1) -> None:
18+
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
19+
if top_k > 1 and tie_break:
20+
raise ConfigurationError("Tie break in Categorical Accuracy "
21+
"can be done only for maximum (top_k = 1)")
22+
if top_k <= 0:
23+
raise ConfigurationError("top_k passed to Categorical Accuracy must be > 0")
1724
self._top_k = top_k
25+
self._tie_break = tie_break
1826
self.correct_count = 0.
1927
self.total_count = 0.
2028

@@ -44,18 +52,32 @@ def __call__(self,
4452
raise ConfigurationError("A gold label passed to Categorical Accuracy contains an id >= {}, "
4553
"the number of classes.".format(num_classes))
4654

47-
# Top K indexes of the predictions (or fewer, if there aren't K of them).
48-
# Special case topk == 1, because it's common and .max() is much faster than .topk().
49-
if self._top_k == 1:
50-
top_k = predictions.max(-1)[1].unsqueeze(-1)
51-
else:
52-
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
55+
predictions = predictions.view((-1, num_classes))
56+
gold_labels = gold_labels.view(-1).long()
57+
if not self._tie_break:
58+
# Top K indexes of the predictions (or fewer, if there aren't K of them).
59+
# Special case topk == 1, because it's common and .max() is much faster than .topk().
60+
if self._top_k == 1:
61+
top_k = predictions.max(-1)[1].unsqueeze(-1)
62+
else:
63+
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
5364

54-
# This is of shape (batch_size, ..., top_k).
55-
correct = top_k.eq(gold_labels.long().unsqueeze(-1)).float()
65+
# This is of shape (batch_size, ..., top_k).
66+
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
67+
else:
68+
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
69+
max_predictions = predictions.max(-1)[0]
70+
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
71+
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
72+
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
73+
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
74+
correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
75+
tie_counts = max_predictions_mask.sum(-1)
76+
correct /= tie_counts.float()
77+
correct.unsqueeze_(-1)
5678

5779
if mask is not None:
58-
correct *= mask.float().unsqueeze(-1)
80+
correct *= mask.view(-1, 1).float()
5981
self.total_count += mask.sum()
6082
else:
6183
self.total_count += gold_labels.numel()

0 commit comments

Comments
 (0)