Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9267ce7

Browse files
pvcastrodirkgr
andauthored
Resize transformers word embeddings layer for additional_special_tokens (#4946)
* Adding a mechanism to resize the word embeddings layer from transformers models in case additional special tokens are provided in tokenizer_kwargs. * Updating changelog * Reformatting test file with black * Fixing failed test for transformer model that don't implement get_input_embeddings() * Adding message to warn user about the transformer model being unable to resize it's embeddings layer when additional tokens are provided * Reformatting with black Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent 52c23dd commit 9267ce7

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
### Added
2929

3030
- Added `tokenizer_kwargs` and `transformer_kwargs` arguments to `PretrainedTransformerBackbone`
31+
- Resize transformers word embeddings layer for `additional_special_tokens`
3132

3233
### Changed
3334

allennlp/modules/token_embedders/pretrained_transformer_embedder.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import math
23
from typing import Optional, Tuple, Dict, Any
34

@@ -12,6 +13,8 @@
1213
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
1314
from allennlp.nn.util import batched_index_select
1415

16+
logger = logging.getLogger(__name__)
17+
1518

1619
@TokenEmbedder.register("pretrained_transformer")
1720
class PretrainedTransformerEmbedder(TokenEmbedder):
@@ -101,6 +104,19 @@ def __init__(
101104
model_name,
102105
tokenizer_kwargs=tokenizer_kwargs,
103106
)
107+
108+
try:
109+
if self.transformer_model.get_input_embeddings().num_embeddings != len(
110+
tokenizer.tokenizer
111+
):
112+
self.transformer_model.resize_token_embeddings(len(tokenizer.tokenizer))
113+
except NotImplementedError:
114+
# Can't resize for transformers models that don't implement base_model.get_input_embeddings()
115+
logger.warning(
116+
"Could not resize the token embedding matrix of the transformer model. "
117+
"This model does not support resizing."
118+
)
119+
104120
self._num_added_start_tokens = len(tokenizer.single_sequence_start_tokens)
105121
self._num_added_end_tokens = len(tokenizer.single_sequence_end_tokens)
106122
self._num_added_tokens = self._num_added_start_tokens + self._num_added_end_tokens

allennlp/training/metric_tracker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,6 @@ def combined_score(self, metrics: Dict[str, float]) -> float:
128128
)
129129
except KeyError as e:
130130
raise ConfigurationError(
131-
f"You configured the trainer to use the {e.args[0]}"
131+
f"You configured the trainer to use the {e.args[0]} "
132132
"metric for early stopping, but the model did not produce that metric."
133133
)

tests/modules/token_embedders/pretrained_transformer_embedder_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,18 @@ def test_encoder_decoder_model(self):
315315
token_ids = torch.LongTensor([[1, 2, 3], [2, 3, 4]])
316316
mask = torch.ones_like(token_ids).bool()
317317
token_embedder(token_ids, mask)
318+
319+
def test_embeddings_resize(self):
320+
regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased")
321+
assert (
322+
regular_token_embedder.transformer_model.embeddings.word_embeddings.num_embeddings
323+
== 28996
324+
)
325+
tokenizer_kwargs = {"additional_special_tokens": ["<NEW_TOKEN>"]}
326+
enhanced_token_embedder = PretrainedTransformerEmbedder(
327+
"bert-base-cased", tokenizer_kwargs=tokenizer_kwargs
328+
)
329+
assert (
330+
enhanced_token_embedder.transformer_model.embeddings.word_embeddings.num_embeddings
331+
== 28997
332+
)

0 commit comments

Comments
 (0)