Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 5d22ce6

Browse files
committed
Merge remote-tracking branch 'origin/master' into vision
2 parents 602399c + d99f7f8 commit 5d22ce6

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

CHANGELOG.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
3838

3939
## Unreleased (1.x branch)
4040

41+
### Fixed
42+
43+
- `GumbelSampler` now sorts the beams by their true log prob.
44+
45+
4146
## [v1.2.1](https://github.com/allenai/allennlp/releases/tag/v1.2.1) - 2020-11-10
4247

4348
### Added
@@ -48,7 +53,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
4853
- Added more documentation about plugins.
4954
- Added sampler class and parameter in beam search for non-deterministic search, with several
5055
implementations, including `MultinomialSampler`, `TopKSampler`, `TopPSampler`, and
51-
`GumbelMaxSampler`. Utilizing `GumbelMaxSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
56+
`GumbelSampler`. Utilizing `GumbelSampler` will give [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
5257

5358
### Changed
5459

@@ -67,6 +72,8 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
6772
- Fixed typo with registered name of ROUGE metric. Previously was `rogue`, fixed to `rouge`.
6873
- Fixed default masks that were erroneously created on the CPU even when a GPU is available.
6974
- Fixed pretrained embeddings for transformers that don't use end tokens.
75+
- Fixed the transformer tokenizer cache when the tokenizers are initialized with custom kwargs.
76+
7077

7178
## [v1.2.0](https://github.com/allenai/allennlp/releases/tag/v1.2.0) - 2020-10-29
7279

allennlp/common/cached_transformers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ def strip_prefix(s):
9494
return transformer
9595

9696

97-
_tokenizer_cache: Dict[Tuple[str, frozenset], transformers.PreTrainedTokenizer] = {}
97+
_tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {}
9898

9999

100100
def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer:
101-
cache_key = (model_name, frozenset(kwargs.items()))
101+
from allennlp.common.util import hash_object
102+
103+
cache_key = (model_name, hash_object(kwargs))
102104

103105
global _tokenizer_cache
104106
tokenizer = _tokenizer_cache.get(cache_key, None)

allennlp/common/util.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
Various utilities that don't fit anywhere else.
33
"""
4+
import hashlib
5+
import io
6+
import pickle
47
from datetime import timedelta
58
import importlib
69
import json
@@ -679,3 +682,12 @@ def cycle_iterator_function(iterator_function: Callable[[], Iterable[T]]) -> Ite
679682
yield next(iterator)
680683
except StopIteration:
681684
iterator = iter(iterator_function())
685+
686+
687+
def hash_object(o: Any) -> str:
688+
"""Returns a 32-character hash code of arbitrary Python objects."""
689+
m = hashlib.blake2b()
690+
with io.BytesIO() as buffer:
691+
pickle.dump(o, buffer)
692+
m.update(buffer.getbuffer())
693+
return m.hexdigest()

allennlp/nn/beam_search.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,18 @@ def sample_beams(
385385
# shape (both): (batch_size, beam_size)
386386
G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
387387

388-
# shape: (batch_size * beam_size,)
389-
G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
390-
391388
# shape: (batch_size, beam_size)
392389
selected_log_probs = log_probs.gather(1, selected_indices)
393390

391+
# Now sort the selected beams by their true log prob.
392+
# shape (all): (batch_size, beam_size)
393+
selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
394+
selected_indices = selected_indices.gather(1, sort_indices)
395+
G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
396+
397+
# shape: (batch_size * beam_size,)
398+
G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
399+
394400
# shape: (batch_size * beam_size,)
395401
phi_S = selected_log_probs.reshape(batch_size * beam_size)
396402

tests/nn/beam_search_test.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,14 @@ def test_gumbel_sampler(self):
434434
num_classes = len(log_probabilities[0])
435435
sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes)
436436

437-
probabilities, classes, state = sampler.sample_beams(log_probabilities, 3, sampler_state)
437+
log_probs, indices, state = sampler.sample_beams(log_probabilities, 3, sampler_state)
438438

439-
assert probabilities.size() == classes.size()
440-
assert classes.size() == (2, 3)
439+
assert log_probs.size() == indices.size()
440+
assert indices.size() == (2, 3)
441+
442+
# Make sure the probabilities are sorted.
443+
_, sorted_indices = log_probs.sort(dim=-1, descending=True)
444+
assert (sorted_indices == torch.arange(3).unsqueeze(0)).all()
441445

442-
assert all([x >= 0 and x < 4 for x in classes[0]])
443-
assert all([x > 1 and x <= 5 for x in classes[1]])
446+
assert all([x >= 0 and x < 4 for x in indices[0]])
447+
assert all([x > 1 and x <= 5 for x in indices[1]])

0 commit comments

Comments
 (0)