Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit e6ad6e9

Browse files
authored
Evaluate on a token-weighted basis. (#2183)
- Allow models to return a `"batch_weight"` key that will be used to weight each batch's loss. - Per @matt-peters' suggestion. - Performed in Calypso: https://github.com/allenai/calypso/blob/master/calypso/train.py#L699 - Remove unused "loss_scale". This was never set in the training config, so it should be fairly safe.
1 parent 71ebcd8 commit e6ad6e9

File tree

9 files changed

+122
-45
lines changed

9 files changed

+122
-45
lines changed

allennlp/commands/evaluate.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
8181
default="",
8282
help='a JSON structure used to override the experiment configuration')
8383

84+
subparser.add_argument('--batch-weight-key',
85+
type=str,
86+
default="",
87+
help='If non-empty, name of metric used to weight the loss on a per-batch basis.')
88+
8489
subparser.set_defaults(func=evaluate_from_args)
8590

8691
return subparser
@@ -89,7 +94,8 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
8994
def evaluate(model: Model,
9095
instances: Iterable[Instance],
9196
data_iterator: DataIterator,
92-
cuda_device: int) -> Dict[str, Any]:
97+
cuda_device: int,
98+
batch_weight_key: str) -> Dict[str, Any]:
9399
_warned_tqdm_ignores_underscores = False
94100
check_for_gpu(cuda_device)
95101
with torch.no_grad():
@@ -101,21 +107,34 @@ def evaluate(model: Model,
101107
logger.info("Iterating over dataset")
102108
generator_tqdm = Tqdm.tqdm(iterator, total=data_iterator.get_num_batches(instances))
103109

110+
# Number of batches in instances.
104111
batch_count = 0
112+
# Number of batches where the model produces a loss.
105113
loss_count = 0
114+
# Cumulative weighted loss
106115
total_loss = 0.0
116+
# Cumulative weight across all batches.
117+
total_weight = 0.0
107118

108119
for batch in generator_tqdm:
109120
batch_count += 1
110121
batch = util.move_to_device(batch, cuda_device)
111-
loss = model(**batch).get("loss")
122+
output_dict = model(**batch)
123+
loss = output_dict.get("loss")
112124

113125
metrics = model.get_metrics()
114126

115127
if loss is not None:
116128
loss_count += 1
117-
metrics["loss"] = loss.item()
118-
total_loss += loss.item()
129+
if batch_weight_key:
130+
weight = output_dict[batch_weight_key].item()
131+
else:
132+
weight = 1.0
133+
134+
total_weight += weight
135+
total_loss += loss.item() * weight
136+
# Report the average loss so far.
137+
metrics["loss"] = total_loss / total_weight
119138

120139
if (not _warned_tqdm_ignores_underscores and
121140
any(metric_name.startswith("_") for metric_name in metrics)):
@@ -128,10 +147,11 @@ def evaluate(model: Model,
128147

129148
final_metrics = model.get_metrics(reset=True)
130149
if loss_count > 0:
150+
# Sanity check
131151
if loss_count != batch_count:
132152
raise RuntimeError("The model you are trying to evaluate only sometimes " +
133153
"produced a loss!")
134-
final_metrics["loss"] = total_loss/batch_count
154+
final_metrics["loss"] = total_loss / total_weight
135155

136156
return final_metrics
137157

@@ -168,7 +188,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
168188
iterator = DataIterator.from_params(iterator_params)
169189
iterator.index_with(model.vocab)
170190

171-
metrics = evaluate(model, instances, iterator, args.cuda_device)
191+
metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key)
172192

173193
logger.info("Finished evaluating.")
174194
logger.info("Metrics:")

allennlp/commands/fine_tune.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
7070
default=False,
7171
help='outputs tqdm status on separate lines and slows tqdm refresh rate')
7272

73+
subparser.add_argument('--batch-weight-key',
74+
type=str,
75+
default="",
76+
help='If non-empty, name of metric used to weight the loss on a per-batch basis.')
77+
7378
subparser.set_defaults(func=fine_tune_model_from_args)
7479

7580
return subparser
@@ -84,15 +89,17 @@ def fine_tune_model_from_args(args: argparse.Namespace):
8489
serialization_dir=args.serialization_dir,
8590
overrides=args.overrides,
8691
extend_vocab=args.extend_vocab,
87-
file_friendly_logging=args.file_friendly_logging)
92+
file_friendly_logging=args.file_friendly_logging,
93+
batch_weight_key=args.batch_weight_key)
8894

