14
14
from allennlp .modules import Seq2VecEncoder , TextFieldEmbedder
15
15
from allennlp .modules .token_embedders import Embedding
16
16
from allennlp .models .model import Model
17
+ from allennlp .nn .beam_search import BeamSearch
17
18
from allennlp .nn .util import get_text_field_mask , sequence_cross_entropy_with_logits
18
19
from allennlp .training .metrics import UnigramRecall
19
20
21
+
20
22
@Model .register ("event2mind" )
21
23
class Event2Mind (Model ):
22
24
"""
@@ -41,7 +43,9 @@ class Event2Mind(Model):
41
43
The encoder of the "encoder/decoder" model.
42
44
max_decoding_steps : int, required
43
45
Length of decoded sequences.
44
- target_names: ``List[str]``, optional
46
+ beam_size : int, optional (default = 10)
47
+ The width of the beam search.
48
+ target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
45
49
Names of the target fields matching those in the ``Instance`` objects.
46
50
target_namespace : str, optional (default = 'tokens')
47
51
If the target side vocabulary is different from the source side's, you need to specify the
@@ -51,17 +55,20 @@ class Event2Mind(Model):
51
55
You can specify an embedding dimensionality for the target side. If not, we'll use the same
52
56
value as the source embedder's.
53
57
"""
54
- # pylint: disable=dangerous-default-value
55
58
def __init__ (self ,
56
59
vocab : Vocabulary ,
57
60
source_embedder : TextFieldEmbedder ,
58
61
embedding_dropout : float ,
59
62
encoder : Seq2VecEncoder ,
60
63
max_decoding_steps : int ,
61
- target_names : List [str ] = ["xintent" , "xreact" , "oreact" ],
64
+ beam_size : int = 10 ,
65
+ target_names : List [str ] = None ,
62
66
target_namespace : str = "tokens" ,
63
67
target_embedding_dim : int = None ) -> None :
68
+ target_names = target_names or ["xintent" , "xreact" , "oreact" ]
69
+
64
70
super (Event2Mind , self ).__init__ (vocab )
71
+
65
72
# Note: The original tweaks the embeddings for "personx" to be the mean
66
73
# across the embeddings for "he", "she", "him" and "her". Similarly for
67
74
# "personx's" and so forth. We could consider that here as a well.
@@ -96,6 +103,12 @@ def __init__(self,
96
103
self ._decoder_output_dim
97
104
)
98
105
106
+ self ._beam_search = BeamSearch (
107
+ self ._end_index ,
108
+ beam_size = beam_size ,
109
+ max_steps = max_decoding_steps
110
+ )
111
+
99
112
def _update_recall (self ,
100
113
all_top_k_predictions : torch .Tensor ,
101
114
target_tokens : Dict [str , torch .LongTensor ],
@@ -175,20 +188,16 @@ def forward(self, # type: ignore
175
188
176
189
# Perform beam search to obtain the predictions.
177
190
if not self .training :
191
+ batch_size = final_encoder_output .size ()[0 ]
178
192
for name , state in self ._states .items ():
193
+ start_predictions = final_encoder_output .new_full (
194
+ (batch_size ,), fill_value = self ._start_index , dtype = torch .long )
195
+ start_state = {"decoder_hidden" : final_encoder_output }
196
+
179
197
# (batch_size, 10, num_decoding_steps)
180
- (all_top_k_predictions , log_probabilities ) = self .beam_search (
181
- final_encoder_output = final_encoder_output ,
182
- width = 10 ,
183
- # We always use the max here instead of passing in the
184
- # length of the longest target to avoid biasing the
185
- # search. Whether this problem would manifest otherwise
186
- # would depend on the metric being used.
187
- num_decoding_steps = self ._max_decoding_steps ,
188
- target_embedder = state .embedder ,
189
- decoder_cell = state .decoder_cell ,
190
- output_projection_layer = state .output_projection_layer
191
- )
198
+ all_top_k_predictions , log_probabilities = self ._beam_search .search (
199
+ start_predictions , start_state , state .take_step )
200
+
192
201
if target_tokens :
193
202
self ._update_recall (all_top_k_predictions , target_tokens [name ], state .recall )
194
203
output_dict [f"{ name } _top_k_predictions" ] = all_top_k_predictions
@@ -276,168 +285,6 @@ def greedy_predict(self,
276
285
# Drop start symbol and return.
277
286
return all_predictions [:, 1 :]
278
287
279
- def beam_search (self ,
280
- final_encoder_output : torch .LongTensor ,
281
- width : int ,
282
- num_decoding_steps : int ,
283
- target_embedder : Embedding ,
284
- decoder_cell : GRUCell ,
285
- output_projection_layer : Linear ) -> Tuple [torch .Tensor , torch .Tensor ]:
286
- """
287
- Uses beam search to compute the highest probability sequences for the
288
- ``decoder_cell`` that fit within the given``width``. Returns the tuple
289
- consisting of the sequences themselves and their log probabilities.
290
-
291
- Parameters
292
- ----------
293
- final_encoder_output : ``torch.LongTensor``, required
294
- Vector produced by ``self._encoder``.
295
- width : ``int``, required
296
- Size of the beam.
297
- num_decoding_steps : ``int``, required
298
- Maximum sequence length.
299
- target_embedder : ``Embedding``, required
300
- Used to embed the token predicted at the previous time step.
301
- decoder_cell: ``GRUCell``, required
302
- The recurrent cell used at each time step.
303
- output_projection_layer: ``Linear``, required
304
- Linear layer mapping to the desired number of classes.
305
-
306
- Returns
307
- -------
308
- predictions : ``torch.LongTensor``
309
- Tensor of shape (batch_size, width, num_decoding_steps) with the predicted indices.
310
- log_probabilities : ``torch.FloatTensor``
311
- Tensor of shape (batch_size, width) with the log probability of the
312
- corresponding prediction.
313
- """
314
- batch_size = final_encoder_output .size ()[0 ]
315
- # List of (batch_size, width) tensors. One for each time step. Does not
316
- # include the start symbols, which are implicit.
317
- predictions = []
318
- # List of (batch_size, width) tensors. One for each time step. None for
319
- # the first. Stores the index n for the parent prediction, i.e.
320
- # predictions[t-1][i][n], that it came from.
321
- backpointers = []
322
-
323
- # Calculate the first timestep. This is done outside the main loop
324
- # because we are going from a single decoder input (the output from the
325
- # encoder) to the top ``width`` decoder outputs. On the other hand,
326
- # within the main loop we are going from the ``width`` elements of the
327
- # beam to ``width``^2 candidates from which we will select the top
328
- # ``width`` elements for the next iteration.
329
- start_predictions = final_encoder_output .new_full (
330
- (batch_size ,), fill_value = self ._start_index , dtype = torch .long
331
- )
332
- start_decoder_input = target_embedder (start_predictions )
333
- start_decoder_hidden = decoder_cell (start_decoder_input , final_encoder_output )
334
- start_output_projections = output_projection_layer (start_decoder_hidden )
335
- start_class_log_probabilities = F .log_softmax (start_output_projections , dim = - 1 )
336
- start_top_log_probabilities , start_predicted_classes = start_class_log_probabilities .topk (width )
337
-
338
- # Set starting values
339
- # The log probabilities for the last time step. (batch_size, width)
340
- last_log_probabilities = start_top_log_probabilities
341
- # [(batch_size, width)]
342
- predictions .append (start_predicted_classes )
343
- # Set the same hidden state for each element in beam.
344
- # (batch_size * width, _decoder_output_dim)
345
- decoder_hidden = start_decoder_hidden .\
346
- unsqueeze (1 ).expand (batch_size , width , self ._decoder_output_dim ).\
347
- reshape (batch_size * width , self ._decoder_output_dim )
348
-
349
- # Log probability tensor that mandates that the end token is selected.
350
- num_classes = self .vocab .get_vocab_size (self ._target_namespace )
351
- log_probs_after_end = start_class_log_probabilities .new_full (
352
- (batch_size * width , num_classes ),
353
- float ("-inf" )
354
- )
355
- log_probs_after_end [:, self ._end_index ] = 0.0
356
-
357
- for timestep in range (num_decoding_steps - 1 ):
358
- # (batch_size * width,)
359
- last_predictions = predictions [- 1 ].reshape (batch_size * width )
360
- decoder_input = target_embedder (last_predictions )
361
- decoder_hidden = decoder_cell (decoder_input , decoder_hidden )
362
- # (batch_size * width, num_classes)
363
- output_projections = output_projection_layer (decoder_hidden )
364
-
365
- # (batch_size * width, num_classes)
366
- class_log_probabilities = F .log_softmax (output_projections , dim = - 1 )
367
-
368
- # (batch_size * width, num_classes)
369
- last_predictions_expanded = last_predictions .unsqueeze (- 1 ).expand (
370
- batch_size * width ,
371
- num_classes
372
- )
373
- # Here we are finding any beams where we predicted the end token in
374
- # the previous timestep and replacing the distribution with a
375
- # one-hot distribution, forcing the beam to predict the end token
376
- # this timestep as well.
377
- cleaned_log_probabilities = torch .where (
378
- last_predictions_expanded == self ._end_index ,
379
- log_probs_after_end ,
380
- class_log_probabilities
381
- )
382
-
383
- # Note: We could consider normalizing for length here, but the
384
- # original implementation does not do so.
385
-
386
- # (batch_size * width, width), (batch_size * width, width)
387
- top_log_probabilities , predicted_classes = cleaned_log_probabilities .topk (width )
388
- # Here we expand the last log probabilities to (batch_size * width,
389
- # width) so that we can add them to the current log probs for this
390
- # timestep. This lets us maintain the log probability of each
391
- # element on the beam.
392
- expanded_last_log_probabilities = last_log_probabilities .\
393
- unsqueeze (2 ).\
394
- expand (batch_size , width , width ).\
395
- reshape (batch_size * width , width )
396
- summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
397
-
398
- reshaped_summed = summed_top_log_probabilities .reshape (batch_size , width * width )
399
- reshaped_predicted_classes = predicted_classes .reshape (batch_size , width * width )
400
- # Keep only the top ``width`` beam indices.
401
- restricted_beam_log_probs , restricted_beam_indices = reshaped_summed .topk (width )
402
- # Use the beam indices to extract the corresponding classes.
403
- restricted_predicted_classes = reshaped_predicted_classes .gather (1 , restricted_beam_indices )
404
-
405
- last_log_probabilities = restricted_beam_log_probs
406
- predictions .append (restricted_predicted_classes )
407
- # The beam indices come from a width * width dimension where the
408
- # indices with a common ancestor are grouped together. Hence
409
- # dividing by width gives the ancestor. (Note that this is integer
410
- # division as the tensor is a LongTensor.)
411
- backpointer = restricted_beam_indices / width
412
- backpointers .append (backpointer )
413
- # For the gather below.
414
- expanded_backpointer = backpointer .unsqueeze (2 ).expand (batch_size , width , self ._decoder_output_dim )
415
- # Keep only the pieces of the hidden state corresponding to the
416
- # ancestors created this iteration.
417
- decoder_hidden = decoder_hidden .\
418
- reshape (batch_size , width , self ._decoder_output_dim ).\
419
- gather (1 , expanded_backpointer ).\
420
- reshape (batch_size * width , self ._decoder_output_dim )
421
-
422
- assert len (predictions ) == num_decoding_steps ,\
423
- "len(predictions) not equal to num_decoding_steps"
424
- assert len (backpointers ) == num_decoding_steps - 1 ,\
425
- "len(backpointers) not equal to num_decoding_steps"
426
-
427
- # Reconstruct the sequences.
428
- reconstructed_predictions = [predictions [num_decoding_steps - 1 ].unsqueeze (2 )]
429
- cur_backpointers = backpointers [num_decoding_steps - 2 ]
430
- for timestep in range (num_decoding_steps - 2 , 0 , - 1 ):
431
- cur_preds = predictions [timestep ].gather (1 , cur_backpointers ).unsqueeze (2 )
432
- reconstructed_predictions .append (cur_preds )
433
- cur_backpointers = backpointers [timestep - 1 ].gather (1 , cur_backpointers )
434
- final_preds = predictions [0 ].gather (1 , cur_backpointers ).unsqueeze (2 )
435
- reconstructed_predictions .append (final_preds )
436
- # We don't add the start tokens here. They are implicit.
437
-
438
- all_predictions = torch .cat (list (reversed (reconstructed_predictions )), 2 )
439
- return (all_predictions , last_log_probabilities )
440
-
441
288
@staticmethod
442
289
def _get_loss (logits : torch .LongTensor ,
443
290
targets : torch .LongTensor ,
@@ -509,6 +356,7 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
509
356
all_metrics [name ] = state .recall .get_metric (reset = reset )
510
357
return all_metrics
511
358
359
+
512
360
class StateDecoder :
513
361
"""
514
362
Simple struct-like class for internal use.
@@ -526,3 +374,14 @@ def __init__(self,
526
374
self .output_projection_layer = Linear (output_dim , num_classes )
527
375
event2mind .add_module (f"{ name } _output_project_layer" , self .output_projection_layer )
528
376
self .recall = UnigramRecall ()
377
+
378
+ def take_step (self ,
379
+ last_predictions : torch .Tensor ,
380
+ state : Dict [str , torch .Tensor ]) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
381
+ decoder_hidden = state ["decoder_hidden" ]
382
+ decoder_input = self .embedder (last_predictions )
383
+ decoder_hidden = self .decoder_cell (decoder_input , decoder_hidden )
384
+ state ["decoder_hidden" ] = decoder_hidden
385
+ output_projections = self .output_projection_layer (decoder_hidden )
386
+ class_log_probabilities = F .log_softmax (output_projections , dim = - 1 )
387
+ return class_log_probabilities , state
0 commit comments