Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 15e288f

Browse files
EpochCallBack for tracking epoch (#4540)
* EpochCallback for tracking epoch in the model * minor lint * updated CHANGELOG * added unit test for track epoch callback * Update allennlp/training/trainer.py Co-authored-by: Matt Gardner <[email protected]>
1 parent 9209bc9 commit 15e288f

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
- Added the option to specify `requires_grad: false` within an optimizer's parameter groups.
2828
- Added the `file-friendly-logging` flag back to the `train` command. Also added this flag to the `predict`, `evaluate`, and `find-learning-rate` commands.
29+
- Added an `EpochCallback` to track current epoch as a model class member.
2930

3031
### Removed
3132

allennlp/training/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from allennlp.training.checkpointer import Checkpointer
22
from allennlp.training.tensorboard_writer import TensorboardWriter
33
from allennlp.training.no_op_trainer import NoOpTrainer
4-
from allennlp.training.trainer import Trainer, GradientDescentTrainer, BatchCallback, EpochCallback
4+
from allennlp.training.trainer import (
5+
Trainer,
6+
GradientDescentTrainer,
7+
BatchCallback,
8+
EpochCallback,
9+
TrackEpochCallback,
10+
)

allennlp/training/trainer.py

+23
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,29 @@ def __call__(
176176
EpochCallback.register("null")(EpochCallback)
177177

178178

179+
@EpochCallback.register("track_epoch_callback")
180+
class TrackEpochCallback:
181+
"""
182+
A callback that you can pass to the `GradientDescentTrainer` to access the current epoch number
183+
in your model during training. This callback sets `model.epoch`, which can be read inside of
184+
`model.forward()`. Since the EpochCallback passes `epoch=-1`
185+
at the start of the training, we set `model.epoch = epoch + 1` which now denotes the number of
186+
completed epochs at a given training state.
187+
"""
188+
189+
def __init__(self):
190+
super().__init__()
191+
192+
def __call__(
193+
self,
194+
trainer: "GradientDescentTrainer",
195+
metrics: Dict[str, Any],
196+
epoch: int,
197+
is_master: bool,
198+
) -> None:
199+
trainer.model.epoch = epoch + 1
200+
201+
179202
@Trainer.register("gradient_descent", constructor="from_partial_objects")
180203
class GradientDescentTrainer(Trainer):
181204
"""

tests/training/trainer_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
TensorboardWriter,
2929
BatchCallback,
3030
EpochCallback,
31+
TrackEpochCallback,
3132
)
3233
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
3334
from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler
@@ -986,6 +987,19 @@ def __call__(
986987
expected_calls = [epoch for epoch in range(-1, 4)]
987988
assert trainer.epoch_callback_calls == expected_calls
988989

990+
def test_track_epoch_callback(self):
991+
num_epochs = 4
992+
trainer = GradientDescentTrainer(
993+
self.model,
994+
self.optimizer,
995+
self.data_loader,
996+
num_epochs=num_epochs,
997+
validation_data_loader=self.validation_data_loader,
998+
epoch_callbacks=[TrackEpochCallback()],
999+
)
1000+
trainer.train()
1001+
assert trainer.model.epoch == num_epochs
1002+
9891003
def test_total_loss_is_average_of_batch_loss(self):
9901004

9911005
batches_per_epoch = 3

0 commit comments

Comments
 (0)