8995

9096
def fine_tune_model_from_file_paths(model_archive_path: str,
9197
config_file: str,
9298
serialization_dir: str,
9399
overrides: str = "",
94100
extend_vocab: bool = False,
95-
file_friendly_logging: bool = False) -> Model:
101+
file_friendly_logging: bool = False,
102+
batch_weight_key: str = "") -> Model:
96103
"""
97104
A wrapper around :func:`fine_tune_model` which loads the model archive from a file.
98105
@@ -121,14 +128,16 @@ def fine_tune_model_from_file_paths(model_archive_path: str,
121128
params=params,
122129
serialization_dir=serialization_dir,
123130
extend_vocab=extend_vocab,
124-
file_friendly_logging=file_friendly_logging)
131+
file_friendly_logging=file_friendly_logging,
132+
batch_weight_key=batch_weight_key)
125133

126134

127135
def fine_tune_model(model: Model,
128136
params: Params,
129137
serialization_dir: str,
130138
extend_vocab: bool = False,
131-
file_friendly_logging: bool = False) -> Model:
139+
file_friendly_logging: bool = False,
140+
batch_weight_key: str = "") -> Model:
132141
"""
133142
Fine tunes the given model, using a set of parameters that is largely identical to those used
134143
for :func:`~allennlp.commands.train.train_model`, except that the ``model`` section is ignored,
@@ -248,7 +257,13 @@ def fine_tune_model(model: Model,
248257
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
249258

250259
if test_data and evaluate_on_test:
251-
test_metrics = evaluate(model, test_data, iterator, cuda_device=trainer._cuda_devices[0]) # pylint: disable=protected-access
260+
test_metrics = evaluate(
261+
model,
262+
test_data,
263+
iterator,
264+
cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access
265+
batch_weight_key=batch_weight_key
266+
)
252267
for key, value in test_metrics.items():
253268
metrics["test_" + key] = value
254269

allennlp/commands/train.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def train_model(params: Params,
385385
logger.info("The model will be evaluated using the best epoch weights.")
386386
test_metrics = evaluate(
387387
best_model, test_data, validation_iterator or iterator,
388-
cuda_device=trainer._cuda_devices[0] # pylint: disable=protected-access
388+
cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access,
389+
# TODO(brendanr): Pass in an arg following Joel's trainer refactor.
390+
batch_weight_key=""
389391
)
390392
for key, value in test_metrics.items():
391393
metrics["test_" + key] = value

allennlp/common/testing/model_test_case.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import Any, Dict, Set, Union
2+
from typing import Any, Dict, Set, Union, Iterable
33

44
from numpy.testing import assert_allclose
55
import torch
@@ -197,7 +197,18 @@ def check_model_computes_gradients_correctly(model: Model,
197197
print(f"Parameter: {name} had incorrect gradient: {grad}")
198198
raise Exception("Incorrect gradients found. See stdout for more info.")
199199

200-
def ensure_batch_predictions_are_consistent(self):
200+
def ensure_batch_predictions_are_consistent(
201+
self,
202+
keys_to_ignore: Iterable[str] = ()):
203+
"""
204+
Ensures that the model performs the same on a batch of instances as on individual instances.
205+
Ignores metrics matching the regexp .*loss.* and those specified explicitly.
206+
207+
Parameters
208+
----------
209+
keys_to_ignore : ``Iterable[str]``, optional (default=())
210+
Names of metrics that should not be taken into account, e.g. "batch_weight".
211+
"""
201212
self.model.eval()
202213
single_predictions = []
203214
for i, instance in enumerate(self.instances):
@@ -215,6 +226,8 @@ def ensure_batch_predictions_are_consistent(self):
215226
# Loss is particularly unstable; we'll just be satisfied if everything else is
216227
# close.
217228
continue
229+
if key in keys_to_ignore:
230+
continue
218231
single_predicted = single_predicted[0]
219232
batch_predicted = batch_predictions[key][i]
220233
if isinstance(single_predicted, torch.Tensor):

allennlp/models/bidirectional_lm.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Union
2-
31
from allennlp.data.vocabulary import Vocabulary
42
from allennlp.models.language_model import LanguageModel
53
from allennlp.models.model import Model
@@ -33,10 +31,6 @@ class BidirectionalLanguageModel(LanguageModel):
3331
dropout: ``float``, optional (default: None)
3432
If specified, dropout is applied to the contextualized embeddings before computation of
3533
the softmax. The contextualized embeddings themselves are returned without dropout.
36-
loss_scale: ``Union[float, str]``, optional (default: 1.0)
37-
This scaling factor is applied to the average language model loss.
38-
You can also specify ``"n_samples"`` in which case we compute total
39-
loss across all predictions.
4034
num_samples: ``int``, optional (default: None)
4135
If provided, the model will use ``SampledSoftmaxLoss``
4236
with the specified number of samples. Otherwise, it will use
@@ -49,15 +43,13 @@ def __init__(self,
4943
text_field_embedder: TextFieldEmbedder,
5044
contextualizer: Seq2SeqEncoder,
5145
dropout: float = None,
52-
loss_scale: Union[float, str] = 1.0,
5346
num_samples: int = None,
5447
sparse_embeddings: bool = False,
5548
initializer: InitializerApplicator = None) -> None:
5649
super().__init__(vocab=vocab,
5750
text_field_embedder=text_field_embedder,
5851
contextualizer=contextualizer,
5952
dropout=dropout,
60-
loss_scale=loss_scale,
6153
num_samples=num_samples,
6254
sparse_embeddings=sparse_embeddings,
6355
bidirectional=True,

allennlp/models/language_model.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ class LanguageModel(Model):
7777
dropout: ``float``, optional (default: None)
7878
If specified, dropout is applied to the contextualized embeddings before computation of
7979
the softmax. The contextualized embeddings themselves are returned without dropout.
80-
loss_scale: ``Union[float, str]``, optional (default: 1.0)
81-
This scaling factor is applied to the average language model loss.
82-
You can also specify ``"n_samples"`` in which case we compute total
83-
loss across all predictions.
8480
num_samples: ``int``, optional (default: None)
8581
If provided, the model will use ``SampledSoftmaxLoss``
8682
with the specified number of samples. Otherwise, it will use
@@ -97,7 +93,6 @@ def __init__(self,
9793
text_field_embedder: TextFieldEmbedder,
9894
contextualizer: Seq2SeqEncoder,
9995
dropout: float = None,
100-
loss_scale: Union[float, str] = 1.0,
10196
num_samples: int = None,
10297
sparse_embeddings: bool = False,
10398
bidirectional: bool = False,
@@ -140,7 +135,6 @@ def __init__(self,
140135
else:
141136
self._dropout = lambda x: x
142137

143-
self._loss_scale = loss_scale
144138
if initializer is not None:
145139
initializer(self)
146140

@@ -312,17 +306,12 @@ def forward(self, # type: ignore
312306
self._last_average_loss[0] = average_loss.detach().item()
313307

314308
if num_targets > 0:
315-
# loss is directly minimized
316-
if self._loss_scale == 'n_samples':
317-
scale_factor = num_targets.float()
318-
else:
319-
scale_factor = self._loss_scale
320-
321309
return_dict.update({
322-
'loss': average_loss * scale_factor,
323-
'forward_loss': forward_loss * scale_factor / num_targets.float(),
324-
'backward_loss': (backward_loss * scale_factor / num_targets.float()
325-
if backward_loss is not None else None)
310+
'loss': average_loss,
311+
'forward_loss': forward_loss / num_targets.float(),
312+
'backward_loss': (backward_loss / num_targets.float()
313+
if backward_loss is not None else None),
314+
'batch_weight': num_targets.float()
326315
})
327316
else:
328317
# average_loss zero tensor, return it for all

allennlp/tests/commands/evaluate_test.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,40 @@
11
# pylint: disable=invalid-name,no-self-use
22
import argparse
33
import json
4+
from typing import Iterator, List, Dict, Iterable
45

6+
import torch
57
from flaky import flaky
68

7-
from allennlp.commands.evaluate import evaluate_from_args, Evaluate
9+
from allennlp.commands.evaluate import evaluate_from_args, Evaluate, evaluate
810
from allennlp.common.testing import AllenNlpTestCase
11+
from allennlp.data import DataIterator, Instance
12+
from allennlp.data.dataset import Batch
13+
from allennlp.data.iterators.data_iterator import TensorDict
14+
from allennlp.models import Model
15+
16+
17+
class DummyIterator(DataIterator):
18+
def __init__(self, outputs: List[TensorDict]) -> None:
19+
super().__init__()
20+
self._outputs = outputs
21+
22+
def __call__(self,
23+
instances: Iterable[Instance],
24+
num_epochs: int = None,
25+
shuffle: bool = True) -> Iterator[TensorDict]:
26+
yield from self._outputs
27+
28+
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
29+
raise NotImplementedError
30+
31+
32+
class DummyModel(Model):
33+
def __init__(self) -> None:
34+
super().__init__(None) # type: ignore
35+
36+
def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint: disable=arguments-differ
37+
return kwargs
938

1039

1140
class TestEvaluate(AllenNlpTestCase):
@@ -16,6 +45,23 @@ def setUp(self):
1645
subparsers = self.parser.add_subparsers(title='Commands', metavar='')
1746
Evaluate().add_subparser('evaluate', subparsers)
1847

48+
def test_evaluate_calculates_average_loss(self):
49+
losses = [7.0, 9.0, 8.0]
50+
outputs = [{"loss": torch.Tensor([loss])} for loss in losses]
51+
iterator = DummyIterator(outputs)
52+
metrics = evaluate(DummyModel(), None, iterator, -1, "")
53+
self.assertAlmostEqual(metrics["loss"], 8.0)
54+
55+
def test_evaluate_calculates_average_loss_with_weights(self):
56+
losses = [7.0, 9.0, 8.0]
57+
weights = [10, 2, 1.5]
58+
inputs = zip(losses, weights)
59+
outputs = [{"loss": torch.Tensor([loss]), "batch_weight": torch.Tensor([weight])}
60+
for loss, weight in inputs]
61+
iterator = DummyIterator(outputs)
62+
metrics = evaluate(DummyModel(), None, iterator, -1, "batch_weight")
63+
self.assertAlmostEqual(metrics["loss"], (70 + 18 + 12)/13.5)
64+
1965
@flaky
2066
def test_evaluate_from_args(self):
2167
kebab_args = ["evaluate", str(self.FIXTURES_ROOT / "bidaf" / "serialization" / "model.tar.gz"),

allennlp/tests/models/bidirectional_lm_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ def test_bidirectional_lm_can_train_save_load(self):
1616
self.ensure_model_can_train_save_and_load(self.param_file)
1717

1818
def test_batch_predictions_are_consistent(self):
19-
self.ensure_batch_predictions_are_consistent()
19+
self.ensure_batch_predictions_are_consistent(keys_to_ignore=["batch_weight"])
2020

2121
def test_forward_pass_runs_correctly(self):
2222
training_tensors = self.dataset.as_tensor_dict()
2323
result = self.model(**training_tensors)
2424

25-
assert set(result) == {"loss", "forward_loss", "backward_loss",
26-
"lm_embeddings", "noncontextual_token_embeddings", "mask"}
25+
assert set(result) == {"loss", "forward_loss", "backward_loss", "lm_embeddings",
26+
"noncontextual_token_embeddings", "mask", "batch_weight"}
2727

2828
# The model should preserve the BOS / EOS tokens.
2929
embeddings = result["lm_embeddings"]

allennlp/tests/models/language_model_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def test_unidirectional_language_model_can_train_save_and_load(self):
2323
self.ensure_model_can_train_save_and_load(self.param_file)
2424

2525
def test_batch_predictions_are_consistent(self):
26-
self.ensure_batch_predictions_are_consistent()
26+
self.ensure_batch_predictions_are_consistent(keys_to_ignore=["batch_weight"])
2727

2828
def test_forward_pass_runs_correctly(self):
2929
training_tensors = self.dataset.as_tensor_dict()
3030
result = self.model(**training_tensors)
3131

32-
assert set(result) == {"loss", "forward_loss", "backward_loss",
33-
"lm_embeddings", "noncontextual_token_embeddings", "mask"}
32+
assert set(result) == {"loss", "forward_loss", "backward_loss", "lm_embeddings",
33+
"noncontextual_token_embeddings", "mask", "batch_weight"}
3434

3535
# The model should preserve the BOS / EOS tokens.
3636
embeddings = result["lm_embeddings"]

0 commit comments

Comments
 (0)