Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 088f0bb

Browse files
authored
Turn BidirectionalLM into a more-general LanguageModel class (#2264)
Fixes #2255 This PR replaces the `BidirectionalLM` class with a more-general `LanguageModel` that can be used in either the unidirectional/forward setting or the bidirectional setting. It also accordingly replaces the `BidirectionalLanguageModelTokenEmbedder` with a `LanguageModelTokenEmbedder`. Also fixes bug in the experiment_unsampled.jsonnet config that was preventing a test from actually being unsampled. TODO: - [x] test the unidirectional case - [x] properly deprecate `BidirectionalLM` and `BidirectionalLanguageModelTokenEmbedder` - [x] check docs for accuracy - [x] fix user-facing training configs
1 parent f76dc70 commit 088f0bb

34 files changed

+829
-483
lines changed

allennlp/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from allennlp.models.bimpm import BiMpm
2828
from allennlp.models.graph_parser import GraphParser
2929
from allennlp.models.bidirectional_lm import BidirectionalLanguageModel
30+
from allennlp.models.language_model import LanguageModel

allennlp/models/bidirectional_lm.py

+12-248
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,16 @@
1-
from typing import Dict, List, Tuple, Union
1+
from typing import Union
22

3-
import torch
4-
import numpy as np
5-
6-
from allennlp.common.checks import ConfigurationError
73
from allennlp.data.vocabulary import Vocabulary
4+
from allennlp.models.language_model import LanguageModel
85
from allennlp.models.model import Model
96
from allennlp.modules.text_field_embedders import TextFieldEmbedder
10-
from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss
117
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
12-
from allennlp.nn.util import get_text_field_mask
138
from allennlp.nn import InitializerApplicator
149

1510

16-
class _SoftmaxLoss(torch.nn.Module):
17-
"""
18-
Given some embeddings and some targets, applies a linear layer
19-
to create logits over possible words and then returns the
20-
negative log likelihood.
21-
"""
22-
def __init__(self,
23-
num_words: int,
24-
embedding_dim: int) -> None:
25-
super().__init__()
26-
27-
# TODO(joelgrus): implement tie_embeddings (maybe)
28-
self.tie_embeddings = False
29-
30-
self.softmax_w = torch.nn.Parameter(
31-
torch.randn(embedding_dim, num_words) / np.sqrt(embedding_dim)
32-
)
33-
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words))
34-
35-
def forward(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
36-
# pylint: disable=arguments-differ
37-
# embeddings is size (n, embedding_dim)
38-
# targets is (batch_size, ) with the correct class id
39-
# Does not do any count normalization / divide by batch size
40-
probs = torch.nn.functional.log_softmax(
41-
torch.matmul(embeddings, self.softmax_w) + self.softmax_b,
42-
dim=-1
43-
)
44-
45-
return torch.nn.functional.nll_loss(probs, targets.long(), reduction="sum")
46-
47-
4811
@Model.register('bidirectional-language-model')
4912
@Model.register('bidirectional_language_model')
50-
class BidirectionalLanguageModel(Model):
13+
class BidirectionalLanguageModel(LanguageModel):
5114
"""
5215
The ``BidirectionalLanguageModel`` applies a bidirectional "contextualizing"
5316
``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``SoftmaxLoss``
@@ -90,211 +53,12 @@ def __init__(self,
9053
num_samples: int = None,
9154
sparse_embeddings: bool = False,
9255
initializer: InitializerApplicator = None) -> None:
93-
super().__init__(vocab)
94-
self._text_field_embedder = text_field_embedder
95-
96-
if not contextualizer.is_bidirectional():
97-
raise ConfigurationError("contextualizer must be bidirectional")
98-
99-
self._contextualizer = contextualizer
100-
# The dimension for making predictions just in the forward
101-
# (or backward) direction.
102-
self._forward_dim = contextualizer.get_output_dim() // 2
103-
104-
# TODO(joelgrus): more sampled softmax configuration options, as needed.
105-
if num_samples is not None:
106-
self._softmax_loss = SampledSoftmaxLoss(num_words=vocab.get_vocab_size(),
107-
embedding_dim=self._forward_dim,
108-
num_samples=num_samples,
109-
sparse=sparse_embeddings)
110-
else:
111-
self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
112-
embedding_dim=self._forward_dim)
113-
114-
# TODO(brendanr): Output perplexity here. e^loss
115-
self.register_buffer('_last_average_loss', torch.zeros(1))
116-
117-
if dropout:
118-
self._dropout = torch.nn.Dropout(dropout)
119-
else:
120-
self._dropout = lambda x: x
121-
122-
self._loss_scale = loss_scale
123-
if initializer is not None:
124-
initializer(self)
125-
126-
def _get_target_token_embedding(self,
127-
token_embeddings: torch.Tensor,
128-
mask: torch.Tensor,
129-
direction: int) -> torch.Tensor:
130-
# Need to shift the mask in the correct direction
131-
zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte()
132-
if direction == 0:
133-
# forward direction, get token to right
134-
shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1)
135-
else:
136-
shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1)
137-
return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim)
138-
139-
def _compute_loss(self,
140-
lm_embeddings: torch.Tensor,
141-
token_embeddings: torch.Tensor,
142-
forward_targets: torch.Tensor,
143-
backward_targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
144-
# lm_embeddings is shape (batch_size, timesteps, dim * 2)
145-
# forward_targets, backward_targets are shape (batch_size, timesteps)
146-
# masked with 0
147-
forward_embeddings, backward_embeddings = lm_embeddings.chunk(2, -1)
148-
losses: List[torch.Tensor] = []
149-
for idx, embedding, targets in ((0, forward_embeddings, forward_targets),
150-
(1, backward_embeddings, backward_targets)):
151-
mask = targets > 0
152-
# we need to subtract 1 to undo the padding id since the softmax
153-
# does not include a padding dimension
154-
155-
# shape (batch_size * timesteps, )
156-
non_masked_targets = targets.masked_select(mask) - 1
157-
158-
# shape (batch_size * timesteps, embedding_dim)
159-
non_masked_embedding = embedding.masked_select(
160-
mask.unsqueeze(-1)
161-
).view(-1, self._forward_dim)
162-
# note: need to return average loss across forward and backward
163-
# directions, but total sum loss across all batches.
164-
# Assuming batches include full sentences, forward and backward
165-
# directions have the same number of samples, so sum up loss
166-
# here then divide by 2 just below
167-
if not self._softmax_loss.tie_embeddings or not self._use_character_inputs:
168-
losses.append(self._softmax_loss(non_masked_embedding, non_masked_targets))
169-
else:
170-
# we also need the token embeddings corresponding to the
171-
# the targets
172-
raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.")
173-
# pylint: disable=unreachable
174-
non_masked_token_embedding = self._get_target_token_embedding(token_embeddings, mask, idx)
175-
losses.append(self._softmax(non_masked_embedding,
176-
non_masked_targets,
177-
non_masked_token_embedding))
178-
179-
return losses[0], losses[1]
180-
181-
def delete_softmax(self) -> None:
182-
"""
183-
Remove the softmax weights. Useful for saving memory when calculating the loss
184-
is not necessary, e.g. in an embedder.
185-
"""
186-
self._softmax_loss = None
187-
188-
def num_layers(self) -> int:
189-
"""
190-
Returns the depth of this LM. That is, how many layers the contextualizer has plus one for
191-
the non-contextual layer.
192-
"""
193-
if hasattr(self._contextualizer, 'num_layers'):
194-
return self._contextualizer.num_layers + 1
195-
else:
196-
raise NotImplementedError(f"Contextualizer of type {type(self._contextualizer)} " +
197-
"does not report how many layers it has.")
198-
199-
def forward(self, # type: ignore
200-
source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
201-
"""
202-
Computes the averaged forward and backward LM loss from the batch.
203-
204-
By convention, the input dict is required to have at least a ``"tokens"``
205-
entry that's the output of a ``SingleIdTokenIndexer``, which is used
206-
to compute the language model targets.
207-
208-
Parameters
209-
----------
210-
tokens: ``torch.Tensor``, required.
211-
The output of ``Batch.as_tensor_dict()`` for a batch of sentences.
212-
213-
Returns
214-
-------
215-
Dict with keys:
216-
217-
``'loss'``: ``torch.Tensor``
218-
averaged forward/backward negative log likelihood
219-
``'forward_loss'``: ``torch.Tensor``
220-
forward direction negative log likelihood
221-
``'backward_loss'``: ``torch.Tensor``
222-
backward direction negative log likelihood
223-
``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]``
224-
(batch_size, timesteps, embed_dim) tensor of top layer contextual representations or
225-
list of all layers. No dropout applied.
226-
``'noncontextual_token_embeddings'``: ``torch.Tensor``
227-
(batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual
228-
representations
229-
``'mask'``: ``torch.Tensor``
230-
(batch_size, timesteps) mask for the embeddings
231-
"""
232-
# pylint: disable=arguments-differ
233-
mask = get_text_field_mask(source)
234-
235-
# shape (batch_size, timesteps, embedding_size)
236-
embeddings = self._text_field_embedder(source)
237-
238-
# Either the top layer or all layers.
239-
contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer(
240-
embeddings, mask
241-
)
242-
243-
return_dict = {}
244-
245-
# If we have target tokens, calculate the loss.
246-
token_ids = source.get("tokens")
247-
if token_ids is not None:
248-
assert isinstance(contextual_embeddings, torch.Tensor)
249-
250-
# Use token_ids to compute targets
251-
forward_targets = torch.zeros_like(token_ids)
252-
backward_targets = torch.zeros_like(token_ids)
253-
forward_targets[:, 0:-1] = token_ids[:, 1:]
254-
backward_targets[:, 1:] = token_ids[:, 0:-1]
255-
256-
# add dropout
257-
contextual_embeddings_with_dropout = self._dropout(contextual_embeddings)
258-
259-
# compute softmax loss
260-
forward_loss, backward_loss = self._compute_loss(contextual_embeddings_with_dropout,
261-
embeddings,
262-
forward_targets,
263-
backward_targets)
264-
265-
num_targets = torch.sum((forward_targets > 0).long())
266-
if num_targets > 0:
267-
average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float()
268-
else:
269-
average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable
270-
# this is stored to compute perplexity if needed
271-
self._last_average_loss[0] = average_loss.detach().item()
272-
273-
if num_targets > 0:
274-
# loss is directly minimized
275-
if self._loss_scale == 'n_samples':
276-
scale_factor = num_targets.float()
277-
else:
278-
scale_factor = self._loss_scale
279-
280-
return_dict.update({
281-
'loss': average_loss * scale_factor,
282-
'forward_loss': forward_loss * scale_factor / num_targets.float(),
283-
'backward_loss': backward_loss * scale_factor / num_targets.float()
284-
})
285-
else:
286-
# average_loss zero tensor, return it for all
287-
return_dict.update({
288-
'loss': average_loss,
289-
'forward_loss': average_loss,
290-
'backward_loss': average_loss
291-
})
292-
293-
return_dict.update({
294-
# Note: These embeddings do not have dropout applied.
295-
'lm_embeddings': contextual_embeddings,
296-
'noncontextual_token_embeddings': embeddings,
297-
'mask': mask
298-
})
299-
300-
return return_dict
56+
super().__init__(vocab=vocab,
57+
text_field_embedder=text_field_embedder,
58+
contextualizer=contextualizer,
59+
dropout=dropout,
60+
loss_scale=loss_scale,
61+
num_samples=num_samples,
62+
sparse_embeddings=sparse_embeddings,
63+
bidirectional=True,
64+
initializer=initializer)

0 commit comments

Comments
 (0)