Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 7e08298

Browse files
maksymbevzajoelgrus
authored andcommitted
Fix wordpiece indexer truncation (#2931)
* Fix wordpiece indexer * Add comments for test and count pieces accumulated
1 parent 03aa838 commit 7e08298

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

allennlp/data/token_indexers/wordpiece_indexer.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,12 @@ def tokens_to_indices(self,
181181
# offset is the last wordpiece of "tokens[-1]".
182182
offset = len(self._start_piece_ids) if self.use_starting_offsets else len(self._start_piece_ids) - 1
183183

184+
# Count amount of wordpieces accumulated
185+
pieces_accumulated = 0
184186
for token in token_wordpiece_ids:
185187
# Truncate the sequence if specified, which depends on where the offsets are
186188
next_offset = 1 if self.use_starting_offsets else 0
187-
if self._truncate_long_sequences and offset >= window_length + next_offset:
189+
if self._truncate_long_sequences and offset + len(token) - 1 >= window_length + next_offset:
188190
break
189191

190192
# For initial offsets, the current value of ``offset`` is the start of
@@ -198,15 +200,17 @@ def tokens_to_indices(self,
198200
offset += len(token)
199201
offsets.append(offset)
200202

203+
pieces_accumulated += len(token)
204+
201205
if len(flat_wordpiece_ids) <= window_length:
202206
# If all the wordpieces fit, then we don't need to do anything special
203207
wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids)]
204208
token_type_ids = self._extend(flat_token_type_ids)
205209
elif self._truncate_long_sequences:
206210
logger.warning("Too many wordpieces, truncating sequence. If you would like a sliding window, set"
207211
"`truncate_long_sequences` to False %s", str([token.text for token in tokens]))
208-
wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[:window_length])]
209-
token_type_ids = self._extend(flat_token_type_ids[:window_length])
212+
wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[:pieces_accumulated])]
213+
token_type_ids = self._extend(flat_token_type_ids[:pieces_accumulated])
210214
else:
211215
# Create a sliding window of wordpieces of length `max_pieces` that advances by `stride` steps and
212216
# add start/end wordpieces to each window

allennlp/tests/data/token_indexers/bert_indexer_test.py

+80
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,83 @@ def test_truncate_window(self):
203203
# 1 full window + 1 half window with start/end tokens
204204
assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17]
205205
assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8]
206+
207+
def test_truncate_window_dont_split_wordpieces(self):
208+
"""
209+
Tests if the sentence is not truncated inside of the word with 2 or
210+
more wordpieces.
211+
"""
212+
213+
tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())
214+
215+
sentence = "the quickest quick brown fox jumped over the quickest dog"
216+
tokens = tokenizer.tokenize(sentence)
217+
218+
vocab = Vocabulary()
219+
vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
220+
token_indexer = PretrainedBertIndexer(str(vocab_path),
221+
truncate_long_sequences=True,
222+
use_starting_offsets=True,
223+
max_pieces=12)
224+
225+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
226+
227+
# 16 = [CLS], 17 = [SEP]
228+
# 1 full window + 1 half window with start/end tokens
229+
assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
230+
# We could fit one more piece here, but we don't, not to have a cut
231+
# in the middle of the word
232+
assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8, 9]
233+
assert indexed_tokens["bert-type-ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
234+
235+
token_indexer = PretrainedBertIndexer(str(vocab_path),
236+
truncate_long_sequences=True,
237+
use_starting_offsets=False,
238+
max_pieces=12)
239+
240+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
241+
242+
# 16 = [CLS], 17 = [SEP]
243+
# 1 full window + 1 half window with start/end tokens
244+
assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
245+
# We could fit one more piece here, but we don't, not to have a cut
246+
# in the middle of the word
247+
assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9]
248+
249+
def test_truncate_window_fit_two_wordpieces(self):
250+
"""
251+
Tests if the both `use_starting_offsets` options work properly when last
252+
word in the truncated sentence consists of two wordpieces.
253+
"""
254+
255+
tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())
256+
257+
sentence = "the quickest quick brown fox jumped over the quickest dog"
258+
tokens = tokenizer.tokenize(sentence)
259+
260+
vocab = Vocabulary()
261+
vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
262+
token_indexer = PretrainedBertIndexer(str(vocab_path),
263+
truncate_long_sequences=True,
264+
use_starting_offsets=True,
265+
max_pieces=13)
266+
267+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
268+
269+
# 16 = [CLS], 17 = [SEP]
270+
# 1 full window + 1 half window with start/end tokens
271+
assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
272+
assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8, 9, 10]
273+
assert indexed_tokens["bert-type-ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
274+
275+
token_indexer = PretrainedBertIndexer(str(vocab_path),
276+
truncate_long_sequences=True,
277+
use_starting_offsets=False,
278+
max_pieces=13)
279+
280+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
281+
282+
# 16 = [CLS], 17 = [SEP]
283+
# 1 full window + 1 half window with start/end tokens
284+
assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
285+
assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 11]

0 commit comments

Comments
 (0)