Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 86da880

Browse files
authored
add sampled softmax loss (#2042)
* first stab at contextual encoder wrappers * contextual encoders * remove sru encoder * pr comments * replace _ElmoCharacterEncoder with CharacterEncoder * docs * sphinx stuff * address pr comments * address more PR comments * make sphinx happy * iterate * make parameters required * this is still wip * wip * bidirectional-lm proof of concept * progress * revert elmo * revert elmo test * revert elmo token embedder * cnn_highway_encoder -> seq2vec * remove contextual encoders * fix docs * remove print * address more feedback * replace none with identity function * fix docs + checks * fix tests * add comments * add top level imports * fix imports * unused import * progress * use brendan's dataset reader * checkpoint * fun * remove tie_embeddings code paths + fast_sampler * fix tests * revert setup.py * address PR feedback * address pr feedback
1 parent 07b5749 commit 86da880

8 files changed

+427
-96
lines changed

allennlp/models/bidirectional_lm.py

+28-30
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from allennlp.models.model import Model
99
from allennlp.modules.masked_layer_norm import MaskedLayerNorm
1010
from allennlp.modules.text_field_embedders import TextFieldEmbedder
11+
from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss
1112
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
1213
from allennlp.nn.util import get_text_field_mask, remove_sentence_boundaries
1314

@@ -20,45 +21,28 @@ class _SoftmaxLoss(torch.nn.Module):
2021
"""
2122
def __init__(self,
2223
num_words: int,
23-
embedding_dim: int,
24-
token_encoder: torch.nn.Parameter = None) -> None:
24+
embedding_dim: int) -> None:
2525
super().__init__()
2626

27-
self.tie_embeddings = token_encoder is not None
27+
# TODO(joelgrus): implement tie_embeddings (maybe)
28+
self.tie_embeddings = False
2829

29-
# Glorit init (std=(1.0 / sqrt(fan_in))
30-
if self.tie_embeddings:
31-
self.softmax_w = token_encoder
32-
# +1 for shape to include padding dimension
33-
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words + 1))
34-
else:
35-
self.softmax_w = torch.nn.Parameter(
36-
torch.randn(embedding_dim, num_words) / np.sqrt(embedding_dim)
37-
)
38-
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words))
30+
self.softmax_w = torch.nn.Parameter(
31+
torch.randn(embedding_dim, num_words) / np.sqrt(embedding_dim)
32+
)
33+
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words))
3934

4035
def forward(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
4136
# pylint: disable=arguments-differ
4237
# embeddings is size (n, embedding_dim)
4338
# targets is (batch_size, ) with the correct class id
4439
# Does not do any count normalization / divide by batch size
45-
if self.tie_embeddings:
46-
softmax_w = self.softmax_w.weight.t()
47-
else:
48-
softmax_w = self.softmax_w
49-
5040
probs = torch.nn.functional.log_softmax(
51-
torch.matmul(embeddings, softmax_w) + self.softmax_b,
41+
torch.matmul(embeddings, self.softmax_w) + self.softmax_b,
5242
dim=-1
5343
)
5444

55-
if self.tie_embeddings:
56-
# need to add back in padding dim!
57-
targets_ = targets + 1
58-
else:
59-
targets_ = targets
60-
61-
return torch.nn.functional.nll_loss(probs, targets_.long(), reduction="sum")
45+
return torch.nn.functional.nll_loss(probs, targets.long(), reduction="sum")
6246

6347

6448
@Model.register('bidirectional-language-model')
@@ -95,6 +79,12 @@ class BidirectionalLanguageModel(Model):
9579
Typically the provided token indexes will be augmented with
9680
begin-sentence and end-sentence tokens. If this flag is True
9781
the corresponding embeddings will be removed from the return values.
82+
num_samples: ``int``, optional (default: None)
83+
If provided, the model will use ``SampledSoftmaxLoss``
84+
with the specified number of samples. Otherwise, it will use
85+
the full ``_SoftmaxLoss`` defined above.
86+
sparse_embeddings: ``bool``, optional (default: False)
87+
Passed on to ``SampledSoftmaxLoss`` if True.
9888
"""
9989
def __init__(self,
10090
vocab: Vocabulary,
@@ -103,7 +93,9 @@ def __init__(self,
10393
layer_norm: Optional[MaskedLayerNorm] = None,
10494
dropout: float = None,
10595
loss_scale: Union[float, str] = 1.0,
106-
remove_bos_eos: bool = True) -> None:
96+
remove_bos_eos: bool = True,
97+
num_samples: int = None,
98+
sparse_embeddings: bool = False) -> None:
10799
super().__init__(vocab)
108100
self._text_field_embedder = text_field_embedder
109101
self._layer_norm = layer_norm or (lambda x: x)
@@ -116,9 +108,15 @@ def __init__(self,
116108
# (or backward) direction.
117109
self._forward_dim = contextualizer.get_output_dim() // 2
118110

119-
# TODO(joelgrus): Allow SampledSoftmaxLoss here by configuration
120-
self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
121-
embedding_dim=self._forward_dim)
111+
# TODO(joelgrus): more sampled softmax configuration options, as needed.
112+
if num_samples is not None:
113+
self._softmax_loss = SampledSoftmaxLoss(num_words=vocab.get_vocab_size(),
114+
embedding_dim=self._forward_dim,
115+
num_samples=num_samples,
116+
sparse=sparse_embeddings)
117+
else:
118+
self._softmax_loss = _SoftmaxLoss(num_words=vocab.get_vocab_size(),
119+
embedding_dim=self._forward_dim)
122120

123121
self.register_buffer('_last_average_loss', torch.zeros(1))
124122

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/nn_impl.py#L885
2+
from typing import Set, Tuple
3+
4+
import numpy as np
5+
6+
import torch
7+
8+
from allennlp.common.checks import ConfigurationError
9+
10+
11+
def _choice(num_words: int, num_samples: int) -> Tuple[np.ndarray, int]:
12+
"""
13+
Chooses ``num_samples`` samples without replacement from [0, ..., num_words).
14+
Returns a tuple (samples, num_tries).
15+
"""
16+
num_tries = 0
17+
num_chosen = 0
18+
19+
def get_buffer() -> np.ndarray:
20+
log_samples = np.random.rand(num_samples) * np.log(num_words + 1)
21+
samples = np.exp(log_samples).astype('int64') - 1
22+
return np.clip(samples, a_min=0, a_max=num_words - 1)
23+
24+
sample_buffer = get_buffer()
25+
buffer_index = 0
26+
samples: Set[int] = set()
27+
28+
while num_chosen < num_samples:
29+
num_tries += 1
30+
# choose sample
31+
sample_id = sample_buffer[buffer_index]
32+
if sample_id not in samples:
33+
samples.add(sample_id)
34+
num_chosen += 1
35+
36+
buffer_index += 1
37+
if buffer_index == num_samples:
38+
# Reset the buffer
39+
sample_buffer = get_buffer()
40+
buffer_index = 0
41+
42+
return np.array(list(samples)), num_tries
43+
44+
45+
class SampledSoftmaxLoss(torch.nn.Module):
46+
"""
47+
Based on the default log_uniform_candidate_sampler in tensorflow.
48+
49+
NOTE: num_words DOES NOT include padding id.
50+
51+
NOTE: In all cases except (tie_embeddings=True and use_character_inputs=False)
52+
the weights are dimensioned as num_words and do not include an entry for the padding (0) id.
53+
For the (tie_embeddings=True and use_character_inputs=False) case,
54+
then the embeddings DO include the extra 0 padding, to be consistent with the word embedding layer.
55+
56+
Parameters
57+
----------
58+
num_words, ``int``
59+
The number of words in the vocabulary
60+
embedding_dim, ``int``
61+
The dimension to softmax over
62+
num_samples, ``int``
63+
During training take this many samples. Must be less than num_words.
64+
sparse, ``bool``, optional (default = False)
65+
If this is true, we use a sparse embedding matrix.
66+
unk_id, ``int``, optional (default = None)
67+
If provided, the id that represents unknown characters.
68+
use_character_inputs, ``bool``, optional (default = True)
69+
Whether to use character inputs
70+
use_fast_sampler, ``bool``, optional (default = False)
71+
Whether to use the fast cython sampler.
72+
"""
73+
def __init__(self,
74+
num_words: int,
75+
embedding_dim: int,
76+
num_samples: int,
77+
sparse: bool = False,
78+
unk_id: int = None,
79+
use_character_inputs: bool = True,
80+
use_fast_sampler: bool = False) -> None:
81+
super().__init__()
82+
83+
# TODO(joelgrus): implement tie_embeddings (maybe)
84+
self.tie_embeddings = False
85+
86+
assert num_samples < num_words
87+
88+
if use_fast_sampler:
89+
raise ConfigurationError("fast sampler is not implemented")
90+
else:
91+
self.choice_func = _choice
92+
93+
# Glorit init (std=(1.0 / sqrt(fan_in))
94+
if sparse:
95+
# create our own sparse embedding
96+
self.softmax_w = torch.nn.Embedding(num_words, embedding_dim, sparse=True)
97+
self.softmax_w.weight.data.normal_(mean=0.0, std=1.0 / np.sqrt(embedding_dim))
98+
self.softmax_b = torch.nn.Embedding(num_words, 1, sparse=True)
99+
self.softmax_b.weight.data.fill_(0.0)
100+
else:
101+
# just create tensors to use as the embeddings
102+
# Glorit init (std=(1.0 / sqrt(fan_in))
103+
self.softmax_w = torch.nn.Parameter(torch.randn(num_words, embedding_dim) / np.sqrt(embedding_dim))
104+
self.softmax_b = torch.nn.Parameter(torch.zeros(num_words))
105+
106+
self.sparse = sparse
107+
self.use_character_inputs = use_character_inputs
108+
109+
if use_character_inputs:
110+
self._unk_id = unk_id
111+
112+
self._num_samples = num_samples
113+
self._embedding_dim = embedding_dim
114+
self._num_words = num_words
115+
self.initialize_num_words()
116+
117+
def initialize_num_words(self):
118+
if self.sparse:
119+
num_words = self.softmax_w.weight.size(0)
120+
else:
121+
num_words = self.softmax_w.size(0)
122+
123+
self._num_words = num_words
124+
self._log_num_words_p1 = np.log(num_words + 1)
125+
126+
# compute the probability of each sampled id
127+
self._probs = (np.log(np.arange(num_words) + 2) -
128+
np.log(np.arange(num_words) + 1)) / self._log_num_words_p1
129+
130+
131+
def forward(self,
132+
embeddings: torch.Tensor,
133+
targets: torch.Tensor,
134+
target_token_embedding: torch.Tensor = None) -> torch.Tensor:
135+
# pylint: disable=arguments-differ
136+
137+
# embeddings is size (n, embedding_dim)
138+
# targets is (n_words, ) with the index of the actual target
139+
# when tieing weights, target_token_embedding is required.
140+
# it is size (n_words, embedding_dim)
141+
# returns log likelihood loss (batch_size, )
142+
# Does not do any count normalization / divide by batch size
143+
144+
if embeddings.shape[0] == 0:
145+
# empty batch
146+
return torch.tensor(0.0).to(embeddings.device) # pylint: disable=not-callable
147+
148+
if not self.training:
149+
return self._forward_eval(embeddings, targets)
150+
else:
151+
return self._forward_train(embeddings, targets, target_token_embedding)
152+
153+
def _forward_train(self,
154+
embeddings: torch.Tensor,
155+
targets: torch.Tensor,
156+
target_token_embedding: torch.Tensor) -> torch.Tensor:
157+
# pylint: disable=unused-argument
158+
# (target_token_embedding is only used in the tie_embeddings case,
159+
# which is not implemented)
160+
161+
# want to compute (n, n_samples + 1) array with the log
162+
# probabilities where the first index is the true target
163+
# and the remaining ones are the the negative samples.
164+
# then we can just select the first column
165+
166+
# NOTE: targets input has padding removed (so 0 == the first id, NOT the padding id)
167+
168+
sampled_ids, target_expected_count, sampled_expected_count = \
169+
self.log_uniform_candidate_sampler(targets, choice_func=self.choice_func)
170+
171+
long_targets = targets.long()
172+
long_targets.requires_grad_(False)
173+
174+
# Get the softmax weights (so we can compute logits)
175+
all_ids = torch.cat([long_targets, sampled_ids], dim=0)
176+
177+
if self.sparse:
178+
all_ids_1 = all_ids.unsqueeze(1)
179+
all_w = self.softmax_w(all_ids_1).squeeze(1)
180+
all_b = self.softmax_b(all_ids_1).squeeze(2).squeeze(1)
181+
else:
182+
all_w = torch.nn.functional.embedding(all_ids, self.softmax_w)
183+
# the unsqueeze / squeeze works around an issue with 1 dim
184+
# embeddings
185+
all_b = torch.nn.functional.embedding(all_ids, self.softmax_b.unsqueeze(1)).squeeze(1)
186+
187+
batch_size = long_targets.size(0)
188+
true_w = all_w[:batch_size, :]
189+
sampled_w = all_w[batch_size:, :]
190+
true_b = all_b[:batch_size]
191+
sampled_b = all_b[batch_size:]
192+
193+
# compute the logits and remove log expected counts
194+
# [batch_size, ]
195+
true_logits = (true_w * embeddings).sum(dim=1) + true_b - torch.log(target_expected_count + 1e-7)
196+
# [batch_size, n_samples]
197+
sampled_logits = (torch.matmul(embeddings, sampled_w.t()) +
198+
sampled_b - torch.log(sampled_expected_count + 1e-7))
199+
200+
# remove true labels -- we will take
201+
# softmax, so set the sampled logits of true values to a large
202+
# negative number
203+
# [batch_size, n_samples]
204+
true_in_sample_mask = sampled_ids == long_targets.unsqueeze(1)
205+
masked_sampled_logits = sampled_logits.masked_fill(true_in_sample_mask, -10000.0)
206+
# now concat the true logits as index 0
207+
# [batch_size, n_samples + 1]
208+
logits = torch.cat([true_logits.unsqueeze(1), masked_sampled_logits], dim=1)
209+
210+
# finally take log_softmax
211+
log_softmax = torch.nn.functional.log_softmax(logits, dim=1)
212+
# true log likelihood is index 0, loss = -1.0 * sum over batch
213+
# the likelihood loss can become very large if the corresponding
214+
# true logit is very small, so we apply a per-target cap here
215+
# so that a single logit for a very rare word won't dominate the batch.
216+
#nll_loss = -1.0 * torch.clamp(log_softmax[:, 0], -1000, 1e6).sum()
217+
nll_loss = -1.0 * log_softmax[:, 0].sum()
218+
return nll_loss
219+
220+
def _forward_eval(self, embeddings: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
221+
# pylint: disable=invalid-name
222+
# evaluation mode, use full softmax
223+
if self.sparse:
224+
w = self.softmax_w.weight
225+
b = self.softmax_b.weight.squeeze(1)
226+
else:
227+
w = self.softmax_w
228+
b = self.softmax_b
229+
230+
log_softmax = torch.nn.functional.log_softmax(torch.matmul(embeddings, w.t()) + b, dim=-1)
231+
if self.tie_embeddings and not self.use_character_inputs:
232+
targets_ = targets + 1
233+
else:
234+
targets_ = targets
235+
return torch.nn.functional.nll_loss(log_softmax, targets_.long(),
236+
reduction="sum")
237+
238+
def log_uniform_candidate_sampler(self, targets, choice_func=_choice):
239+
# returns sampled, true_expected_count, sampled_expected_count
240+
# targets = (batch_size, )
241+
#
242+
# samples = (n_samples, )
243+
# true_expected_count = (batch_size, )
244+
# sampled_expected_count = (n_samples, )
245+
246+
# see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/range_sampler.h
247+
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/range_sampler.cc
248+
249+
# algorithm: keep track of number of tries when doing sampling,
250+
# then expected count is
251+
# -expm1(num_tries * log1p(-p))
252+
# = (1 - (1-p)^num_tries) where p is self._probs[id]
253+
254+
np_sampled_ids, num_tries = choice_func(self._num_words, self._num_samples)
255+
256+
sampled_ids = torch.from_numpy(np_sampled_ids).to(targets.device)
257+
258+
# Compute expected count = (1 - (1-p)^num_tries) = -expm1(num_tries * log1p(-p))
259+
# P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)
260+
target_probs = torch.log((targets.float() + 2.0) / (targets.float() + 1.0)) / self._log_num_words_p1
261+
target_expected_count = -1.0 * (torch.exp(num_tries * torch.log1p(-target_probs)) - 1.0)
262+
sampled_probs = torch.log((sampled_ids.float() + 2.0) /
263+
(sampled_ids.float() + 1.0)) / self._log_num_words_p1
264+
sampled_expected_count = -1.0 * (torch.exp(num_tries * torch.log1p(-sampled_probs)) - 1.0)
265+
266+
sampled_ids.requires_grad_(False)
267+
target_expected_count.requires_grad_(False)
268+
sampled_expected_count.requires_grad_(False)
269+
270+
return sampled_ids, target_expected_count, sampled_expected_count

0 commit comments

Comments
 (0)