|
1 | 1 | from typing import List, Callable, Tuple, Dict
|
| 2 | +import warnings |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 |
|
@@ -48,6 +49,16 @@ def search(self,
|
48 | 49 | Given a starting state and a step function, apply beam search to find the
|
49 | 50 | most likely target sequences.
|
50 | 51 |
|
| 52 | + Notes |
| 53 | + ----- |
| 54 | + If your step function returns ``-inf`` for some log probabilities |
| 55 | + (like if you're using a masked log-softmax) then some of the "best" |
| 56 | + sequences returned may also have ``-inf`` log probability. Specifically |
| 57 | + this happens when the beam size is smaller than the number of actions |
| 58 | + with finite log probability (non-zero probability) returned by the step function. |
| 59 | + Therefore if you're using a mask you may want to check the results from ``search`` |
| 60 | + and potentially discard sequences with non-finite log probability. |
| 61 | +
|
51 | 62 | Parameters
|
52 | 63 | ----------
|
53 | 64 | start_predictions : ``torch.Tensor``
|
@@ -110,6 +121,11 @@ def search(self,
|
110 | 121 | # shape: (batch_size, beam_size), (batch_size, beam_size)
|
111 | 122 | start_top_log_probabilities, start_predicted_classes = \
|
112 | 123 | start_class_log_probabilities.topk(self.beam_size)
|
| 124 | + if self.beam_size == 1 and (start_predicted_classes == self._end_index).all(): |
| 125 | + warnings.warn("Empty sequences predicted. You may want to increase the beam size or ensure " |
| 126 | + "your step function is working properly.", |
| 127 | + RuntimeWarning) |
| 128 | + return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities |
113 | 129 |
|
114 | 130 | # The log probabilities for the last time step.
|
115 | 131 | # shape: (batch_size, beam_size)
|
@@ -166,9 +182,9 @@ def search(self,
|
166 | 182 | class_log_probabilities
|
167 | 183 | )
|
168 | 184 |
|
| 185 | + # shape (both): (batch_size * beam_size, per_node_beam_size) |
169 | 186 | top_log_probabilities, predicted_classes = \
|
170 | 187 | cleaned_log_probabilities.topk(self.per_node_beam_size)
|
171 |
| - # shape (both): (batch_size * beam_size, per_node_beam_size) |
172 | 188 |
|
173 | 189 | # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
|
174 | 190 | # so that we can add them to the current log probs for this timestep.
|
@@ -227,6 +243,12 @@ def search(self,
|
227 | 243 | gather(1, expanded_backpointer).\
|
228 | 244 | reshape(batch_size * self.beam_size, *last_dims)
|
229 | 245 |
|
| 246 | + if not torch.isfinite(last_log_probabilities).all(): |
| 247 | + warnings.warn("Infinite log probabilities encountered. Some final sequences may not make sense. " |
| 248 | + "This can happen when the beam size is larger than the number of valid (non-zero " |
| 249 | + "probability) transitions that the step function produces.", |
| 250 | + RuntimeWarning) |
| 251 | + |
230 | 252 | # Reconstruct the sequences.
|
231 | 253 | # shape: [(batch_size, beam_size, 1)]
|
232 | 254 | reconstructed_predictions = [predictions[-1].unsqueeze(2)]
|
|
0 commit comments