Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Make BeamSearch Registrable #5231

Merged
merged 5 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- Trainer callbacks can now store and restore state in case a training run gets interrupted.
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.
- Register `BeamSearch` with name `"beam_search"` and make this the `default_implementation`.

### Added

Expand Down
9 changes: 7 additions & 2 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from overrides import overrides
import torch

from allennlp.common import FromParams, Registrable
from allennlp.common import Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.nn.util import min_value_of_dtype

Expand Down Expand Up @@ -683,7 +683,7 @@ def _update_state(
return state


class BeamSearch(FromParams):
class BeamSearch(Registrable):
"""
Implements the beam search algorithm for decoding the most likely sequences.

Expand Down Expand Up @@ -731,6 +731,8 @@ class BeamSearch(FromParams):
provided, no constraints will be enforced.
"""

default_implementation = "beam_search"

def __init__(
self,
end_index: int,
Expand Down Expand Up @@ -1180,3 +1182,6 @@ def _update_state(self, state: StateType, backpointer: torch.Tensor):
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)


BeamSearch.register("beam_search")(BeamSearch)