Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 720d306

Browse files
epwalshbrendan-ai2
authored andcommitted
Handle edge cases in beam search (#2557)
Addresses the edge case brought up in #2486 (fixes #2486) as well as another. This actually turned out to be a little more nuanced... I first thought the issue brought up was caused by `start_predictions` being the `end_index`, but it actually occurs in general with a beam size of 1 when the first predictions that the step function produces are the `end_index`, regardless of what `start_predictions` are, i.e. at this line: `start_class_log_probabilities, state = step(start_predictions, start_state)` The other edge case is similar, and occurs when the beam size is smaller than the number of valid (non-zero probability) transitions that the step function produces. For example, this could happen in a semantic parsing task where a masked log softmax is used to create predicted log probs for valid next actions. Though this doesn't cause the beam search to crash per se, I thought it would still be good to warn the user in these cases since some of the predicted sequences may be improbable.
1 parent 79936e5 commit 720d306

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

allennlp/nn/beam_search.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Callable, Tuple, Dict
2+
import warnings
23

34
import torch
45

@@ -48,6 +49,16 @@ def search(self,
4849
Given a starting state and a step function, apply beam search to find the
4950
most likely target sequences.
5051
52+
Notes
53+
-----
54+
If your step function returns ``-inf`` for some log probabilities
55+
(like if you're using a masked log-softmax) then some of the "best"
56+
sequences returned may also have ``-inf`` log probability. Specifically
57+
this happens when the beam size is smaller than the number of actions
58+
with finite log probability (non-zero probability) returned by the step function.
59+
Therefore if you're using a mask you may want to check the results from ``search``
60+
and potentially discard sequences with non-finite log probability.
61+
5162
Parameters
5263
----------
5364
start_predictions : ``torch.Tensor``
@@ -110,6 +121,11 @@ def search(self,
110121
# shape: (batch_size, beam_size), (batch_size, beam_size)
111122
start_top_log_probabilities, start_predicted_classes = \
112123
start_class_log_probabilities.topk(self.beam_size)
124+
if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
125+
warnings.warn("Empty sequences predicted. You may want to increase the beam size or ensure "
126+
"your step function is working properly.",
127+
RuntimeWarning)
128+
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
113129

114130
# The log probabilities for the last time step.
115131
# shape: (batch_size, beam_size)
@@ -166,9 +182,9 @@ def search(self,
166182
class_log_probabilities
167183
)
168184

185+
# shape (both): (batch_size * beam_size, per_node_beam_size)
169186
top_log_probabilities, predicted_classes = \
170187
cleaned_log_probabilities.topk(self.per_node_beam_size)
171-
# shape (both): (batch_size * beam_size, per_node_beam_size)
172188

173189
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
174190
# so that we can add them to the current log probs for this timestep.
@@ -227,6 +243,12 @@ def search(self,
227243
gather(1, expanded_backpointer).\
228244
reshape(batch_size * self.beam_size, *last_dims)
229245

246+
if not torch.isfinite(last_log_probabilities).all():
247+
warnings.warn("Infinite log probabilities encountered. Some final sequences may not make sense. "
248+
"This can happen when the beam size is larger than the number of valid (non-zero "
249+
"probability) transitions that the step function produces.",
250+
RuntimeWarning)
251+
230252
# Reconstruct the sequences.
231253
# shape: [(batch_size, beam_size, 1)]
232254
reconstructed_predictions = [predictions[-1].unsqueeze(2)]

allennlp/tests/nn/beam_search_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,24 @@ def test_catch_bad_config(self):
167167
beam_search = BeamSearch(self.end_index, beam_size=20)
168168
with pytest.raises(ConfigurationError):
169169
self._check_results(beam_search=beam_search)
170+
171+
def test_warn_for_bad_log_probs(self):
172+
# The only valid next step from the initial predictions is the end index.
173+
# But with a beam size of 3, the call to `topk` to find the 3 most likely
174+
# next beams will result in 2 new beams that are invalid, in that have probability of 0.
175+
# The beam search should warn us of this.
176+
initial_predictions = torch.LongTensor([self.end_index-1, self.end_index-1])
177+
with pytest.warns(RuntimeWarning, match="Infinite log probabilities"):
178+
self.beam_search.search(initial_predictions, {}, take_step)
179+
180+
def test_empty_sequences(self):
181+
initial_predictions = torch.LongTensor([self.end_index-1, self.end_index-1])
182+
beam_search = BeamSearch(self.end_index, beam_size=1)
183+
with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
184+
predictions, log_probs = beam_search.search(initial_predictions, {}, take_step)
185+
# predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
186+
assert list(predictions.size()) == [2, 1, 1]
187+
# log probs hould have shape `(batch_size, beam_size)`.
188+
assert list(log_probs.size()) == [2, 1]
189+
assert (predictions == self.end_index).all()
190+
assert (log_probs == 0).all()

0 commit comments

Comments
 (0)