Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6ec74aa

Browse files
committed
Added a TokenEmbedder for use with pytorch-transformers
1 parent 0e872a0 commit 6ec74aa

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

allennlp/modules/token_embedders/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
LanguageModelTokenEmbedder
1717
from allennlp.modules.token_embedders.bag_of_word_counts_token_embedder import BagOfWordCountsTokenEmbedder
1818
from allennlp.modules.token_embedders.pass_through_token_embedder import PassThroughTokenEmbedder
19+
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pytorch_transformers.modeling_auto import AutoModel
2+
import torch
3+
4+
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
5+
6+
7+
@TokenEmbedder.register("pretrained_transformer")
8+
class PretrainedTransformerEmbedder(TokenEmbedder):
9+
"""
10+
Uses a pretrained model from ``pytorch-transformers`` as a ``TokenEmbedder``.
11+
"""
12+
def __init__(self, model_name: str) -> None:
13+
super().__init__()
14+
self.transformer_model = AutoModel.from_pretrained(model_name)
15+
16+
def forward(self, token_ids: torch.LongTensor) -> torch.Tensor:
17+
return self.transformer_model(token_ids)[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# pylint: disable=no-self-use,invalid-name
2+
import torch
3+
4+
from allennlp.common import Params
5+
from allennlp.modules.token_embedders import PretrainedTransformerEmbedder
6+
from allennlp.common.testing import AllenNlpTestCase
7+
8+
class TestPretrainedTransformerEmbedder(AllenNlpTestCase):
9+
def test_forward_runs_when_initialized_from_params(self):
10+
# This code just passes things off to pytorch-transformers, so we only have a very simple
11+
# test.
12+
params = Params({'model_name': 'bert-base-uncased'})
13+
embedder = PretrainedTransformerEmbedder.from_params(params)
14+
tensor = torch.randint(0, 100, (1, 4))
15+
output = embedder(tensor)
16+
assert tuple(output.size()) == (1, 4, 768)

0 commit comments

Comments
 (0)