|
1 |
| -from typing import Dict, List, Tuple, Union |
| 1 | +from typing import Union |
2 | 2 |
|
3 |
| -import torch |
4 |
| -import numpy as np |
5 |
| - |
6 |
| -from allennlp.common.checks import ConfigurationError |
7 | 3 | from allennlp.data.vocabulary import Vocabulary
|
| 4 | +from allennlp.models.language_model import LanguageModel |
8 | 5 | from allennlp.models.model import Model
|
9 | 6 | from allennlp.modules.text_field_embedders import TextFieldEmbedder
|
10 |
| -from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss |
11 | 7 | from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
|
12 |
| -from allennlp.nn.util import get_text_field_mask |
13 | 8 | from allennlp.nn import InitializerApplicator
|
14 | 9 |
|
15 | 10 |
|
16 |
| -class _SoftmaxLoss(torch.nn.Module): |
17 |
| - """ |
18 |
| - Given some embeddings and some targets, applies a linear layer |
19 |
| - to create logits over possible words and then returns the |
20 |
| - negative log likelihood. |
21 |
| - """ |
22 |
| - def __init__(self, |
23 |
| - num_words: int, |
24 |
| - embedding_dim: int) -> None: |
25 |
| - super().__init__() |
26 |
| - |
27 |
| - # TODO(joelgrus): implement tie_embeddings (maybe) |
28 |
| - self.tie_embeddings = False |
29 |
| - |
30 |
| - self.softmax_w = torch.nn.Parameter( |
31 |
| - torch.randn(embedding_dim, num_words) / np.sqrt(embedding_dim) |
32 |
| - ) |
33 |
| - self.softmax_b = torch.nn.Parameter(torch.zeros(num_words)) |
34 |
| - |
35 |
| - def forward(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
36 |
| - # pylint: disable=arguments-differ |
37 |
| - # embeddings is size (n, embedding_dim) |
38 |
| - # targets is (batch_size, ) with the correct class id |
39 |
| - # Does not do any count normalization / divide by batch size |
40 |
| - probs = torch.nn.functional.log_softmax( |
41 |
| - torch.matmul(embeddings, self.softmax_w) + self.softmax_b, |
42 |
| - dim=-1 |
43 |
| - ) |
44 |
| - |
45 |
| - return torch.nn.functional.nll_loss(probs, targets.long(), reduction="sum") |
46 |
| - |
47 |
| - |
48 | 11 | @Model.register('bidirectional-language-model')
|
49 | 12 | @Model.register('bidirectional_language_model')
|
50 |
| -class BidirectionalLanguageModel(Model): |
| 13 | +class BidirectionalLanguageModel(LanguageModel): |
51 | 14 | """
|
52 | 15 | The ``BidirectionalLanguageModel`` applies a bidirectional "contextualizing"
|
53 | 16 | ``Seq2SeqEncoder`` to uncontextualized embeddings, using a ``SoftmaxLoss``
|
@@ -90,211 +53,12 @@ def __init__(self,
|
90 | 53 | num_samples: int = None,
|
91 | 54 | sparse_embeddings: bool = False,
|
92 | 55 | initializer: InitializerApplicator = None) -> None:
|
93 |
| - super().__init__(vocab) |
94 |
| - self._text_field_embedder = text_field_embedder |
95 |
| - |
96 |
| - if not contextualizer.is_bidirectional(): |
97 |
| - raise ConfigurationError("contextualizer must be bidirectional") |
98 |
| - |
99 |
| - self._contextualizer = contextualizer |
100 |
| - # The dimension for making predictions just in the forward |
101 |
| - # (or backward) direction. |
102 |
| - self._forward_dim = contextualizer.get_output_dim() // 2 |
103 |
| - |
104 |
| - # TODO(joelgrus): more sampled softmax configuration options, as needed. |
105 |
| - if num_samples is not None: |
106 |
| - self._softmax_loss = SampledSoftmaxLoss(num_words=vocab.get_vocab_size(), |
107 |
| - embedding_dim=self._forward_dim, |
108 |
| - num_samples=num_samples, |
109 |
| - sparse=sparse_embeddings) |
110 |
| - else: |
111 |
| - self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(), |
112 |
| - embedding_dim=self._forward_dim) |
113 |
| - |
114 |
| - # TODO(brendanr): Output perplexity here. e^loss |
115 |
| - self.register_buffer('_last_average_loss', torch.zeros(1)) |
116 |
| - |
117 |
| - if dropout: |
118 |
| - self._dropout = torch.nn.Dropout(dropout) |
119 |
| - else: |
120 |
| - self._dropout = lambda x: x |
121 |
| - |
122 |
| - self._loss_scale = loss_scale |
123 |
| - if initializer is not None: |
124 |
| - initializer(self) |
125 |
| - |
126 |
| - def _get_target_token_embedding(self, |
127 |
| - token_embeddings: torch.Tensor, |
128 |
| - mask: torch.Tensor, |
129 |
| - direction: int) -> torch.Tensor: |
130 |
| - # Need to shift the mask in the correct direction |
131 |
| - zero_col = token_embeddings.new_zeros(mask.size(0), 1).byte() |
132 |
| - if direction == 0: |
133 |
| - # forward direction, get token to right |
134 |
| - shifted_mask = torch.cat([zero_col, mask[:, 0:-1]], dim=1) |
135 |
| - else: |
136 |
| - shifted_mask = torch.cat([mask[:, 1:], zero_col], dim=1) |
137 |
| - return token_embeddings.masked_select(shifted_mask.unsqueeze(-1)).view(-1, self._forward_dim) |
138 |
| - |
139 |
| - def _compute_loss(self, |
140 |
| - lm_embeddings: torch.Tensor, |
141 |
| - token_embeddings: torch.Tensor, |
142 |
| - forward_targets: torch.Tensor, |
143 |
| - backward_targets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
144 |
| - # lm_embeddings is shape (batch_size, timesteps, dim * 2) |
145 |
| - # forward_targets, backward_targets are shape (batch_size, timesteps) |
146 |
| - # masked with 0 |
147 |
| - forward_embeddings, backward_embeddings = lm_embeddings.chunk(2, -1) |
148 |
| - losses: List[torch.Tensor] = [] |
149 |
| - for idx, embedding, targets in ((0, forward_embeddings, forward_targets), |
150 |
| - (1, backward_embeddings, backward_targets)): |
151 |
| - mask = targets > 0 |
152 |
| - # we need to subtract 1 to undo the padding id since the softmax |
153 |
| - # does not include a padding dimension |
154 |
| - |
155 |
| - # shape (batch_size * timesteps, ) |
156 |
| - non_masked_targets = targets.masked_select(mask) - 1 |
157 |
| - |
158 |
| - # shape (batch_size * timesteps, embedding_dim) |
159 |
| - non_masked_embedding = embedding.masked_select( |
160 |
| - mask.unsqueeze(-1) |
161 |
| - ).view(-1, self._forward_dim) |
162 |
| - # note: need to return average loss across forward and backward |
163 |
| - # directions, but total sum loss across all batches. |
164 |
| - # Assuming batches include full sentences, forward and backward |
165 |
| - # directions have the same number of samples, so sum up loss |
166 |
| - # here then divide by 2 just below |
167 |
| - if not self._softmax_loss.tie_embeddings or not self._use_character_inputs: |
168 |
| - losses.append(self._softmax_loss(non_masked_embedding, non_masked_targets)) |
169 |
| - else: |
170 |
| - # we also need the token embeddings corresponding to the |
171 |
| - # the targets |
172 |
| - raise NotImplementedError("This requires SampledSoftmaxLoss, which isn't implemented yet.") |
173 |
| - # pylint: disable=unreachable |
174 |
| - non_masked_token_embedding = self._get_target_token_embedding(token_embeddings, mask, idx) |
175 |
| - losses.append(self._softmax(non_masked_embedding, |
176 |
| - non_masked_targets, |
177 |
| - non_masked_token_embedding)) |
178 |
| - |
179 |
| - return losses[0], losses[1] |
180 |
| - |
181 |
| - def delete_softmax(self) -> None: |
182 |
| - """ |
183 |
| - Remove the softmax weights. Useful for saving memory when calculating the loss |
184 |
| - is not necessary, e.g. in an embedder. |
185 |
| - """ |
186 |
| - self._softmax_loss = None |
187 |
| - |
188 |
| - def num_layers(self) -> int: |
189 |
| - """ |
190 |
| - Returns the depth of this LM. That is, how many layers the contextualizer has plus one for |
191 |
| - the non-contextual layer. |
192 |
| - """ |
193 |
| - if hasattr(self._contextualizer, 'num_layers'): |
194 |
| - return self._contextualizer.num_layers + 1 |
195 |
| - else: |
196 |
| - raise NotImplementedError(f"Contextualizer of type {type(self._contextualizer)} " + |
197 |
| - "does not report how many layers it has.") |
198 |
| - |
199 |
| - def forward(self, # type: ignore |
200 |
| - source: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: |
201 |
| - """ |
202 |
| - Computes the averaged forward and backward LM loss from the batch. |
203 |
| -
|
204 |
| - By convention, the input dict is required to have at least a ``"tokens"`` |
205 |
| - entry that's the output of a ``SingleIdTokenIndexer``, which is used |
206 |
| - to compute the language model targets. |
207 |
| -
|
208 |
| - Parameters |
209 |
| - ---------- |
210 |
| - tokens: ``torch.Tensor``, required. |
211 |
| - The output of ``Batch.as_tensor_dict()`` for a batch of sentences. |
212 |
| -
|
213 |
| - Returns |
214 |
| - ------- |
215 |
| - Dict with keys: |
216 |
| -
|
217 |
| - ``'loss'``: ``torch.Tensor`` |
218 |
| - averaged forward/backward negative log likelihood |
219 |
| - ``'forward_loss'``: ``torch.Tensor`` |
220 |
| - forward direction negative log likelihood |
221 |
| - ``'backward_loss'``: ``torch.Tensor`` |
222 |
| - backward direction negative log likelihood |
223 |
| - ``'lm_embeddings'``: ``Union[torch.Tensor, List[torch.Tensor]]`` |
224 |
| - (batch_size, timesteps, embed_dim) tensor of top layer contextual representations or |
225 |
| - list of all layers. No dropout applied. |
226 |
| - ``'noncontextual_token_embeddings'``: ``torch.Tensor`` |
227 |
| - (batch_size, timesteps, token_embed_dim) tensor of bottom layer noncontextual |
228 |
| - representations |
229 |
| - ``'mask'``: ``torch.Tensor`` |
230 |
| - (batch_size, timesteps) mask for the embeddings |
231 |
| - """ |
232 |
| - # pylint: disable=arguments-differ |
233 |
| - mask = get_text_field_mask(source) |
234 |
| - |
235 |
| - # shape (batch_size, timesteps, embedding_size) |
236 |
| - embeddings = self._text_field_embedder(source) |
237 |
| - |
238 |
| - # Either the top layer or all layers. |
239 |
| - contextual_embeddings: Union[torch.Tensor, List[torch.Tensor]] = self._contextualizer( |
240 |
| - embeddings, mask |
241 |
| - ) |
242 |
| - |
243 |
| - return_dict = {} |
244 |
| - |
245 |
| - # If we have target tokens, calculate the loss. |
246 |
| - token_ids = source.get("tokens") |
247 |
| - if token_ids is not None: |
248 |
| - assert isinstance(contextual_embeddings, torch.Tensor) |
249 |
| - |
250 |
| - # Use token_ids to compute targets |
251 |
| - forward_targets = torch.zeros_like(token_ids) |
252 |
| - backward_targets = torch.zeros_like(token_ids) |
253 |
| - forward_targets[:, 0:-1] = token_ids[:, 1:] |
254 |
| - backward_targets[:, 1:] = token_ids[:, 0:-1] |
255 |
| - |
256 |
| - # add dropout |
257 |
| - contextual_embeddings_with_dropout = self._dropout(contextual_embeddings) |
258 |
| - |
259 |
| - # compute softmax loss |
260 |
| - forward_loss, backward_loss = self._compute_loss(contextual_embeddings_with_dropout, |
261 |
| - embeddings, |
262 |
| - forward_targets, |
263 |
| - backward_targets) |
264 |
| - |
265 |
| - num_targets = torch.sum((forward_targets > 0).long()) |
266 |
| - if num_targets > 0: |
267 |
| - average_loss = 0.5 * (forward_loss + backward_loss) / num_targets.float() |
268 |
| - else: |
269 |
| - average_loss = torch.tensor(0.0).to(forward_targets.device) # pylint: disable=not-callable |
270 |
| - # this is stored to compute perplexity if needed |
271 |
| - self._last_average_loss[0] = average_loss.detach().item() |
272 |
| - |
273 |
| - if num_targets > 0: |
274 |
| - # loss is directly minimized |
275 |
| - if self._loss_scale == 'n_samples': |
276 |
| - scale_factor = num_targets.float() |
277 |
| - else: |
278 |
| - scale_factor = self._loss_scale |
279 |
| - |
280 |
| - return_dict.update({ |
281 |
| - 'loss': average_loss * scale_factor, |
282 |
| - 'forward_loss': forward_loss * scale_factor / num_targets.float(), |
283 |
| - 'backward_loss': backward_loss * scale_factor / num_targets.float() |
284 |
| - }) |
285 |
| - else: |
286 |
| - # average_loss zero tensor, return it for all |
287 |
| - return_dict.update({ |
288 |
| - 'loss': average_loss, |
289 |
| - 'forward_loss': average_loss, |
290 |
| - 'backward_loss': average_loss |
291 |
| - }) |
292 |
| - |
293 |
| - return_dict.update({ |
294 |
| - # Note: These embeddings do not have dropout applied. |
295 |
| - 'lm_embeddings': contextual_embeddings, |
296 |
| - 'noncontextual_token_embeddings': embeddings, |
297 |
| - 'mask': mask |
298 |
| - }) |
299 |
| - |
300 |
| - return return_dict |
| 56 | + super().__init__(vocab=vocab, |
| 57 | + text_field_embedder=text_field_embedder, |
| 58 | + contextualizer=contextualizer, |
| 59 | + dropout=dropout, |
| 60 | + loss_scale=loss_scale, |
| 61 | + num_samples=num_samples, |
| 62 | + sparse_embeddings=sparse_embeddings, |
| 63 | + bidirectional=True, |
| 64 | + initializer=initializer) |
0 commit comments