Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9719b5c

Browse files
HarshTrivedimatt-gardner
authored andcommitted
Allow embedding extension to load from pre-trained embeddings file. (#2387)
* Rough attempt for Embedder/Embedding extension. * fix some mistakes. * add tests for token-embedder and text-field-embedder extension. * fix vocab_namespace usage in embedding.py * update names and change some comments. * update embedding tests. * fix some typos. * add more tests. * update some comments. * fix minor pylint issue. * Implement extend_vocab for TokenCharactersEncoder. * minor simplification. * Update help text for --extend-vocab in fine-tune command. * Shift location of model tests appropriately. * Allow root Embedding in model to be extendable. * Incorporate PR comments in embedding.py. * Fix annotations. * Add appropriate docstrings and minor cleanup. * Resolve pylint complains. * shift disable pytlint protected-access to top of tests. * Add a test to increase coverage. * minor update in TokenEmbedder docstrings. * Allow to pass pretrained_file in embedding extension (with tests). * Remove a blank line. * Add a blank line before Returns block in Embedding docstring. * Fix pylint complains. * Allow pretrained file to be passed in token_characters_encoder also. * Fix pylint complains and update some comments. * Test to ensure trained embeddings do not get overriden. * PR feedback: update comments, fix annotation.
1 parent 174f539 commit 9719b5c

File tree

3 files changed

+71
-9
lines changed

3 files changed

+71
-9
lines changed

allennlp/modules/token_embedders/embedding.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,14 @@ def forward(self, inputs): # pylint: disable=arguments-differ
147147
return embedded
148148

149149
@overrides
150-
def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
150+
def extend_vocab(self, # pylint: disable=arguments-differ
151+
extended_vocab: Vocabulary,
152+
vocab_namespace: str = None,
153+
pretrained_file: str = None) -> None:
151154
"""
152155
Extends the embedding matrix according to the extended vocabulary.
153-
Extended weight would be initialized with xavier uniform.
156+
If pretrained_file is available, it will be used for initializing the new words
157+
in the extended vocabulary; otherwise they will be initialized with xavier uniform.
154158
155159
Parameters
156160
----------
@@ -162,6 +166,10 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
162166
can pass it. If not passed, it will check if vocab_namespace used at the
163167
time of ``Embedding`` construction is available. If so, this namespace
164168
will be used or else default 'tokens' namespace will be used.
169+
pretrained_file : str, (optional, default=None)
170+
A file containing pretrained embeddings can be specified here. It can be
171+
the path to a local file or an URL of a (cached) remote file. Check format
172+
details in ``from_params`` of ``Embedding`` class.
165173
"""
166174
# Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute,
167175
# knowing which is necessary at time of embedding vocab extension. So old archive models are
@@ -172,11 +180,19 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = None):
172180
vocab_namespace = "tokens"
173181
logging.warning("No vocab_namespace provided to Embedder.extend_vocab. Defaulting to 'tokens'.")
174182

175-
extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
176-
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
177183
embedding_dim = self.weight.data.shape[-1]
178-
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
179-
torch.nn.init.xavier_uniform_(extra_weight)
184+
if not pretrained_file:
185+
extended_num_embeddings = extended_vocab.get_vocab_size(vocab_namespace)
186+
extra_num_embeddings = extended_num_embeddings - self.num_embeddings
187+
extra_weight = torch.FloatTensor(extra_num_embeddings, embedding_dim)
188+
torch.nn.init.xavier_uniform_(extra_weight)
189+
else:
190+
# It's easiest to just reload the embeddings for the entire vocab,
191+
# then only keep the ones we need.
192+
whole_weight = _read_pretrained_embeddings_file(pretrained_file, embedding_dim,
193+
extended_vocab, vocab_namespace)
194+
extra_weight = whole_weight[self.num_embeddings:, :]
195+
180196
extended_weight = torch.cat([self.weight.data, extra_weight], dim=0)
181197
self.weight = torch.nn.Parameter(extended_weight, requires_grad=self.weight.requires_grad)
182198

allennlp/modules/token_embedders/token_characters_encoder.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ def forward(self, token_characters: torch.Tensor) -> torch.Tensor: # pylint: di
3737
return self._dropout(self._encoder(self._embedding(token_characters), mask))
3838

