Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit b529f6d

Browse files
epwalshbrendan-ai2
authored andcommitted
Generalize beam search, improve simple_seq2seq (#1841)
Updates to `simple_seq2seq`: - supports attention module - utilizes beam search for decoding ~~@brendan-ai2 I didn't want to end up re-implementing beam search again, so I used the existing system by defining a simple state class, and then just implementing the `take_step` method within `SimpleSeq2Seq`. Let me what you think!~~ ~~@DeNeutoy the test `allennlp/tests/predictors/simple_seq2seq_test.py` is failing because it depends on an old archived model fixture at `allennlp/tests/fixtures/encoder_decoder/simple_seq2seq/serialization`. I'm not sure if you needed that exact model for something else, or if I could just replace it with another trained model.~~ **TODO:** - [x] pull out and generalize beam search from `event2mind` - [x] replace `state_machine.beam_search` in `simple_seq2seq` with more efficient version - [x] update model that predictor test depends on - [x] add unit tests for new beam search
1 parent b0ade1b commit b529f6d

20 files changed

+1291
-406
lines changed

allennlp/models/encoder_decoders/simple_seq2seq.py

+354-157
Large diffs are not rendered by default.

allennlp/models/event2mind.py

+36-177
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from allennlp.modules import Seq2VecEncoder, TextFieldEmbedder
1515
from allennlp.modules.token_embedders import Embedding
1616
from allennlp.models.model import Model
17+
from allennlp.nn.beam_search import BeamSearch
1718
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
1819
from allennlp.training.metrics import UnigramRecall
1920

21+
2022
@Model.register("event2mind")
2123
class Event2Mind(Model):
2224
"""
@@ -41,7 +43,9 @@ class Event2Mind(Model):
4143
The encoder of the "encoder/decoder" model.
4244
max_decoding_steps : int, required
4345
Length of decoded sequences.
44-
target_names: ``List[str]``, optional
46+
beam_size : int, optional (default = 10)
47+
The width of the beam search.
48+
target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
4549
Names of the target fields matching those in the ``Instance`` objects.
4650
target_namespace : str, optional (default = 'tokens')
4751
If the target side vocabulary is different from the source side's, you need to specify the
@@ -51,17 +55,20 @@ class Event2Mind(Model):
5155
You can specify an embedding dimensionality for the target side. If not, we'll use the same
5256
value as the source embedder's.
5357
"""
54-
# pylint: disable=dangerous-default-value
5558
def __init__(self,
5659
vocab: Vocabulary,
5760
source_embedder: TextFieldEmbedder,
5861
embedding_dropout: float,
5962
encoder: Seq2VecEncoder,
6063
max_decoding_steps: int,
61-
target_names: List[str] = ["xintent", "xreact", "oreact"],
64+
beam_size: int = 10,
65+
target_names: List[str] = None,
6266
target_namespace: str = "tokens",
6367
target_embedding_dim: int = None) -> None:
68+
target_names = target_names or ["xintent", "xreact", "oreact"]
69+
6470
super(Event2Mind, self).__init__(vocab)
71+
6572
# Note: The original tweaks the embeddings for "personx" to be the mean
6673
# across the embeddings for "he", "she", "him" and "her". Similarly for
6774
# "personx's" and so forth. We could consider that here as a well.
@@ -96,6 +103,12 @@ def __init__(self,
96103
self._decoder_output_dim
97104
)
98105

106+
self._beam_search = BeamSearch(
107+
self._end_index,
108+
beam_size=beam_size,
109+
max_steps=max_decoding_steps
110+
)
111+
99112
def _update_recall(self,
100113
all_top_k_predictions: torch.Tensor,
101114
target_tokens: Dict[str, torch.LongTensor],
@@ -175,20 +188,16 @@ def forward(self, # type: ignore
175188

176189
# Perform beam search to obtain the predictions.
177190
if not self.training:
191+
batch_size = final_encoder_output.size()[0]
178192
for name, state in self._states.items():
193+
start_predictions = final_encoder_output.new_full(
194+
(batch_size,), fill_value=self._start_index, dtype=torch.long)
195+
start_state = {"decoder_hidden": final_encoder_output}
196+
179197
# (batch_size, 10, num_decoding_steps)
180-
(all_top_k_predictions, log_probabilities) = self.beam_search(
181-
final_encoder_output=final_encoder_output,
182-
width=10,
183-
# We always use the max here instead of passing in the
184-
# length of the longest target to avoid biasing the
185-
# search. Whether this problem would manifest otherwise
186-
# would depend on the metric being used.
187-
num_decoding_steps=self._max_decoding_steps,
188-
target_embedder=state.embedder,
189-
decoder_cell=state.decoder_cell,
190-
output_projection_layer=state.output_projection_layer
191-
)
198+
all_top_k_predictions, log_probabilities = self._beam_search.search(
199+
start_predictions, start_state, state.take_step)
200+
192201
if target_tokens:
193202
self._update_recall(all_top_k_predictions, target_tokens[name], state.recall)
194203
output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions
@@ -276,168 +285,6 @@ def greedy_predict(self,
276285
# Drop start symbol and return.
277286
return all_predictions[:, 1:]
278287

279-
def beam_search(self,
280-
final_encoder_output: torch.LongTensor,
281-
width: int,
282-
num_decoding_steps: int,
283-
target_embedder: Embedding,
284-
decoder_cell: GRUCell,
285-
output_projection_layer: Linear) -> Tuple[torch.Tensor, torch.Tensor]:
286-
"""
287-
Uses beam search to compute the highest probability sequences for the
288-
``decoder_cell`` that fit within the given``width``. Returns the tuple
289-
consisting of the sequences themselves and their log probabilities.
290-
291-
Parameters
292-
----------
293-
final_encoder_output : ``torch.LongTensor``, required
294-
Vector produced by ``self._encoder``.
295-
width : ``int``, required
296-
Size of the beam.
297-
num_decoding_steps : ``int``, required
298-
Maximum sequence length.
299-
target_embedder : ``Embedding``, required
300-
Used to embed the token predicted at the previous time step.
301-
decoder_cell: ``GRUCell``, required
302-
The recurrent cell used at each time step.
303-
output_projection_layer: ``Linear``, required
304-
Linear layer mapping to the desired number of classes.
305-
306-
Returns
307-
-------
308-
predictions : ``torch.LongTensor``
309-
Tensor of shape (batch_size, width, num_decoding_steps) with the predicted indices.
310-
log_probabilities : ``torch.FloatTensor``
311-
Tensor of shape (batch_size, width) with the log probability of the
312-
corresponding prediction.
313-
"""
314-
batch_size = final_encoder_output.size()[0]
315-
# List of (batch_size, width) tensors. One for each time step. Does not
316-
# include the start symbols, which are implicit.
317-
predictions = []
318-
# List of (batch_size, width) tensors. One for each time step. None for
319-
# the first. Stores the index n for the parent prediction, i.e.
320-
# predictions[t-1][i][n], that it came from.
321-
backpointers = []
322-
323-
# Calculate the first timestep. This is done outside the main loop
324-
# because we are going from a single decoder input (the output from the
325-
# encoder) to the top ``width`` decoder outputs. On the other hand,
326-
# within the main loop we are going from the ``width`` elements of the
327-
# beam to ``width``^2 candidates from which we will select the top
328-
# ``width`` elements for the next iteration.
329-
start_predictions = final_encoder_output.new_full(
330-
(batch_size,), fill_value=self._start_index, dtype=torch.long
331-
)
332-
start_decoder_input = target_embedder(start_predictions)
333-
start_decoder_hidden = decoder_cell(start_decoder_input, final_encoder_output)
334-
start_output_projections = output_projection_layer(start_decoder_hidden)
335-
start_class_log_probabilities = F.log_softmax(start_output_projections, dim=-1)
336-
start_top_log_probabilities, start_predicted_classes = start_class_log_probabilities.topk(width)
337-
338-
# Set starting values
339-
# The log probabilities for the last time step. (batch_size, width)
340-
last_log_probabilities = start_top_log_probabilities
341-
# [(batch_size, width)]
342-
predictions.append(start_predicted_classes)
343-
# Set the same hidden state for each element in beam.
344-
# (batch_size * width, _decoder_output_dim)
345-
decoder_hidden = start_decoder_hidden.\
346-
unsqueeze(1).expand(batch_size, width, self._decoder_output_dim).\
347-
reshape(batch_size * width, self._decoder_output_dim)
348-
349-
# Log probability tensor that mandates that the end token is selected.
350-
num_classes = self.vocab.get_vocab_size(self._target_namespace)
351-
log_probs_after_end = start_class_log_probabilities.new_full(
352-
(batch_size * width, num_classes),
353-
float("-inf")
354-
)
355-
log_probs_after_end[:, self._end_index] = 0.0
356-
357-
for timestep in range(num_decoding_steps - 1):
358-
# (batch_size * width,)
359-
last_predictions = predictions[-1].reshape(batch_size * width)
360-
decoder_input = target_embedder(last_predictions)
361-
decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
362-
# (batch_size * width, num_classes)
363-
output_projections = output_projection_layer(decoder_hidden)
364-
365-
# (batch_size * width, num_classes)
366-
class_log_probabilities = F.log_softmax(output_projections, dim=-1)
367-
368-
# (batch_size * width, num_classes)
369-
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
370-
batch_size * width,
371-
num_classes
372-
)
373-
# Here we are finding any beams where we predicted the end token in
374-
# the previous timestep and replacing the distribution with a
375-
# one-hot distribution, forcing the beam to predict the end token
376-
# this timestep as well.
377-
cleaned_log_probabilities = torch.where(
378-
last_predictions_expanded == self._end_index,
379-
log_probs_after_end,
380-
class_log_probabilities
381-
)
382-
383-
# Note: We could consider normalizing for length here, but the
384-
# original implementation does not do so.
385-
386-
# (batch_size * width, width), (batch_size * width, width)
387-
top_log_probabilities, predicted_classes = cleaned_log_probabilities.topk(width)
388-
# Here we expand the last log probabilities to (batch_size * width,
389-
# width) so that we can add them to the current log probs for this
390-
# timestep. This lets us maintain the log probability of each
391-
# element on the beam.
392-
expanded_last_log_probabilities = last_log_probabilities.\
393-
unsqueeze(2).\
394-
expand(batch_size, width, width).\
395-
reshape(batch_size * width, width)
396-
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
397-
398-
reshaped_summed = summed_top_log_probabilities.reshape(batch_size, width * width)
399-
reshaped_predicted_classes = predicted_classes.reshape(batch_size, width * width)
400-
# Keep only the top ``width`` beam indices.
401-
restricted_beam_log_probs, restricted_beam_indices = reshaped_summed.topk(width)
402-
# Use the beam indices to extract the corresponding classes.
403-
restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
404-
405-
last_log_probabilities = restricted_beam_log_probs
406-
predictions.append(restricted_predicted_classes)
407-
# The beam indices come from a width * width dimension where the
408-
# indices with a common ancestor are grouped together. Hence
409-
# dividing by width gives the ancestor. (Note that this is integer
410-
# division as the tensor is a LongTensor.)
411-
backpointer = restricted_beam_indices / width
412-
backpointers.append(backpointer)
413-
# For the gather below.
414-
expanded_backpointer = backpointer.unsqueeze(2).expand(batch_size, width, self._decoder_output_dim)
415-
# Keep only the pieces of the hidden state corresponding to the
416-
# ancestors created this iteration.
417-
decoder_hidden = decoder_hidden.\
418-
reshape(batch_size, width, self._decoder_output_dim).\
419-
gather(1, expanded_backpointer).\
420-
reshape(batch_size * width, self._decoder_output_dim)
421-
422-
assert len(predictions) == num_decoding_steps,\
423-
"len(predictions) not equal to num_decoding_steps"
424-
assert len(backpointers) == num_decoding_steps - 1,\
425-
"len(backpointers) not equal to num_decoding_steps"
426-
427-
# Reconstruct the sequences.
428-
reconstructed_predictions = [predictions[num_decoding_steps - 1].unsqueeze(2)]
429-
cur_backpointers = backpointers[num_decoding_steps - 2]
430-
for timestep in range(num_decoding_steps - 2, 0, -1):
431-
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
432-
reconstructed_predictions.append(cur_preds)
433-
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
434-
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
435-
reconstructed_predictions.append(final_preds)
436-
# We don't add the start tokens here. They are implicit.
437-
438-
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
439-
return (all_predictions, last_log_probabilities)
440-
441288
@staticmethod
442289
def _get_loss(logits: torch.LongTensor,
443290
targets: torch.LongTensor,
@@ -509,6 +356,7 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
509356
all_metrics[name] = state.recall.get_metric(reset=reset)
510357
return all_metrics
511358

359+
512360
class StateDecoder:
513361
"""
514362
Simple struct-like class for internal use.
@@ -526,3 +374,14 @@ def __init__(self,
526374
self.output_projection_layer = Linear(output_dim, num_classes)
527375
event2mind.add_module(f"{name}_output_project_layer", self.output_projection_layer)
528376
self.recall = UnigramRecall()
377+
378+
def take_step(self,
379+
last_predictions: torch.Tensor,
380+
state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
381+
decoder_hidden = state["decoder_hidden"]
382+
decoder_input = self.embedder(last_predictions)
383+
decoder_hidden = self.decoder_cell(decoder_input, decoder_hidden)
384+
state["decoder_hidden"] = decoder_hidden
385+
output_projections = self.output_projection_layer(decoder_hidden)
386+
class_log_probabilities = F.log_softmax(output_projections, dim=-1)
387+
return class_log_probabilities, state

0 commit comments

Comments
 (0)