Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Fixes token type ids for folded sequences #5149

Merged
merged 6 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## Unreleased

### Fixed

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22

### Added
Expand Down
29 changes: 28 additions & 1 deletion allennlp/data/token_indexers/pretrained_transformer_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,52 @@ def _postprocess_output(self, output: IndexedTokenList) -> IndexedTokenList:
# TODO(zhaofengw): we aren't respecting word boundaries when segmenting wordpieces.

indices = output["token_ids"]
type_ids = output.get("type_ids", [0] * len(indices))

# Strips original special tokens
indices = indices[
self._num_added_start_tokens : len(indices) - self._num_added_end_tokens
]
type_ids = type_ids[
self._num_added_start_tokens : len(type_ids) - self._num_added_end_tokens
]

# Folds indices
folded_indices = [
indices[i : i + self._effective_max_length]
for i in range(0, len(indices), self._effective_max_length)
]
folded_type_ids = [
type_ids[i : i + self._effective_max_length]
for i in range(0, len(type_ids), self._effective_max_length)
]

# Adds special tokens to each segment
folded_indices = [
self._tokenizer.build_inputs_with_special_tokens(segment)
for segment in folded_indices
]
single_sequence_start_type_ids = [
t.type_id for t in self._allennlp_tokenizer.single_sequence_start_tokens
]
single_sequence_end_type_ids = [
t.type_id for t in self._allennlp_tokenizer.single_sequence_end_tokens
]
folded_type_ids = [
single_sequence_start_type_ids + segment + single_sequence_end_type_ids
for segment in folded_type_ids
]
assert all(
len(segment_indices) == len(segment_type_ids)
for segment_indices, segment_type_ids in zip(folded_indices, folded_type_ids)
)

# Flattens
indices = [i for segment in folded_indices for i in segment]
type_ids = [i for segment in folded_type_ids for i in segment]

output["token_ids"] = indices
output["type_ids"] = [0] * len(indices)
output["type_ids"] = type_ids
output["segment_concat_mask"] = [True] * len(indices)

return output
Expand Down
16 changes: 16 additions & 0 deletions tests/data/token_indexers/pretrained_transformer_indexer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,22 @@ def test_long_sequence_splitting(self):
assert indexed["segment_concat_mask"] == [True] * len(expected_ids)
assert indexed["mask"] == [True] * 7 # original length

def test_type_ids_when_folding(self):
allennlp_tokenizer = PretrainedTransformerTokenizer(
"bert-base-uncased", add_special_tokens=False
)
indexer = PretrainedTransformerIndexer(model_name="bert-base-uncased", max_length=6)
first_string = "How do trees get online?"
second_string = "They log in!"

tokens = allennlp_tokenizer.add_special_tokens(
allennlp_tokenizer.tokenize(first_string), allennlp_tokenizer.tokenize(second_string)
)
vocab = Vocabulary()
indexed = indexer.tokens_to_indices(tokens, vocab)
assert min(indexed["type_ids"]) == 0
assert max(indexed["type_ids"]) == 1

@staticmethod
def _assert_tokens_equal(expected_tokens, actual_tokens):
for expected, actual in zip(expected_tokens, actual_tokens):
Expand Down