29
29
from omegaconf import DictConfig , OmegaConf
30
30
from pytorch_lightning .callbacks import LearningRateMonitor
31
31
32
+ from .callbacks import CheckpointEveryNSteps
33
+
32
34
PYTORCH_IMPORT_ERROR = """
33
35
Openspeech requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
34
36
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
@@ -228,7 +230,10 @@ def get_pl_trainer(
228
230
logger = logger ,
229
231
auto_scale_batch_size = configs .trainer .auto_scale_batch_size ,
230
232
max_epochs = configs .trainer .max_epochs ,
231
- callbacks = [LearningRateMonitor (logging_interval = 'step' )])
233
+ callbacks = [
234
+ LearningRateMonitor (logging_interval = 'step' ),
235
+ CheckpointEveryNSteps (configs .save_checkpoint_n_steps )
236
+ ])
232
237
elif configs .trainer .name == "gpu" :
233
238
trainer = pl .Trainer (accelerator = configs .trainer .accelerator ,
234
239
gpus = num_devices ,
@@ -239,7 +244,10 @@ def get_pl_trainer(
239
244
logger = logger ,
240
245
auto_scale_batch_size = configs .trainer .auto_scale_batch_size ,
241
246
max_epochs = configs .trainer .max_epochs ,
242
- callbacks = [LearningRateMonitor (logging_interval = 'step' )])
247
+ callbacks = [
248
+ LearningRateMonitor (logging_interval = 'step' ),
249
+ CheckpointEveryNSteps (configs .save_checkpoint_n_steps )
250
+ ])
243
251
elif configs .trainer .name == "tpu" :
244
252
trainer = pl .Trainer (accelerator = configs .trainer .accelerator ,
245
253
tpu_cores = configs .trainer .tpu_cores ,
@@ -250,7 +258,10 @@ def get_pl_trainer(
250
258
logger = logger ,
251
259
auto_scale_batch_size = configs .trainer .auto_scale_batch_size ,
252
260
max_epochs = configs .trainer .max_epochs ,
253
- callbacks = [LearningRateMonitor (logging_interval = 'step' )])
261
+ callbacks = [
262
+ LearningRateMonitor (logging_interval = 'step' ),
263
+ CheckpointEveryNSteps (configs .save_checkpoint_n_steps )
264
+ ])
254
265
elif configs .trainer .name == "gpu-fp16" :
255
266
trainer = pl .Trainer (precision = configs .trainer .precision ,
256
267
accelerator = configs .trainer .accelerator ,
@@ -275,7 +286,10 @@ def get_pl_trainer(
275
286
logger = logger ,
276
287
auto_scale_batch_size = configs .trainer .auto_scale_batch_size ,
277
288
max_epochs = configs .trainer .max_epochs ,
278
- callbacks = [LearningRateMonitor (logging_interval = 'step' )])
289
+ callbacks = [
290
+ LearningRateMonitor (logging_interval = 'step' ),
291
+ CheckpointEveryNSteps (configs .save_checkpoint_n_steps )
292
+ ])
279
293
elif configs .trainer .name == "cpu-fp64" :
280
294
trainer = pl .Trainer (precision = configs .trainer .precision ,
281
295
accelerator = configs .trainer .accelerator ,
@@ -286,7 +300,10 @@ def get_pl_trainer(
286
300
logger = logger ,
287
301
auto_scale_batch_size = configs .trainer .auto_scale_batch_size ,
288
302
max_epochs = configs .trainer .max_epochs ,
289
- callbacks = [LearningRateMonitor (logging_interval = 'step' )])
303
+ callbacks = [
304
+ LearningRateMonitor (logging_interval = 'step' ),
305
+ CheckpointEveryNSteps (configs .save_checkpoint_n_steps )
306
+ ])
290
307
else :
291
308
raise ValueError (f"Unsupported trainer: { configs .trainer .name } " )
292
309
0 commit comments