Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Make BeamSearch Registrable #5231

Merged
merged 5 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- Trainer callbacks can now store and restore state in case a training run gets interrupted.
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.
- `BeamSearch` is now a `Registrable` class.

### Added

Expand Down
9 changes: 7 additions & 2 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from overrides import overrides
import torch

from allennlp.common import FromParams, Registrable
from allennlp.common import Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.nn.util import min_value_of_dtype

Expand Down Expand Up @@ -683,7 +683,7 @@ def _update_state(
return state


class BeamSearch(FromParams):
class BeamSearch(Registrable):
"""
Implements the beam search algorithm for decoding the most likely sequences.

Expand Down Expand Up @@ -731,6 +731,8 @@ class BeamSearch(FromParams):
provided, no constraints will be enforced.
"""

default_implementation = "beam_search"

def __init__(
self,
end_index: int,
Expand Down Expand Up @@ -1180,3 +1182,6 @@ def _update_state(self, state: StateType, backpointer: torch.Tensor):
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)


BeamSearch.register("beam_search")(BeamSearch)