Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 8c06c4b

Browse files
authored
Adding a LanguageModelHead abstraction (#3200)
* Modules and docs * Added tests * Docstrings * pylint * moved linear layer to tests * add todos about caching * fix import... * doc
1 parent 370d512 commit 8c06c4b

File tree

11 files changed

+193
-0
lines changed

11 files changed

+193
-0
lines changed

allennlp/modules/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from allennlp.modules.input_variational_dropout import InputVariationalDropout
2525
from allennlp.modules.bimpm_matching import BiMpmMatching
2626
from allennlp.modules.residual_with_layer_dropout import ResidualWithLayerDropout
27+
from allennlp.modules.language_model_heads import LanguageModelHead
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead
2+
from allennlp.modules.language_model_heads.bert import BertLanguageModelHead
3+
from allennlp.modules.language_model_heads.gpt2 import Gpt2LanguageModelHead
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from overrides import overrides
2+
from pytorch_transformers import BertConfig, BertForMaskedLM
3+
import torch
4+
5+
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead
6+
7+
8+
@LanguageModelHead.register('bert')
9+
class BertLanguageModelHead(LanguageModelHead):
10+
"""
11+
Loads just the LM head from ``pytorch_transformers.BertForMaskedLM``. It was easiest to load
12+
the entire model before only pulling out the head, so this is a bit slower than it could be,
13+
but for practical use in a model, the few seconds of extra loading time is probably not a big
14+
deal.
15+
"""
16+
def __init__(self, model_name: str) -> None:
17+
super().__init__()
18+
config = BertConfig.from_pretrained(model_name)
19+
self.input_dim = config.hidden_size
20+
self.output_dim = config.vocab_size
21+
# TODO(mattg): It's possible that we could use some kind of cache like we have in
22+
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we
23+
# would only load the BERT weights once. Though, it's not clear how to do that here, as we
24+
# need to load `BertForMaskedLM`, not just `BertModel`...
25+
bert_model = BertForMaskedLM.from_pretrained(model_name)
26+
self.bert_lm_head = bert_model.cls # pylint: disable=no-member
27+
28+
@overrides
29+
def get_input_dim(self) -> int:
30+
return self.input_dim
31+
32+
@overrides
33+
def get_output_dim(self) -> int:
34+
return self.output_dim
35+
36+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
37+
return self.bert_lm_head(hidden_states)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from overrides import overrides
2+
from pytorch_transformers import GPT2Config, GPT2LMHeadModel
3+
import torch
4+
5+
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead
6+
7+
8+
@LanguageModelHead.register('gpt2')
9+
class Gpt2LanguageModelHead(LanguageModelHead):
10+
"""
11+
Loads just the LM head from ``pytorch_transformers.GPT2LMHeadModel``. It was easiest to load
12+
the entire model before only pulling out the head, so this is a bit slower than it could be,
13+
but for practical use in a model, the few seconds of extra loading time is probably not a big
14+
deal.
15+
"""
16+
def __init__(self, model_name: str) -> None:
17+
super().__init__()
18+
config = GPT2Config.from_pretrained(model_name)
19+
self.input_dim = config.hidden_size
20+
self.output_dim = config.vocab_size
21+
# TODO(mattg): It's possible that we could use some kind of cache like we have in
22+
# allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we
23+
# would only load the GPT2 weights once. Though, it's not clear how to do that here, as we
24+
# need to load `GPT2LMHeadModel`, not just `GPT2Model`...
25+
gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
26+
self.gpt2_lm_head = gpt2_model.lm_head # pylint: disable=no-member
27+
28+
@overrides
29+
def get_input_dim(self) -> int:
30+
return self.input_dim
31+
32+
@overrides
33+
def get_output_dim(self) -> int:
34+
return self.output_dim
35+
36+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
37+
return self.gpt2_lm_head(hidden_states)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from allennlp.common import Registrable
4+
5+
6+
class LanguageModelHead(torch.nn.Module, Registrable):
7+
"""
8+
A ``LanguageModelHead`` encapsulates a function that goes from some hidden state to logits over
9+
a vocabulary.
10+
"""
11+
def get_input_dim(self) -> int:
12+
raise NotImplementedError
13+
14+
def get_output_dim(self) -> int:
15+
raise NotImplementedError
16+
17+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # type: ignore
18+
# pylint: disable=arguments-differ
19+
raise NotImplementedError

allennlp/tests/modules/language_model_heads/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# pylint: disable=invalid-name,no-self-use,protected-access
2+
import torch
3+
4+
from allennlp.common import Params
5+
from allennlp.common.testing.test_case import AllenNlpTestCase
6+
from allennlp.modules.language_model_heads import LanguageModelHead, BertLanguageModelHead
7+
8+
9+
class TestBertLanguageModelHead(AllenNlpTestCase):
10+
def test_can_init_and_run(self):
11+
# The LM head code reads a module from somewhere else; we're basically just testing here
12+
# that we can initialize the expected model `from_params`.
13+
head = LanguageModelHead.from_params(Params({"type": "bert", "model_name": "bert-base-uncased"}))
14+
assert isinstance(head, BertLanguageModelHead)
15+
assert head.get_input_dim() == 768
16+
assert head.get_output_dim() == 30522
17+
tensor = torch.rand(1, 768)
18+
logits = head(tensor)
19+
assert tuple(logits.size()) == (1, 30522)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# pylint: disable=invalid-name,no-self-use,protected-access
2+
import torch
3+
4+
from allennlp.common import Params
5+
from allennlp.common.testing.test_case import AllenNlpTestCase
6+
from allennlp.modules.language_model_heads import LanguageModelHead, Gpt2LanguageModelHead
7+
8+
9+
class TestGpt2LanguageModelHead(AllenNlpTestCase):
10+
def test_can_init_and_run(self):
11+
# The LM head code reads a module from somewhere else; we're basically just testing here
12+
# that we can initialize the expected model `from_params`.
13+
head = LanguageModelHead.from_params(Params({"type": "gpt2", "model_name": "gpt2"}))
14+
assert isinstance(head, Gpt2LanguageModelHead)
15+
assert head.get_input_dim() == 768
16+
assert head.get_output_dim() == 50257
17+
tensor = torch.rand(1, 768)
18+
logits = head(tensor)
19+
assert tuple(logits.size()) == (1, 50257)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from overrides import overrides
2+
import torch
3+
4+
from allennlp.data import Vocabulary
5+
from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead
6+
7+
8+
@LanguageModelHead.register('linear')
9+
class LinearLanguageModelHead(LanguageModelHead):
10+
"""
11+
Uses ``torch.nn.Linear`` as a language model head. Does nothing else fancy. This was intended
12+
largely for testing code with small models and simple components. It's likely that you would
13+
want something nicer for actually training a language model, such as tying weights with an
14+
input embedding, or an adaptive softmax, or something. But, if you find this class useful for
15+
something you're doing and want it moved into the repo, open an issue on github.
16+
"""
17+
def __init__(self,
18+
vocab: Vocabulary,
19+
input_dim: int,
20+
vocab_namespace: str) -> None:
21+
super().__init__()
22+
self.input_dim = input_dim
23+
self.output_dim = vocab.get_vocab_size(vocab_namespace)
24+
self.linear = torch.nn.Linear(self.input_dim, self.output_dim)
25+
26+
@overrides
27+
def get_input_dim(self) -> int:
28+
return self.input_dim
29+
30+
@overrides
31+
def get_output_dim(self) -> int:
32+
return self.output_dim
33+
34+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35+
return self.linear(hidden_states)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
allennlp.modules.language_model_heads
2+
=====================================
3+
4+
.. automodule:: allennlp.modules.language_model_heads
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
.. automodule:: allennlp.modules.language_model_heads.language_model_head
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
.. automodule:: allennlp.modules.language_model_heads.bert
15+
:members:
16+
:undoc-members:
17+
:show-inheritance:
18+
19+
.. automodule:: allennlp.modules.language_model_heads.gpt2
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:

doc/api/allennlp.modules.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ allennlp.modules
1414
allennlp.modules.lstm_cell_with_projection
1515
allennlp.modules.elmo
1616
allennlp.modules.elmo_lstm
17+
allennlp.modules.language_model_heads
1718
allennlp.modules.conditional_random_field
1819
allennlp.modules.feedforward
1920
allennlp.modules.highway

0 commit comments

Comments
 (0)