Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 711afaa

Browse files
David Waddenmatt-gardner
David Wadden
andauthored
Fix division by zero when there are zero-length spans in MismatchedEmbedder. (#4615)
* Implment MattG's fix for NaN gradients in MismatchedEmbedder. Fix `clamp_min` on embeddings. Implment MattG's fix for NaN gradients in MismatchedEmbedder. * Fix NaN gradients caused by weird tokens in MismatchedEmbedder. Fixed division by zero error when there are zero-length spans in the input to a mismatched embedder. * Add changelog message. * Re-run `black` to get code formatting right. * combine fixed sections after merging with master Co-authored-by: Matt Gardner <[email protected]>
1 parent be97943 commit 711afaa

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Fixed handling of some edge cases when constructing classes with `FromParams` where the class
1313
accepts `**kwargs`.
14+
- Fixed division by zero error when there are zero-length spans in the input to a
15+
`PretrainedTransformerMismatchedIndexer`.
1416

1517
### Added
1618

1719
- `Predictor.capture_model_internals()` now accepts a regex specifying
1820
which modules to capture
1921

22+
2023
## [v1.1.0rc4](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc4) - 2020-08-20
2124

2225
### Added
@@ -63,7 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6366

6467
- Added the option to specify `requires_grad: false` within an optimizer's parameter groups.
6568
- Added the `file-friendly-logging` flag back to the `train` command. Also added this flag to the `predict`, `evaluate`, and `find-learning-rate` commands.
66-
- Added an `EpochCallback` to track current epoch as a model class member.
69+
- Added an `EpochCallback` to track current epoch as a model class member.
6770
- Added the option to enable or disable gradient checkpointing for transformer token embedders via boolean parameter `gradient_checkpointing`.
6871

6972
### Removed

allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def forward(
105105
span_embeddings_sum = span_embeddings.sum(2)
106106
span_embeddings_len = span_mask.sum(2)
107107
# Shape: (batch_size, num_orig_tokens, embedding_size)
108-
orig_embeddings = span_embeddings_sum / span_embeddings_len
108+
orig_embeddings = span_embeddings_sum / torch.clamp_min(span_embeddings_len, 1)
109109

110110
# All the places where the span length is zero, write in zeros.
111111
orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0

tests/modules/token_embedders/pretrained_transformer_mismatched_embedder_test.py

+34
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from allennlp.data.instance import Instance
99
from allennlp.data.token_indexers import PretrainedTransformerMismatchedIndexer
1010
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
11+
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
1112
from allennlp.common.testing import AllenNlpTestCase
1213

1314

@@ -143,3 +144,36 @@ def test_token_without_wordpieces(self):
143144
assert not torch.isnan(bert_vectors).any()
144145
assert all(bert_vectors[0, 1] == 0)
145146
assert all(bert_vectors[1, 1] == 0)
147+
148+
def test_exotic_tokens_no_nan_grads(self):
149+
token_indexer = PretrainedTransformerMismatchedIndexer("bert-base-uncased")
150+
151+
sentence1 = ["A", "", "AllenNLP", "sentence", "."]
152+
sentence2 = ["A", "\uf732\uf730\uf730\uf733", "AllenNLP", "sentence", "."]
153+
154+
tokens1 = [Token(word) for word in sentence1]
155+
tokens2 = [Token(word) for word in sentence2]
156+
vocab = Vocabulary()
157+
158+
token_embedder = BasicTextFieldEmbedder(
159+
{"bert": PretrainedTransformerMismatchedEmbedder("bert-base-uncased")}
160+
)
161+
162+
instance1 = Instance({"tokens": TextField(tokens1, {"bert": token_indexer})})
163+
instance2 = Instance({"tokens": TextField(tokens2, {"bert": token_indexer})})
164+
165+
batch = Batch([instance1, instance2])
166+
batch.index_instances(vocab)
167+
168+
padding_lengths = batch.get_padding_lengths()
169+
tensor_dict = batch.as_tensor_dict(padding_lengths)
170+
tokens = tensor_dict["tokens"]
171+
172+
bert_vectors = token_embedder(tokens)
173+
test_loss = bert_vectors.mean()
174+
175+
test_loss.backward()
176+
177+
for name, param in token_embedder.named_parameters():
178+
grad = param.grad
179+
assert (grad is None) or (not torch.any(torch.isnan(grad)).item())

0 commit comments

Comments
 (0)