Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Implementing abstraction to score final sequences in BeamSearch #5208

Merged
merged 7 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.
- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.
- Added the `FinalSequenceScorer` abstraction to calculate the final scores of the generated sequences in `BeamSearch`.

### Fixed

Expand Down
113 changes: 110 additions & 3 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,100 @@ def gumbel_with_max(self, phi, T) -> torch.Tensor:
return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))


class FinalSequenceScorer(Registrable):
"""
An abstract class that can be used to score the final generated sequences found
by beam search. Given the predicted sequences and the corresponding log probabilities of
those sequences, the class calculates and returns the final score of the sequences.

The default implementation scores the sequences using the sum of the log probabilities of
the sequence, which is passed as input.
"""

default_implementation = "sequence-log-prob"

def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
"""
Score the final predictions found by beam search.

# Parameters

predictions : `torch.Tensor`
A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.

log_probabilities : `StateType`
A tensor containing the log probabilities of the sequence, defined as the sum
of the log probabilities per token, with shape `(batch_size, beam_size)`.

end_index : `int`
The index of the end symbol.

# Returns

`torch.Tensor`
A tensor of the final sequence scores of shape `(batch_size, beam_size)`.
"""
raise NotImplementedError


@FinalSequenceScorer.register("sequence-log-prob")
class SequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A `FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
across the sequence's tokens.
"""

@overrides
def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
# The sum of the sequence log probabilities is the input parameter, so just
# return it. The tensor is cloned so it does not use the same storage as the input
# tensor, as is the case with `LengthNormalizedSequenceLogProbabilityScorer`.
return log_probabilities.clone()


@FinalSequenceScorer.register("length-normalized-sequence-log-prob")
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A `FinalSequenceScorer` which scores the sequences by the average log probability of the
tokens in the sequence. It optionally includes a length penalty which promotes
or demotes sequences based on their lengths. The final score for a sequence will
be (sequence_log_probability) / (sequence_length ** length_penalty). The sequence length
here includes the end token.

# Parameters

length_penalty : `float`, optional (default = `1.0`)
The length penalty to use. A value of 1.0 means no length penalty is used.
A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
"""

def __init__(self, length_penalty: float = 1.0):
super().__init__()
self.length_penalty = length_penalty

@overrides
def score(
self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int
) -> torch.Tensor:
# shape: (batch_size, beam_size)
lengths = (predictions != end_index).long().sum(dim=2)

# If the sequence ended during beam search, the `log_probabilities` will include
# the transition to the end token. Therefore, in such situations, `lengths` is
# actually off by 1. This corrects for that.
# shape: (batch_size, beam_size)
is_end_token = predictions[:, :, -1] == end_index
lengths += is_end_token.long()

# shape: (batch_size, beam_size)
average_log_probs = log_probabilities / (lengths ** self.length_penalty)
return average_log_probs


class BeamSearch(FromParams):
"""
Implements the beam search algorithm for decoding the most likely sequences.
Expand Down Expand Up @@ -462,6 +556,12 @@ class BeamSearch(FromParams):

Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
[Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).

final_sequence_scorer : `FinalSequenceScorer`, optional (default = `None`)
An optional `FinalSequenceScorer` which is used to score the final generated sequences.
The output from this module is what is returned by the `search` method. If not
specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
by the sum of the token log probabilities.
"""

def __init__(
Expand All @@ -471,6 +571,7 @@ def __init__(
beam_size: int = 10,
per_node_beam_size: int = None,
sampler: Sampler = None,
final_sequence_scorer: FinalSequenceScorer = None,
) -> None:
if not max_steps > 0:
raise ValueError("max_steps must be positive")
Expand All @@ -484,6 +585,7 @@ def __init__(
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size or beam_size
self.sampler = sampler or DeterministicSampler()
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()

@staticmethod
def _reconstruct_sequences(predictions, backpointers):
Expand Down Expand Up @@ -568,8 +670,8 @@ def search(
# Returns

`Tuple[torch.Tensor, torch.Tensor]`
Tuple of `(predictions, log_probabilities)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `log_probabilities`
Tuple of `(predictions, final_scores)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `final_scores`
has shape `(batch_size, beam_size)`.
"""
step_signature = signature(step)
Expand Down Expand Up @@ -763,7 +865,12 @@ def _search(
# shape: (batch_size, beam_size, max_steps)
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)

return all_predictions, last_log_probabilities
# Calculate the final sequence scores
final_scores = self.final_sequence_scorer.score(
all_predictions, last_log_probabilities, self._end_index
)

return all_predictions, final_scores

@staticmethod
def _is_multilayer_rnn_decoder(key: str, state_tensor: torch.Tensor) -> bool:
Expand Down
45 changes: 45 additions & 0 deletions tests/nn/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
TopKSampler,
TopPSampler,
GumbelSampler,
SequenceLogProbabilityScorer,
LengthNormalizedSequenceLogProbabilityScorer,
)
from allennlp.common.params import Params

Expand Down Expand Up @@ -445,3 +447,46 @@ def test_gumbel_sampler(self):

assert all([x >= 0 and x < 4 for x in indices[0]])
assert all([x > 1 and x <= 5 for x in indices[1]])

def test_sequence_log_prob_scorer(self):
# SequenceLogProbabilityScorer is the default, so manually setting the
# sequence scorer shouldn't actually change anything
self.beam_search.sequence_scorer = SequenceLogProbabilityScorer()

def test_length_normalized_sequence_log_prob_scorer(self):
"""
Tests to ensure the sequences are normalized by the correct values. The end token is
included in the length. The start token is not.
"""
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer()
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array([5, 4, 3])
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_log_probs=expected_scores)

# Introduce a length penalty
length_penalty = 2.0
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
length_penalty=length_penalty
)
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array(
[5 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty]
)
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_log_probs=expected_scores)

# Here, we set the max_steps = 4. This prevents the first sequence from finishing,
# so its length does not include the end token, whereas the other sequences do.
length_penalty = 2.0
self.beam_search.max_steps = 4
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
length_penalty=length_penalty
)
expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]])
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
length_normalization = np.array(
[4 ** length_penalty, 4 ** length_penalty, 3 ** length_penalty]
)
expected_scores = expected_log_probs / length_normalization
self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)