Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 6ea273e

Browse files
mikerossgithubjoelgrus
authored andcommitted
Allow checkpointer to be initialized from params (#2491)
* Allow passing checkpointer as argument to trainer. Also adds some checkpointer unittests Added unit test for registration Added unit test comments Added checkpointer unit test for configuration error * Added checkpointer unit test comments
1 parent b0ea7ab commit 6ea273e

File tree

3 files changed

+154
-9
lines changed

3 files changed

+154
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# pylint: disable=invalid-name
2+
import os
3+
import re
4+
import time
5+
6+
from allennlp.common.testing import AllenNlpTestCase
7+
from allennlp.training.checkpointer import Checkpointer
8+
from allennlp.common.params import Params
9+
from allennlp.training.trainer import Trainer
10+
from allennlp.common.checks import ConfigurationError
11+
12+
13+
class TestCheckpointer(AllenNlpTestCase):
14+
def retrieve_and_delete_saved(self):
15+
"""
16+
Helper function for the tests below. Finds the weight and training state files in
17+
self.TEST_DIR, parses their names for the epochs that were saved, deletes them,
18+
and returns the saved epochs as two lists of integers.
19+
"""
20+
serialization_files = os.listdir(self.TEST_DIR)
21+
model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
22+
found_model_epochs = [int(re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1))
23+
for x in model_checkpoints]
24+
for f in model_checkpoints:
25+
os.remove(os.path.join(self.TEST_DIR, f))
26+
training_checkpoints = [x for x in serialization_files if "training_state_epoch" in x]
27+
found_training_epochs = [int(re.search(r"training_state_epoch_([0-9\.\-]+)\.th", x).group(1))
28+
for x in training_checkpoints]
29+
for f in training_checkpoints:
30+
os.remove(os.path.join(self.TEST_DIR, f))
31+
return sorted(found_model_epochs), sorted(found_training_epochs)
32+
33+
def test_default(self):
34+
"""
35+
Tests that the default behavior keeps just the last 20 checkpoints.
36+
"""
37+
default_num_to_keep = 20
38+
num_epochs = 30
39+
target = list(range(num_epochs - default_num_to_keep, num_epochs))
40+
41+
checkpointer = Checkpointer(serialization_dir=self.TEST_DIR)
42+
43+
for e in range(num_epochs):
44+
checkpointer.save_checkpoint(epoch=e,
45+
model_state={"epoch": e},
46+
training_states={"epoch": e},
47+
is_best_so_far=False)
48+
models, training = self.retrieve_and_delete_saved()
49+
assert models == training == target
50+
51+
def test_with_time(self):
52+
"""
53+
Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved
54+
after enough time has elapsed between epochs.
55+
"""
56+
num_to_keep = 10
57+
num_epochs = 30
58+
target = list(range(num_epochs - num_to_keep, num_epochs))
59+
pauses = [5, 18, 26]
60+
target = sorted(set(target + pauses))
61+
checkpointer = Checkpointer(serialization_dir=self.TEST_DIR,
62+
num_serialized_models_to_keep=num_to_keep,
63+
keep_serialized_model_every_num_seconds=1)
64+
for e in range(num_epochs):
65+
if e in pauses:
66+
time.sleep(2)
67+
checkpointer.save_checkpoint(epoch=e,
68+
model_state={"epoch": e},
69+
training_states={"epoch": e},
70+
is_best_so_far=False)
71+
models, training = self.retrieve_and_delete_saved()
72+
assert models == training == target
73+
74+
def test_configuration_error_when_passed_as_conflicting_argument_to_trainer(self):
75+
"""
76+
Users should initialize Trainer either with an instance of Checkpointer or by specifying
77+
parameter values for num_serialized_models_to_keep and keep_serialized_model_every_num_seconds.
78+
Check that Trainer raises a ConfigurationError if both methods are used at the same time.
79+
"""
80+
with self.assertRaises(ConfigurationError):
81+
Trainer(None, None, None, None,
82+
num_serialized_models_to_keep=30,
83+
keep_serialized_model_every_num_seconds=None,
84+
checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
85+
num_serialized_models_to_keep=40,
86+
keep_serialized_model_every_num_seconds=2))
87+
with self.assertRaises(ConfigurationError):
88+
Trainer(None, None, None, None,
89+
num_serialized_models_to_keep=20,
90+
keep_serialized_model_every_num_seconds=2,
91+
checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
92+
num_serialized_models_to_keep=40,
93+
keep_serialized_model_every_num_seconds=2))
94+
try:
95+
Trainer(None, None, None, None,
96+
checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
97+
num_serialized_models_to_keep=40,
98+
keep_serialized_model_every_num_seconds=2))
99+
except ConfigurationError:
100+
self.fail("Configuration Error raised for passed checkpointer")
101+
102+
def test_registered_subclass(self):
103+
"""
104+
Tests that registering Checkpointer subclasses works correctly.
105+
"""
106+
107+
@Checkpointer.register("checkpointer_subclass")
108+
class CheckpointerSubclass(Checkpointer):
109+
def __init__(self, x: int, y: int) -> None:
110+
super().__init__()
111+
self.x = x
112+
self.y = y
113+
114+
sub_inst = Checkpointer.from_params(Params({"type": "checkpointer_subclass", "x": 1, "y": 3}))
115+
assert sub_inst.__class__ == CheckpointerSubclass
116+
assert sub_inst.x == 1 and sub_inst.y == 3