3939
@overrides
40-
def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = "token_characters"):
40+
def extend_vocab(self, # pylint: disable=arguments-differ
41+
extended_vocab: Vocabulary,
42+
vocab_namespace: str = "token_characters",
43+
pretrained_file: str = None) -> None:
4144
"""
4245
Extends the embedding module according to the extended vocabulary.
43-
Extended weight would be initialized with xavier uniform.
46+
If pretrained_file is available, it will be used for initializing the new words
47+
in the extended vocabulary; otherwise they will be initialized with xavier uniform.
4448
4549
Parameters
4650
----------
@@ -52,11 +56,17 @@ def extend_vocab(self, extended_vocab: Vocabulary, vocab_namespace: str = "token
5256
you can pass it here. If not passed, it will check if vocab_namespace used
5357
at the time of ``TokenCharactersEncoder`` construction is available. If so, this
5458
namespace will be used or else default 'token_characters' namespace will be used.
59+
pretrained_file : str, (optional, default=None)
60+
A file containing pretrained embeddings can be specified here. It can be
61+
the path to a local file or an URL of a (cached) remote file. Check format
62+
details in ``from_params`` of ``Embedding`` class.
5563
"""
5664
# Caveat: For allennlp v0.8.1 and below, we weren't storing vocab_namespace as an attribute, knowing
5765
# which is necessary at time of token_characters_encoder vocab extension. So old archive models are
5866
# currently unextendable unless the user used default vocab_namespace 'token_characters' for it.
59-
self._embedding._module.extend_vocab(extended_vocab, vocab_namespace) # pylint: disable=protected-access
67+
self._embedding._module.extend_vocab(extended_vocab, # pylint: disable=protected-access
68+
vocab_namespace=vocab_namespace,
69+
pretrained_file=pretrained_file)
6070

6171
# The setdefault requires a custom from_params
6272
@classmethod

allennlp/tests/modules/token_embedders/embedding_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,39 @@ def test_embedding_vocab_extension_without_stored_namespace(self):
281281
extended_weight = embedder.weight
282282
assert extended_weight.shape[0] == 5
283283
assert torch.all(extended_weight[:4, :] == original_weight[:4, :])
284+
285+
def test_embedding_vocab_extension_works_with_pretrained_embedding_file(self):
286+
vocab = Vocabulary()
287+
vocab.add_token_to_namespace('word1')
288+
vocab.add_token_to_namespace('word2')
289+
290+
embeddings_filename = str(self.TEST_DIR / "embeddings2.gz")
291+
with gzip.open(embeddings_filename, 'wb') as embeddings_file:
292+
embeddings_file.write("word3 0.5 0.3 -6.0\n".encode('utf-8'))
293+
embeddings_file.write("word4 1.0 2.3 -1.0\n".encode('utf-8'))
294+
embeddings_file.write("word2 0.1 0.4 -4.0\n".encode('utf-8'))
295+
embeddings_file.write("word1 1.0 2.3 -1.0\n".encode('utf-8'))
296+
297+
embedding_params = Params({"vocab_namespace": "tokens", "embedding_dim": 3,
298+
"pretrained_file": embeddings_filename})
299+
embedder = Embedding.from_params(vocab, embedding_params)
300+
301+
# Change weight to simulate embedding training
302+
embedder.weight.data += 1
303+
assert torch.all(embedder.weight[2:, :] == torch.Tensor([[2.0, 3.3, 0.0], [1.1, 1.4, -3.0]]))
304+
original_weight = embedder.weight
305+
306+
assert tuple(original_weight.size()) == (4, 3) # 4 because of padding and OOV
307+
308+
vocab.add_token_to_namespace('word3')
309+
embedder.extend_vocab(vocab, pretrained_file=embeddings_filename) # default namespace
310+
extended_weight = embedder.weight
311+
312+
# Make sure extenstion happened for extra token in extended vocab
313+
assert tuple(extended_weight.size()) == (5, 3)
314+
315+
# Make sure extension doesn't change original trained weights.
316+
assert torch.all(original_weight[:4, :] == extended_weight[:4, :])
317+
318+
# Make sure extended weight is taken from the embedding file.
319+
assert torch.all(extended_weight[4, :] == torch.Tensor([0.5, 0.3, -6.0]))

0 commit comments

Comments
 (0)