Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 3fd224f

Browse files
authored
fix lowercase-ization in bert indexer (#2205)
* fix lowercase-ization in bert indexer * add never_lowercase feature for [UNK], etc * add warnings when BERT model appears incongruent with do_lowercase
1 parent a2084fd commit 3fd224f

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

allennlp/data/token_indexers/wordpiece_indexer.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
1717

18+
# This is the default list of tokens that should not be lowercased.
19+
_NEVER_LOWERCASE = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
20+
21+
1822
class WordpieceIndexer(TokenIndexer[int]):
1923
"""
2024
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
@@ -39,6 +43,14 @@ class WordpieceIndexer(TokenIndexer[int]):
3943
maximum length for its input ids. Currently any inputs longer than this
4044
will be truncated. If this behavior is undesirable to you, you should
4145
consider filtering them out in your dataset reader.
46+
do_lowercase : ``bool``, optional (default=``False``)
47+
Should we lowercase the provided tokens before getting the indices?
48+
You would need to do this if you are using an -uncased BERT model
49+
but your DatasetReader is not lowercasing tokens (which might be the
50+
case if you're also using other embeddings based on cased tokens).
51+
never_lowercase: ``List[str]``, optional
52+
Tokens that should never be lowercased. Default is
53+
['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
4254
start_tokens : ``List[str]``, optional (default=``None``)
4355
These are prepended to the tokens provided to ``tokens_to_indices``.
4456
end_tokens : ``List[str]``, optional (default=``None``)
@@ -50,6 +62,8 @@ def __init__(self,
5062
namespace: str = "wordpiece",
5163
use_starting_offsets: bool = False,
5264
max_pieces: int = 512,
65+
do_lowercase: bool = False,
66+
never_lowercase: List[str] = None,
5367
start_tokens: List[str] = None,
5468
end_tokens: List[str] = None) -> None:
5569
self.vocab = vocab
@@ -64,6 +78,13 @@ def __init__(self,
6478
self._added_to_vocabulary = False
6579
self.max_pieces = max_pieces
6680
self.use_starting_offsets = use_starting_offsets
81+
self._do_lowercase = do_lowercase
82+
83+
if never_lowercase is None:
84+
# Use the defaults
85+
self._never_lowercase = set(_NEVER_LOWERCASE)
86+
else:
87+
self._never_lowercase = set(never_lowercase)
6788

6889
# Convert the start_tokens and end_tokens to wordpiece_ids
6990
self._start_piece_ids = [vocab[wordpiece]
@@ -108,8 +129,12 @@ def tokens_to_indices(self,
108129
offset = len(wordpiece_ids) if self.use_starting_offsets else len(wordpiece_ids) - 1
109130

110131
for token in tokens:
132+
# Lowercase if necessary
133+
text = (token.text.lower()
134+
if self._do_lowercase and token.text not in self._never_lowercase
135+
else token.text)
111136
token_wordpiece_ids = [self.vocab[wordpiece]
112-
for wordpiece in self.wordpiece_tokenizer(token.text)]
137+
for wordpiece in self.wordpiece_tokenizer(text)]
113138
# If we have enough room to add these ids *and also* the end_token ids.
114139
if len(wordpiece_ids) + len(token_wordpiece_ids) + len(self._end_piece_ids) <= self.max_pieces:
115140
# For initial offsets, the current value of ``offset`` is the start of
@@ -189,6 +214,9 @@ class PretrainedBertIndexer(WordpieceIndexer):
189214
they will instead correspond to the first wordpiece in each word.
190215
do_lowercase: ``bool``, optional (default = True)
191216
Whether to lowercase the tokens before converting to wordpiece ids.
217+
never_lowercase: ``List[str]``, optional
218+
Tokens that should never be lowercased. Default is
219+
['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
192220
max_pieces: int, optional (default: 512)
193221
The BERT embedder uses positional embeddings and so has a corresponding
194222
maximum length for its input ids. Currently any inputs longer than this
@@ -199,12 +227,22 @@ def __init__(self,
199227
pretrained_model: str,
200228
use_starting_offsets: bool = False,
201229
do_lowercase: bool = True,
230+
never_lowercase: List[str] = None,
202231
max_pieces: int = 512) -> None:
232+
if pretrained_model.endswith("-cased") and do_lowercase:
233+
logger.warning("Your BERT model appears to be cased, "
234+
"but your indexer is lowercasing tokens.")
235+
elif pretrained_model.endswith("-uncased") and not do_lowercase:
236+
logger.warning("Your BERT model appears to be uncased, "
237+
"but your indexer is not lowercasing tokens.")
238+
203239
bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model, do_lower_case=do_lowercase)
204240
super().__init__(vocab=bert_tokenizer.vocab,
205241
wordpiece_tokenizer=bert_tokenizer.wordpiece_tokenizer.tokenize,
206242
namespace="bert",
207243
use_starting_offsets=use_starting_offsets,
208244
max_pieces=max_pieces,
245+
do_lowercase=do_lowercase,
246+
never_lowercase=never_lowercase,
209247
start_tokens=["[CLS]"],
210248
end_tokens=["[SEP]"])

allennlp/tests/data/token_indexers/bert_indexer_test.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# pylint: disable=no-self-use,invalid-name
22
from allennlp.common.testing import ModelTestCase
33
from allennlp.data.token_indexers.wordpiece_indexer import PretrainedBertIndexer
4-
from allennlp.data.tokenizers import WordTokenizer
4+
from allennlp.data.tokenizers import WordTokenizer, Token
55
from allennlp.data.tokenizers.word_splitter import BertBasicWordSplitter
66
from allennlp.data.vocabulary import Vocabulary
77

88
class TestBertIndexer(ModelTestCase):
99

10+
1011
def test_starting_ending_offsets(self):
1112
tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())
1213

@@ -30,3 +31,56 @@ def test_starting_ending_offsets(self):
3031

3132
assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
3233
assert indexed_tokens["bert-offsets"] == [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]
34+
35+
36+
def test_do_lowercase(self):
37+
# Our default tokenizer doesn't handle lowercasing.
38+
tokenizer = WordTokenizer()
39+
40+
# Quick is UNK because of capitalization
41+
# 2 1 5 6 8 9 2 15 10 11 14 1
42+
sentence = "the Quick brown fox jumped over the laziest lazy elmo"
43+
tokens = tokenizer.tokenize(sentence)
44+
45+
vocab = Vocabulary()
46+
vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
47+
token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=False)
48+
49+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
50+
51+
# Quick should get 1 == OOV
52+
assert indexed_tokens["bert"] == [16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
53+
54+
# Does lowercasing by default
55+
token_indexer = PretrainedBertIndexer(str(vocab_path))
56+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
57+
58+
# Now Quick should get indexed correctly as 3 ( == "quick")
59+
assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
60+
61+
62+
def test_never_lowercase(self):
63+
# Our default tokenizer doesn't handle lowercasing.
64+
tokenizer = WordTokenizer()
65+
66+
# 2 15 10 11 6
67+
sentence = "the laziest fox"
68+
69+
tokens = tokenizer.tokenize(sentence)
70+
tokens.append(Token("[PAD]")) # have to do this b/c tokenizer splits it in three
71+
72+
vocab = Vocabulary()
73+
vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
74+
token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True)
75+
76+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
77+
78+
# PAD should get recognized and not lowercased # [PAD]
79+
assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17]
80+
81+
# Unless we manually override the never lowercases
82+
token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True, never_lowercase=())
83+
indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")
84+
85+
# now PAD should get lowercased and be UNK # [UNK]
86+
assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]

0 commit comments

Comments
 (0)