|
| 1 | +from overrides import overrides |
| 2 | +from pytorch_transformers import GPT2Config, GPT2LMHeadModel |
| 3 | +import torch |
| 4 | + |
| 5 | +from allennlp.modules.language_model_heads.language_model_head import LanguageModelHead |
| 6 | + |
| 7 | + |
| 8 | +@LanguageModelHead.register('gpt2') |
| 9 | +class Gpt2LanguageModelHead(LanguageModelHead): |
| 10 | + """ |
| 11 | + Loads just the LM head from ``pytorch_transformers.GPT2LMHeadModel``. It was easiest to load |
| 12 | + the entire model before only pulling out the head, so this is a bit slower than it could be, |
| 13 | + but for practical use in a model, the few seconds of extra loading time is probably not a big |
| 14 | + deal. |
| 15 | + """ |
| 16 | + def __init__(self, model_name: str) -> None: |
| 17 | + super().__init__() |
| 18 | + config = GPT2Config.from_pretrained(model_name) |
| 19 | + self.input_dim = config.hidden_size |
| 20 | + self.output_dim = config.vocab_size |
| 21 | + # TODO(mattg): It's possible that we could use some kind of cache like we have in |
| 22 | + # allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel. That way, we |
| 23 | + # would only load the GPT2 weights once. Though, it's not clear how to do that here, as we |
| 24 | + # need to load `GPT2LMHeadModel`, not just `GPT2Model`... |
| 25 | + gpt2_model = GPT2LMHeadModel.from_pretrained(model_name) |
| 26 | + self.gpt2_lm_head = gpt2_model.lm_head # pylint: disable=no-member |
| 27 | + |
| 28 | + @overrides |
| 29 | + def get_input_dim(self) -> int: |
| 30 | + return self.input_dim |
| 31 | + |
| 32 | + @overrides |
| 33 | + def get_output_dim(self) -> int: |
| 34 | + return self.output_dim |
| 35 | + |
| 36 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 37 | + return self.gpt2_lm_head(hidden_states) |
0 commit comments