Skip to content

Commit b59dfd5

Browse files
authored
Merge pull request #57 from openspeech-team/checkpoint-n-step
Add CheckpointEveryNSteps class for save checkpoint every N steps.
2 parents 5d24f71 + 38e9b06 commit b59dfd5

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

openspeech/callbacks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, W
20+
21+
import os
22+
import pytorch_lightning as pl
23+
24+
25+
class CheckpointEveryNSteps(pl.Callback):
26+
"""
27+
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
28+
based on validation loss.
29+
30+
Args:
31+
save_step_frequency: how often to save in steps
32+
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's default filename, don't use ours.
33+
"""
34+
35+
def __init__(
36+
self,
37+
save_step_frequency,
38+
use_modelcheckpoint_filename=False,
39+
) -> None:
40+
self.save_step_frequency = save_step_frequency
41+
self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
42+
43+
def on_batch_end(self, trainer: pl.Trainer, _):
44+
""" Check if we should save a checkpoint after every train batch """
45+
epoch = trainer.current_epoch
46+
global_step = trainer.global_step
47+
if global_step % self.save_step_frequency == 0:
48+
if self.use_modelcheckpoint_filename:
49+
filename = trainer.checkpoint_callback.filename
50+
else:
51+
filename = f"{epoch=}_{global_step=}.ckpt"
52+
ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
53+
trainer.save_checkpoint(ckpt_path)

openspeech/dataclass/configurations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class BaseTrainerConfigs(OpenspeechDataclass):
202202
max_epochs: int = field(
203203
default=20, metadata={"help": "Stop training once this number of epochs is reached."}
204204
)
205+
save_checkpoint_n_steps: int = field(
206+
default=10000, metadata={"help": "Save a checkpoint every N steps."}
207+
)
205208
auto_scale_batch_size: str = field(
206209
default="binsearch", metadata={"help": "If set to True, will initially run a batch size finder trying to find "
207210
"the largest batch size that fits into memory."}

openspeech/utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from omegaconf import DictConfig, OmegaConf
3030
from pytorch_lightning.callbacks import LearningRateMonitor
3131

32+
from .callbacks import CheckpointEveryNSteps
33+
3234
PYTORCH_IMPORT_ERROR = """
3335
Openspeech requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
3436
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
@@ -228,7 +230,10 @@ def get_pl_trainer(
228230
logger=logger,
229231
auto_scale_batch_size=configs.trainer.auto_scale_batch_size,
230232
max_epochs=configs.trainer.max_epochs,
231-
callbacks=[LearningRateMonitor(logging_interval='step')])
233+
callbacks=[
234+
LearningRateMonitor(logging_interval='step'),
235+
CheckpointEveryNSteps(configs.save_checkpoint_n_steps)
236+
])
232237
elif configs.trainer.name == "gpu":
233238
trainer = pl.Trainer(accelerator=configs.trainer.accelerator,
234239
gpus=num_devices,
@@ -239,7 +244,10 @@ def get_pl_trainer(
239244
logger=logger,
240245
auto_scale_batch_size=configs.trainer.auto_scale_batch_size,
241246
max_epochs=configs.trainer.max_epochs,
242-
callbacks=[LearningRateMonitor(logging_interval='step')])
247+
callbacks=[
248+
LearningRateMonitor(logging_interval='step'),
249+
CheckpointEveryNSteps(configs.save_checkpoint_n_steps)
250+
])
243251
elif configs.trainer.name == "tpu":
244252
trainer = pl.Trainer(accelerator=configs.trainer.accelerator,
245253
tpu_cores=configs.trainer.tpu_cores,
@@ -250,7 +258,10 @@ def get_pl_trainer(
250258
logger=logger,
251259
auto_scale_batch_size=configs.trainer.auto_scale_batch_size,
252260
max_epochs=configs.trainer.max_epochs,
253-
callbacks=[LearningRateMonitor(logging_interval='step')])
261+
callbacks=[
262+
LearningRateMonitor(logging_interval='step'),
263+
CheckpointEveryNSteps(configs.save_checkpoint_n_steps)
264+
])
254265
elif configs.trainer.name == "gpu-fp16":
255266
trainer = pl.Trainer(precision=configs.trainer.precision,
256267
accelerator=configs.trainer.accelerator,
@@ -275,7 +286,10 @@ def get_pl_trainer(
275286
logger=logger,
276287
auto_scale_batch_size=configs.trainer.auto_scale_batch_size,
277288
max_epochs=configs.trainer.max_epochs,
278-
callbacks=[LearningRateMonitor(logging_interval='step')])
289+
callbacks=[
290+
LearningRateMonitor(logging_interval='step'),
291+
CheckpointEveryNSteps(configs.save_checkpoint_n_steps)
292+
])
279293
elif configs.trainer.name == "cpu-fp64":
280294
trainer = pl.Trainer(precision=configs.trainer.precision,
281295
accelerator=configs.trainer.accelerator,
@@ -286,7 +300,10 @@ def get_pl_trainer(
286300
logger=logger,
287301
auto_scale_batch_size=configs.trainer.auto_scale_batch_size,
288302
max_epochs=configs.trainer.max_epochs,
289-
callbacks=[LearningRateMonitor(logging_interval='step')])
303+
callbacks=[
304+
LearningRateMonitor(logging_interval='step'),
305+
CheckpointEveryNSteps(configs.save_checkpoint_n_steps)
306+
])
290307
else:
291308
raise ValueError(f"Unsupported trainer: {configs.trainer.name}")
292309

0 commit comments

Comments
 (0)