allennlp/training/checkpointer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
import torch
1010

11+
from allennlp.common.registrable import Registrable
1112
from allennlp.nn import util as nn_util
1213

1314
logger = logging.getLogger(__name__)
1415

15-
class Checkpointer:
16+
class Checkpointer(Registrable):
1617
"""
1718
This class implements the functionality for checkpointing your model and trainer state
1819
during training. It is agnostic as to what those states look like (they are typed as

allennlp/training/trainer.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self,
5050
serialization_dir: Optional[str] = None,
5151
num_serialized_models_to_keep: int = 20,
5252
keep_serialized_model_every_num_seconds: int = None,
53+
checkpointer: Checkpointer = None,
5354
model_save_interval: float = None,
5455
cuda_device: Union[int, List] = -1,
5556
grad_norm: Optional[float] = None,
@@ -115,6 +116,11 @@ def __init__(self,
115116
To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
116117
between permanently saved checkpoints. Note that this option is only used if
117118
num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
119+
checkpointer : ``Checkpointer``, optional (default=None)
120+
An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
121+
the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
122+
not be specified. The caller is responsible for initializing the checkpointer so that it is
123+
consistent with serialization_dir.
118124
model_save_interval : ``float``, optional (default=None)
119125
If provided, then serialize models every ``model_save_interval``
120126
seconds within single epochs. In all cases, models are also saved
@@ -196,9 +202,19 @@ def __init__(self,
196202

197203
self._num_epochs = num_epochs
198204

199-
self._checkpointer = Checkpointer(serialization_dir,
200-
keep_serialized_model_every_num_seconds,
201-
num_serialized_models_to_keep)
205+
if checkpointer is not None:
206+
# We can't easily check if these parameters were passed in, so check against their default values.
207+
# We don't check against serialization_dir since it is also used by the parent class.
208+
if num_serialized_models_to_keep != 20 or \
209+
keep_serialized_model_every_num_seconds is not None:
210+
raise ConfigurationError(
211+
"When passing a custom Checkpointer, you may not also pass in separate checkpointer "
212+
"args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'.")
213+
self._checkpointer = checkpointer
214+
else:
215+
self._checkpointer = Checkpointer(serialization_dir,
216+
keep_serialized_model_every_num_seconds,
217+
num_serialized_models_to_keep)
202218

203219
self._model_save_interval = model_save_interval
204220

@@ -683,9 +699,22 @@ def from_params(cls, # type: ignore
683699
else:
684700
momentum_scheduler = None
685701

686-
num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
687-
keep_serialized_model_every_num_seconds = params.pop_int(
688-
"keep_serialized_model_every_num_seconds", None)
702+
if 'checkpointer' in params:
703+
if 'keep_serialized_model_every_num_seconds' in params or \
704+
'num_serialized_models_to_keep' in params:
705+
raise ConfigurationError(
706+
"Checkpointer may be initialized either from the 'checkpointer' key or from the "
707+
"keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
708+
" but the passed config uses both methods.")
709+
checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
710+
else:
711+
num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
712+
keep_serialized_model_every_num_seconds = params.pop_int(
713+
"keep_serialized_model_every_num_seconds", None)
714+
checkpointer = Checkpointer(
715+
serialization_dir=serialization_dir,
716+
num_serialized_models_to_keep=num_serialized_models_to_keep,
717+
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds)
689718
model_save_interval = params.pop_float("model_save_interval", None)
690719
summary_interval = params.pop_int("summary_interval", 100)
691720
histogram_interval = params.pop_int("histogram_interval", None)
@@ -707,8 +736,7 @@ def from_params(cls, # type: ignore
707736
grad_clipping=grad_clipping,
708737
learning_rate_scheduler=lr_scheduler,
709738
momentum_scheduler=momentum_scheduler,
710-
num_serialized_models_to_keep=num_serialized_models_to_keep,
711-
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
739+
checkpointer=checkpointer,
712740
model_save_interval=model_save_interval,
713741
summary_interval=summary_interval,
714742
histogram_interval=histogram_interval,

0 commit comments

Comments
 (0)