Skip to content

Commit 445fb5c

Browse files
Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training (NVIDIA#8480)
* Fix PTL2.2 saving multiple `*-last.ckpt` checkpoints when resuming from previous run Signed-off-by: He Huang (Steve) <[email protected]> * Fix missing import Signed-off-by: He Huang (Steve) <[email protected]> * fix broken test Signed-off-by: stevehuang52 <[email protected]> --------- Signed-off-by: He Huang (Steve) <[email protected]> Signed-off-by: stevehuang52 <[email protected]> Co-authored-by: Abhishree Thittenamane <[email protected]>
1 parent d0115ab commit 445fb5c

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

nemo/utils/callbacks/nemo_model_checkpoint.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytorch_lightning
2323
import torch
24-
from pytorch_lightning.callbacks import ModelCheckpoint
24+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
2525
from pytorch_lightning.utilities import rank_zero_info
2626

2727
from nemo.collections.common.callbacks import EMA
@@ -454,3 +454,29 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
454454
# delete markers
455455
for marker_path in existing_marker_filepaths:
456456
os.remove(marker_path)
457+
458+
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
459+
"""Checks if the previous checkpoint should be deleted.
460+
A checkpoint won't be deleted if any of the cases apply:
461+
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
462+
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
463+
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
464+
and the resumed from checkpoint is not the last checkpoint
465+
"""
466+
if previous == current:
467+
return False
468+
if not _is_local_file_protocol(previous):
469+
return True
470+
previous = Path(previous).absolute()
471+
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
472+
473+
if resume_path is not None and previous == resume_path:
474+
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
475+
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory
476+
pass
477+
else:
478+
return False
479+
if self.dirpath is None:
480+
raise ValueError(f"{self.__class__}.dirpath is None.")
481+
dirpath = Path(self.dirpath).absolute()
482+
return dirpath in previous.parents

tests/core/test_exp_manager.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -946,9 +946,8 @@ def test_invalid_checkpoints_removed_from_topk(self, tmp_path):
946946
test_trainer2.fit(model)
947947

948948
ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()}
949-
# 3 top + 1 last + 1 resume ckpt since PTL >= 2.1 ensures to never delete the resume ckpt
950-
# (https://github.com/Lightning-AI/pytorch-lightning/pull/18750)
951-
assert len(ckpt_filenames) == 5
949+
# 3 top + 1 last
950+
assert len(ckpt_filenames) == 4
952951
assert 'epoch=9-last.ckpt' in ckpt_filenames
953952
assert 'epoch=8.ckpt' in ckpt_filenames
954953
assert 'epoch=7.ckpt' in ckpt_filenames

0 commit comments

Comments
 (0)