Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a6a600d

Browse files
committed
Fixes token type ids for folded sequences (#5149)
* Fixes token type ids for folded sequences * Changelog * Save memory on the GitHub test runners * Tensors have to be on the same device
1 parent 402bc78 commit a6a600d

File tree

4 files changed

+66
-10
lines changed

4 files changed

+66
-10
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
1313

1414

15+
## Unreleased
16+
17+
### Fixed
18+
19+
- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.
20+
21+
1522
## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22
1623

1724
### Added

allennlp/data/token_indexers/pretrained_transformer_indexer.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,52 @@ def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
154154
# TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.
155155

156156
indices = output["token_ids"]
157+
type_ids = output.get("type_ids", [0] * len(indices))
158+
157159
# Strips original special tokens
158160
indices = indices[
159161
self._num_added_start_tokens : len(indices) - self._num_added_end_tokens
160162
]
163+
type_ids = type_ids[
164+
self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens
165+
]
166+
161167
# Folds indices
162168
folded_indices = [
163169
indices[i : i + self._effective_max_length]
164170
for i in range(0, len(indices), self._effective_max_length)
165171
]
172+
folded_type_ids = [
173+
type_ids[i : i + self._effective_max_length]
174+
for i in range(0, len(type_ids), self._effective_max_length)
175+
]
176+
166177
# Adds special tokens to each segment
167178
folded_indices = [
168179
self._tokenizer.build_inputs_with_special_tokens(segment)
169180
for segment in folded_indices
170181
]
182+
single_sequence_start_type_ids = [
183+
t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens
184+
]
185+
single_sequence_end_type_ids = [
186+
t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens
187+
]
188+
folded_type_ids = [
189+
single_sequence_start_type_ids + segment + single_sequence_end_type_ids
190+
for segment in folded_type_ids
191+
]
192+
assert all(
193+
len(segment_indices) == len(segment_type_ids)
194+
for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids)
195+
)
196+
171197
# Flattens
172198
indices = [i for segment in folded_indices for i in segment]
199+
type_ids = [i for segment in folded_type_ids for i in segment]
173200

174201
output["token_ids"] = indices
175-
output["type_ids"] = [0] * len(indices)
202+
output["type_ids"] = type_ids
176203
output["segment_concat_mask"] = [True] * len(indices)
177204

178205
return output

tests/data/token_indexers/pretrained_transformer_indexer_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,22 @@ def test_long_sequence_splitting(self):
163163
assert indexed["segment_concat_mask"] == [True] * len(expected_ids)
164164
assert indexed["mask"] == [True] * 7 # original length
165165

166+
def test_type_ids_when_folding(self):
167+
allennlp_tokenizer = PretrainedTransformerTokenizer(
168+
"bert-base-uncased", add_special_tokens=False
169+
)
170+
indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased", max_length=6)
171+
first_string = "How do trees get online?"
172+
second_string = "They log in!"
173+
174+
tokens = allennlp_tokenizer.add_special_tokens(
175+
allennlp_tokenizer.tokenize(first_string), allennlp_tokenizer.tokenize(second_string)
176+
)
177+
vocab = Vocabulary()
178+
indexed = indexer.tokens_to_indices(tokens, vocab)
179+
assert min(indexed["type_ids"]) == 0
180+
assert max(indexed["type_ids"]) == 1
181+
166182
@staticmethod
167183
def _assert_tokens_equal(expected_tokens, actual_tokens):
168184
for expected, actual in zip(expected_tokens, actual_tokens):

tests/modules/token_embedders/pretrained_transformer_embedder_test.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from allennlp.common import Params
6-
from allennlp.common.testing import AllenNlpTestCase
6+
from allennlp.common.testing import AllenNlpTestCase, requires_gpu
77
from allennlp.data import Vocabulary
88
from allennlp.data.batch import Batch
99
from allennlp.data.fields import TextField
@@ -15,14 +15,15 @@
1515

1616

1717
class TestPretrainedTransformerEmbedder(AllenNlpTestCase):
18+
@requires_gpu
1819
def test_forward_runs_when_initialized_from_params(self):
1920
# This code just passes things off to `transformers`, so we only have a very simple
2021
# test.
2122
params = Params({"model_name": "bert-base-uncased"})
22-
embedder = PretrainedTransformerEmbedder.from_params(params)
23+
embedder = PretrainedTransformerEmbedder.from_params(params).cuda()
2324
token_ids = torch.randint(0, 100, (1, 4))
2425
mask = torch.randint(0, 2, (1, 4)).bool()
25-
output = embedder(token_ids=token_ids, mask=mask)
26+
output = embedder(token_ids=token_ids.cuda(), mask=mask.cuda())
2627
assert tuple(output.size()) == (1, 4, 768)
2728

2829
@pytest.mark.parametrize(
@@ -169,22 +170,24 @@ def test_end_to_end_t5(
169170
assert bert_vectors.size() == (2, 8, 64)
170171
assert bert_vectors.requires_grad == (train_parameters or not last_layer_only)
171172

173+
@requires_gpu
172174
def test_big_token_type_ids(self):
173-
token_embedder = PretrainedTransformerEmbedder("roberta-base")
175+
token_embedder = PretrainedTransformerEmbedder("roberta-base").cuda()
174176
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
175177
mask = torch.ones_like(token_ids).bool()
176178
type_ids = torch.zeros_like(token_ids)
177179
type_ids[1, 1] = 1
178180
with pytest.raises(ValueError):
179-
token_embedder(token_ids, mask, type_ids)
181+
token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda())
180182

183+
@requires_gpu
181184
def test_xlnet_token_type_ids(self):
182-
token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased")
185+
token_embedder = PretrainedTransformerEmbedder("xlnet-base-cased").cuda()
183186
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
184187
mask = torch.ones_like(token_ids).bool()
185188
type_ids = torch.zeros_like(token_ids)
186189
type_ids[1, 1] = 1
187-
token_embedder(token_ids, mask, type_ids)
190+
token_embedder(token_ids.cuda(), mask.cuda(), type_ids.cuda())
188191

189192
def test_long_sequence_splitting_end_to_end(self):
190193
# Mostly the same as the end_to_end test (except for adding max_length=4),
@@ -310,11 +313,14 @@ def test_unfold_long_sequences(self):
310313
)
311314
assert (unfolded_embeddings_out == unfolded_embeddings).all()
312315

316+
@requires_gpu
313317
def test_encoder_decoder_model(self):
314-
token_embedder = PretrainedTransformerEmbedder("facebook/bart-large", sub_module="encoder")
318+
token_embedder = PretrainedTransformerEmbedder(
319+
"facebook/bart-large", sub_module="encoder"
320+
).cuda()
315321
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
316322
mask = torch.ones_like(token_ids).bool()
317-
token_embedder(token_ids, mask)
323+
token_embedder(token_ids.cuda(), mask.cuda())
318324

319325
def test_embeddings_resize(self):
320326
regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased")

0 commit comments

Comments
 (0)