Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 02e2930

Browse files
OyvindTafjordDeNeutoy
authored andcommitted
Model can store extra pretained embeddings (#1817)
This adds a `min_pretrained_embeddings` parameter to the `"embedding"` token embedder which will keep at least that many embeddings from the top of an embedding text file (like Glove). This is useful as a pragmatic way to support unseen words, say in a demo, at a cost of larger model size (e.g., specifying 200k here will increase model.tar.gz by about 75MB). This leverages the fact that at least for Glove files, the words are ordered by frequency, so by the time you get to, say, 200k, you're mostly missing out on rare words where the embeddings are less useful anyway. I'd use this in the QuaRel demo.
1 parent f65ced5 commit 02e2930

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

allennlp/modules/token_embedders/embedding.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
182182
norm_type = params.pop_float('norm_type', 2.)
183183
scale_grad_by_freq = params.pop_bool('scale_grad_by_freq', False)
184184
sparse = params.pop_bool('sparse', False)
185+
min_pretrained_embeddings = params.pop_int("min_pretrained_embeddings", 0)
185186
params.assert_empty(cls.__name__)
186187

187188
if pretrained_file:
@@ -191,7 +192,10 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
191192
weight = _read_pretrained_embeddings_file(pretrained_file,
192193
embedding_dim,
193194
vocab,
194-
vocab_namespace)
195+
vocab_namespace,
196+
min_pretrained_embeddings)
197+
if min_pretrained_embeddings > 0:
198+
num_embeddings = vocab.get_vocab_size(vocab_namespace)
195199
else:
196200
weight = None
197201

@@ -210,7 +214,8 @@ def from_params(cls, vocab: Vocabulary, params: Params) -> 'Embedding': # type:
210214
def _read_pretrained_embeddings_file(file_uri: str,
211215
embedding_dim: int,
212216
vocab: Vocabulary,
213-
namespace: str = "tokens") -> torch.FloatTensor:
217+
namespace: str = "tokens",
218+
min_pretrained_embeddings: int = None) -> torch.FloatTensor:
214219
"""
215220
Returns and embedding matrix for the given vocabulary using the pretrained embeddings
216221
contained in the given file. Embeddings for tokens not found in the pretrained embedding file
@@ -244,8 +249,9 @@ def _read_pretrained_embeddings_file(file_uri: str,
244249
A Vocabulary object.
245250
namespace : str, (optional, default=tokens)
246251
The namespace of the vocabulary to find pretrained embeddings for.
247-
trainable : bool, (optional, default=True)
248-
Whether or not the embedding parameters should be optimized.
252+
min_pretrained_embeddings : int, (optional, default=None):
253+
If given, will keep at least this number of embeddings from the start of the pretrained
254+
embedding text file (typically the most common words)
249255
250256
Returns
251257
-------
@@ -261,13 +267,14 @@ def _read_pretrained_embeddings_file(file_uri: str,
261267

262268
return _read_embeddings_from_text_file(file_uri,
263269
embedding_dim,
264-
vocab, namespace)
270+
vocab, namespace, min_pretrained_embeddings)
265271

266272

267273
def _read_embeddings_from_text_file(file_uri: str,
268274
embedding_dim: int,
269275
vocab: Vocabulary,
270-
namespace: str = "tokens") -> torch.FloatTensor:
276+
namespace: str = "tokens",
277+
min_pretrained_embeddings: int = 0) -> torch.FloatTensor:
271278
"""
272279
Read pre-trained word vectors from an eventually compressed text file, possibly contained
273280
inside an archive with multiple files. The text file is assumed to be utf-8 encoded with
@@ -278,16 +285,15 @@ def _read_embeddings_from_text_file(file_uri: str,
278285
The remainder of the docstring is identical to ``_read_pretrained_embeddings_file``.
279286
"""
280287
tokens_to_keep = set(vocab.get_index_to_token_vocabulary(namespace).values())
281-
vocab_size = vocab.get_vocab_size(namespace)
282288
embeddings = {}
283289

284290
# First we read the embeddings from the file, only keeping vectors for the words we need.
285291
logger.info("Reading pretrained embeddings from file")
286292

287293
with EmbeddingsTextFile(file_uri) as embeddings_file:
288-
for line in Tqdm.tqdm(embeddings_file):
294+
for index, line in Tqdm.tqdm(enumerate(embeddings_file)):
289295
token = line.split(' ', 1)[0]
290-
if token in tokens_to_keep:
296+
if token in tokens_to_keep or index < min_pretrained_embeddings:
291297
fields = line.rstrip().split(' ')
292298
if len(fields) - 1 != embedding_dim:
293299
# Sometimes there are funny unicode parsing problems that lead to different
@@ -303,6 +309,10 @@ def _read_embeddings_from_text_file(file_uri: str,
303309

304310
vector = numpy.asarray(fields[1:], dtype='float32')
305311
embeddings[token] = vector
312+
if token not in tokens_to_keep:
313+
vocab.add_token_to_namespace(token, namespace)
314+
315+
vocab_size = vocab.get_vocab_size(namespace)
306316

307317
if not embeddings:
308318
raise ConfigurationError("No embeddings of correct dimension found; you probably "

allennlp/tests/modules/token_embedders/embedding_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,20 @@ def test_forward_works_with_projection_layer(self):
5353
embedded = embedding_layer(input_tensor).data.numpy()
5454
assert embedded.shape == (1, 1, 4, 20)
5555

56+
def test_min_pretrained_embeddings(self):
57+
vocab = Vocabulary()
58+
vocab.add_token_to_namespace('the')
59+
vocab.add_token_to_namespace('a')
60+
params = Params({
61+
'pretrained_file': str(self.FIXTURES_ROOT / 'embeddings/glove.6B.100d.sample.txt.gz'),
62+
'embedding_dim': 100,
63+
'min_pretrained_embeddings': 50
64+
})
65+
# This will now update vocab
66+
_ = Embedding.from_params(vocab, params)
67+
assert vocab.get_vocab_size() >= 50
68+
assert vocab.get_token_index("his") > 1 # not @@UNKNOWN@@
69+
5670
def test_embedding_layer_actually_initializes_word_vectors_correctly(self):
5771
vocab = Vocabulary()
5872
vocab.add_token_to_namespace("word")

0 commit comments

Comments
 (0)