15
15
16
16
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
17
17
18
+ # This is the default list of tokens that should not be lowercased.
19
+ _NEVER_LOWERCASE = ['[UNK]' , '[SEP]' , '[PAD]' , '[CLS]' , '[MASK]' ]
20
+
21
+
18
22
class WordpieceIndexer (TokenIndexer [int ]):
19
23
"""
20
24
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
@@ -39,6 +43,14 @@ class WordpieceIndexer(TokenIndexer[int]):
39
43
maximum length for its input ids. Currently any inputs longer than this
40
44
will be truncated. If this behavior is undesirable to you, you should
41
45
consider filtering them out in your dataset reader.
46
+ do_lowercase : ``bool``, optional (default=``False``)
47
+ Should we lowercase the provided tokens before getting the indices?
48
+ You would need to do this if you are using an -uncased BERT model
49
+ but your DatasetReader is not lowercasing tokens (which might be the
50
+ case if you're also using other embeddings based on cased tokens).
51
+ never_lowercase: ``List[str]``, optional
52
+ Tokens that should never be lowercased. Default is
53
+ ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
42
54
start_tokens : ``List[str]``, optional (default=``None``)
43
55
These are prepended to the tokens provided to ``tokens_to_indices``.
44
56
end_tokens : ``List[str]``, optional (default=``None``)
@@ -50,6 +62,8 @@ def __init__(self,
50
62
namespace : str = "wordpiece" ,
51
63
use_starting_offsets : bool = False ,
52
64
max_pieces : int = 512 ,
65
+ do_lowercase : bool = False ,
66
+ never_lowercase : List [str ] = None ,
53
67
start_tokens : List [str ] = None ,
54
68
end_tokens : List [str ] = None ) -> None :
55
69
self .vocab = vocab
@@ -64,6 +78,13 @@ def __init__(self,
64
78
self ._added_to_vocabulary = False
65
79
self .max_pieces = max_pieces
66
80
self .use_starting_offsets = use_starting_offsets
81
+ self ._do_lowercase = do_lowercase
82
+
83
+ if never_lowercase is None :
84
+ # Use the defaults
85
+ self ._never_lowercase = set (_NEVER_LOWERCASE )
86
+ else :
87
+ self ._never_lowercase = set (never_lowercase )
67
88
68
89
# Convert the start_tokens and end_tokens to wordpiece_ids
69
90
self ._start_piece_ids = [vocab [wordpiece ]
@@ -108,8 +129,12 @@ def tokens_to_indices(self,
108
129
offset = len (wordpiece_ids ) if self .use_starting_offsets else len (wordpiece_ids ) - 1
109
130
110
131
for token in tokens :
132
+ # Lowercase if necessary
133
+ text = (token .text .lower ()
134
+ if self ._do_lowercase and token .text not in self ._never_lowercase
135
+ else token .text )
111
136
token_wordpiece_ids = [self .vocab [wordpiece ]
112
- for wordpiece in self .wordpiece_tokenizer (token . text )]
137
+ for wordpiece in self .wordpiece_tokenizer (text )]
113
138
# If we have enough room to add these ids *and also* the end_token ids.
114
139
if len (wordpiece_ids ) + len (token_wordpiece_ids ) + len (self ._end_piece_ids ) <= self .max_pieces :
115
140
# For initial offsets, the current value of ``offset`` is the start of
@@ -189,6 +214,9 @@ class PretrainedBertIndexer(WordpieceIndexer):
189
214
they will instead correspond to the first wordpiece in each word.
190
215
do_lowercase: ``bool``, optional (default = True)
191
216
Whether to lowercase the tokens before converting to wordpiece ids.
217
+ never_lowercase: ``List[str]``, optional
218
+ Tokens that should never be lowercased. Default is
219
+ ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
192
220
max_pieces: int, optional (default: 512)
193
221
The BERT embedder uses positional embeddings and so has a corresponding
194
222
maximum length for its input ids. Currently any inputs longer than this
@@ -199,12 +227,22 @@ def __init__(self,
199
227
pretrained_model : str ,
200
228
use_starting_offsets : bool = False ,
201
229
do_lowercase : bool = True ,
230
+ never_lowercase : List [str ] = None ,
202
231
max_pieces : int = 512 ) -> None :
232
+ if pretrained_model .endswith ("-cased" ) and do_lowercase :
233
+ logger .warning ("Your BERT model appears to be cased, "
234
+ "but your indexer is lowercasing tokens." )
235
+ elif pretrained_model .endswith ("-uncased" ) and not do_lowercase :
236
+ logger .warning ("Your BERT model appears to be uncased, "
237
+ "but your indexer is not lowercasing tokens." )
238
+
203
239
bert_tokenizer = BertTokenizer .from_pretrained (pretrained_model , do_lower_case = do_lowercase )
204
240
super ().__init__ (vocab = bert_tokenizer .vocab ,
205
241
wordpiece_tokenizer = bert_tokenizer .wordpiece_tokenizer .tokenize ,
206
242
namespace = "bert" ,
207
243
use_starting_offsets = use_starting_offsets ,
208
244
max_pieces = max_pieces ,
245
+ do_lowercase = do_lowercase ,
246
+ never_lowercase = never_lowercase ,
209
247
start_tokens = ["[CLS]" ],
210
248
end_tokens = ["[SEP]" ])
0 commit comments