@@ -182,6 +182,7 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
182
182
norm_type = params .pop_float ('norm_type' , 2. )
183
183
scale_grad_by_freq = params .pop_bool ('scale_grad_by_freq' , False )
184
184
sparse = params .pop_bool ('sparse' , False )
185
+ min_pretrained_embeddings = params .pop_int ("min_pretrained_embeddings" , 0 )
185
186
params .assert_empty (cls .__name__ )
186
187
187
188
if pretrained_file :
@@ -191,7 +192,10 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
191
192
weight = _read_pretrained_embeddings_file (pretrained_file ,
192
193
embedding_dim ,
193
194
vocab ,
194
- vocab_namespace )
195
+ vocab_namespace ,
196
+ min_pretrained_embeddings )
197
+ if min_pretrained_embeddings > 0 :
198
+ num_embeddings = vocab .get_vocab_size (vocab_namespace )
195
199
else :
196
200
weight = None
197
201
@@ -210,7 +214,8 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
210
214
def _read_pretrained_embeddings_file (file_uri : str ,
211
215
embedding_dim : int ,
212
216
vocab : Vocabulary ,
213
- namespace : str = "tokens" ) -> torch .FloatTensor :
217
+ namespace : str = "tokens" ,
218
+ min_pretrained_embeddings : int = None ) -> torch .FloatTensor :
214
219
"""
215
220
Returns and embedding matrix for the given vocabulary using the pretrained embeddings
216
221
contained in the given file. Embeddings for tokens not found in the pretrained embedding file
@@ -244,8 +249,9 @@ def _read_pretrained_embeddings_file(file_uri: str,
244
249
A Vocabulary object.
245
250
namespace : str, (optional, default=tokens)
246
251
The namespace of the vocabulary to find pretrained embeddings for.
247
- trainable : bool, (optional, default=True)
248
- Whether or not the embedding parameters should be optimized.
252
+ min_pretrained_embeddings : int, (optional, default=None):
253
+ If given, will keep at least this number of embeddings from the start of the pretrained
254
+ embedding text file (typically the most common words)
249
255
250
256
Returns
251
257
-------
@@ -261,13 +267,14 @@ def _read_pretrained_embeddings_file(file_uri: str,
261
267
262
268
return _read_embeddings_from_text_file (file_uri ,
263
269
embedding_dim ,
264
- vocab , namespace )
270
+ vocab , namespace , min_pretrained_embeddings )
265
271
266
272
267
273
def _read_embeddings_from_text_file (file_uri : str ,
268
274
embedding_dim : int ,
269
275
vocab : Vocabulary ,
270
- namespace : str = "tokens" ) -> torch .FloatTensor :
276
+ namespace : str = "tokens" ,
277
+ min_pretrained_embeddings : int = 0 ) -> torch .FloatTensor :
271
278
"""
272
279
Read pre-trained word vectors from an eventually compressed text file, possibly contained
273
280
inside an archive with multiple files. The text file is assumed to be utf-8 encoded with
@@ -278,16 +285,15 @@ def _read_embeddings_from_text_file(file_uri: str,
278
285
The remainder of the docstring is identical to ``_read_pretrained_embeddings_file``.
279
286
"""
280
287
tokens_to_keep = set (vocab .get_index_to_token_vocabulary (namespace ).values ())
281
- vocab_size = vocab .get_vocab_size (namespace )
282
288
embeddings = {}
283
289
284
290
# First we read the embeddings from the file, only keeping vectors for the words we need.
285
291
logger .info ("Reading pretrained embeddings from file" )
286
292
287
293
with EmbeddingsTextFile (file_uri ) as embeddings_file :
288
- for line in Tqdm .tqdm (embeddings_file ):
294
+ for index , line in Tqdm .tqdm (enumerate ( embeddings_file ) ):
289
295
token = line .split (' ' , 1 )[0 ]
290
- if token in tokens_to_keep :
296
+ if token in tokens_to_keep or index < min_pretrained_embeddings :
291
297
fields = line .rstrip ().split (' ' )
292
298
if len (fields ) - 1 != embedding_dim :
293
299
# Sometimes there are funny unicode parsing problems that lead to different
@@ -303,6 +309,10 @@ def _read_embeddings_from_text_file(file_uri: str,
303
309
304
310
vector = numpy .asarray (fields [1 :], dtype = 'float32' )
305
311
embeddings [token ] = vector
312
+ if token not in tokens_to_keep :
313
+ vocab .add_token_to_namespace (token , namespace )
314
+
315
+ vocab_size = vocab .get_vocab_size (namespace )
306
316
307
317
if not embeddings :
308
318
raise ConfigurationError ("No embeddings of correct dimension found; you probably "
0 commit comments