diff --git a/CHANGELOG.md b/CHANGELOG.md index a316a6ce2c4..0f8a9b13503 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details. - Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths. - Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences. +- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`. ### Fixed diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index f1d43226a83..4337e3efc4a 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -431,6 +431,99 @@ def gumbel_with_max(self, phi, T) -> torch.Tensor: return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs())) +class FinalSequenceScorer(Registrable): + """ + An abstract class that can be used to score the final generated sequences found + by beam search. Given the predicted sequences and the corresponding log probabilities of + those sequences, the class calculates and returns the final score of the sequences. + + The default implementation scores the sequences using the sum of the log probabilities of + the sequence, which is passed as input. + """ + + default_implementation = "sequence-log-prob" + + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + """ + Score the final predictions found by beam search. + + # Parameters + + predictions : `torch.Tensor` + A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`. + + log_probabilities : `torch.Tensor` + A tensor containing the log probabilities of the sequence, defined as the sum + of the log probabilities per token, with shape `(batch_size, beam_size)`. + + end_index : `int` + The index of the end symbol. + + # Returns + + `torch.Tensor` + A tensor of the final sequence scores of shape `(batch_size, beam_size)`. + """ + raise NotImplementedError + + +@FinalSequenceScorer.register("sequence-log-prob") +class SequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A `FinalSequenceScorer` which scores the sequences by the sum of the log probabilities + across the sequence's tokens. + """ + + @overrides + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + # The sum of the sequence log probabilities is the input parameter, so just + # return it. + return log_probabilities + + +@FinalSequenceScorer.register("length-normalized-sequence-log-prob") +class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer): + """ + A `FinalSequenceScorer` which scores the sequences by the average log probability of the + tokens in the sequence. It optionally includes a length penalty which promotes + or demotes sequences based on their lengths. The final score for a sequence will + be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length + here includes the end token. + + # Parameters + + length_penalty : `float`, optional (default = `1.0`) + The length penalty to use. A value of 1.0 means no length penalty is used. + A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences. + """ + + def __init__(self, length_penalty: float = 1.0): + super().__init__() + self.length_penalty = length_penalty + + @overrides + def score( + self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int + ) -> torch.Tensor: + # shape: (batch_size, beam_size) + lengths = (predictions != end_index).long().sum(dim=2) + + # If the sequence ended during beam search, the `log_probabilities` will include + # the transition to the end token. Therefore, in such situations, `lengths` is + # actually off by 1. This corrects for that. + # shape: (batch_size, beam_size) + is_end_token = predictions[:, :, -1] == end_index + lengths += is_end_token.long() + + # shape: (batch_size, beam_size) + average_log_probs = log_probabilities / (lengths ** self.length_penalty) + return average_log_probs + + class BeamSearch(FromParams): """ Implements the beam search algorithm for decoding the most likely sequences. @@ -467,6 +560,12 @@ class BeamSearch(FromParams): The minimum number of decoding steps to take, i.e. the minimum length of the predicted sequences. This does not include the start or end tokens. If `None`, no minimum is enforced. + + final_sequence_scorer : `FinalSequenceScorer`, optional (default = `None`) + An optional `FinalSequenceScorer` which is used to score the final generated sequences. + The output from this module is what is returned by the `search` method. If not + specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences + by the sum of the token log probabilities. """ def __init__( @@ -477,6 +576,7 @@ def __init__( per_node_beam_size: int = None, sampler: Sampler = None, min_steps: Optional[int] = None, + final_sequence_scorer: FinalSequenceScorer = None, ) -> None: if not max_steps > 0: raise ValueError("max_steps must be positive") @@ -496,6 +596,7 @@ def __init__( self.per_node_beam_size = per_node_beam_size or beam_size self.sampler = sampler or DeterministicSampler() self.min_steps = min_steps or 0 + self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() @staticmethod def _reconstruct_sequences(predictions, backpointers): @@ -580,8 +681,8 @@ def search( # Returns `Tuple[torch.Tensor, torch.Tensor]` - Tuple of `(predictions, log_probabilities)`, where `predictions` - has shape `(batch_size, beam_size, max_steps)` and `log_probabilities` + Tuple of `(predictions, final_scores)`, where `predictions` + has shape `(batch_size, beam_size, max_steps)` and `final_scores` has shape `(batch_size, beam_size)`. """ step_signature = signature(step) @@ -786,7 +887,20 @@ def _search( # shape: (batch_size, beam_size, max_steps) all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2) - return all_predictions, last_log_probabilities + # Calculate the final sequence scores + # shape: (batch_size, beam_size) + final_scores = self.final_sequence_scorer.score( + all_predictions, last_log_probabilities, self._end_index + ) + + # Sort the sequences based on the final scores so the best scoring + # sequence is at index 0 + sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True) + sorted_all_predictions = torch.gather( + all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions) + ) + + return sorted_all_predictions, sorted_final_scores @staticmethod def _is_multilayer_rnn_decoder(key: str, state_tensor: torch.Tensor) -> bool: diff --git a/tests/nn/beam_search_test.py b/tests/nn/beam_search_test.py index 4fcd892ab91..275390cc135 100644 --- a/tests/nn/beam_search_test.py +++ b/tests/nn/beam_search_test.py @@ -12,6 +12,8 @@ TopKSampler, TopPSampler, GumbelSampler, + SequenceLogProbabilityScorer, + LengthNormalizedSequenceLogProbabilityScorer, ) from allennlp.common.params import Params @@ -538,3 +540,59 @@ def test_gumbel_sampler(self): assert all([x >= 0 and x < 4 for x in indices[0]]) assert all([x > 1 and x <= 5 for x in indices[1]]) + + def test_sequence_log_prob_scorer(self): + # SequenceLogProbabilityScorer is the default, so manually setting the + # sequence scorer shouldn't actually change anything + self.beam_search.sequence_scorer = SequenceLogProbabilityScorer() + + def test_length_normalized_sequence_log_prob_scorer(self): + """ + Tests to ensure the sequences are normalized by the correct values. The end token is + included in the length. The start token is not. + """ + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer() + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array([5, 4, 3]) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Introduce a length penalty + length_penalty = 2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array( + [5 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_log_probs=expected_scores) + + # Pick a length penalty so extreme that the order of the sequences is reversed + length_penalty = -2.0 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]]) + expected_log_probs = np.log(np.array([0.2, 0.3, 0.4])) + length_normalization = np.array( + [3 ** length_penalty, 4 ** length_penalty, 5 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores) + + # Here, we set the max_steps = 4. This prevents the first sequence from finishing, + # so its length does not include the end token, whereas the other sequences do. + length_penalty = 2.0 + self.beam_search.max_steps = 4 + self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer( + length_penalty=length_penalty + ) + expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]]) + expected_log_probs = np.log(np.array([0.4, 0.3, 0.2])) + length_normalization = np.array( + [4 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty] + ) + expected_scores = expected_log_probs / length_normalization + self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)