Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 3170f6f

Browse files
authored
Make FBetaMeasure work with batch size of 1. (#2777)
- A stray squeeze caused batches of size one to be zero dimensional. Indexing into them was thus invalid and crashed. - See #2275 (comment)
1 parent b97bca0 commit 3170f6f

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

allennlp/tests/training/metrics/fbeta_measure_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,17 @@ def test_fbeta_multiclass_with_explicit_labels(self):
161161
numpy.testing.assert_almost_equal(precisions, desired_precisions, decimal=2)
162162
numpy.testing.assert_almost_equal(recalls, desired_recalls, decimal=2)
163163
numpy.testing.assert_almost_equal(fscores, desired_fscores, decimal=2)
164+
165+
def test_fbeta_handles_batch_size_of_one(self):
166+
predictions = torch.Tensor([[0.2862, 0.3479, 0.1627, 0.2033]])
167+
targets = torch.Tensor([1])
168+
mask = torch.Tensor([1])
169+
170+
fbeta = FBetaMeasure()
171+
fbeta(predictions, targets, mask)
172+
metric = fbeta.get_metric()
173+
precisions = metric['precision']
174+
recalls = metric['recall']
175+
176+
numpy.testing.assert_almost_equal(precisions, [0.0, 1.0, 0.0, 0.0])
177+
numpy.testing.assert_almost_equal(recalls, [0.0, 1.0, 0.0, 0.0])

allennlp/training/metrics/fbeta_measure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __call__(self,
122122
mask = mask.to(torch.uint8)
123123
gold_labels = gold_labels.float()
124124

125-
argmax_predictions = predictions.max(dim=-1)[1].float().squeeze(dim=-1)
125+
argmax_predictions = predictions.max(dim=-1)[1].float()
126126
true_positives = (gold_labels == argmax_predictions) * mask
127127
true_positives_bins = gold_labels[true_positives]
128128

0 commit comments

Comments
 (0)