Skip to content

Commit 390af24

Browse files
committed
Revert "Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training (#8480)"
This reverts commit 11b7a73.
1 parent 479f5a8 commit 390af24

File tree

2 files changed

+4
-29
lines changed

2 files changed

+4
-29
lines changed

nemo/utils/callbacks/nemo_model_checkpoint.py

+1-27
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytorch_lightning
2323
import torch
24-
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
24+
from pytorch_lightning.callbacks import ModelCheckpoint
2525
from pytorch_lightning.utilities import rank_zero_info
2626

2727
from nemo.collections.common.callbacks import EMA
@@ -454,29 +454,3 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
454454
# delete markers
455455
for marker_path in existing_marker_filepaths:
456456
os.remove(marker_path)
457-
458-
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
459-
"""Checks if the previous checkpoint should be deleted.
460-
A checkpoint won't be deleted if any of the cases apply:
461-
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
462-
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
463-
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
464-
and the resumed from checkpoint is not the last checkpoint
465-
"""
466-
if previous == current:
467-
return False
468-
if not _is_local_file_protocol(previous):
469-
return True
470-
previous = Path(previous).absolute()
471-
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
472-
473-
if resume_path is not None and previous == resume_path:
474-
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
475-
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory
476-
pass
477-
else:
478-
return False
479-
if self.dirpath is None:
480-
raise ValueError(f"{self.__class__}.dirpath is None.")
481-
dirpath = Path(self.dirpath).absolute()
482-
return dirpath in previous.parents

tests/core/test_exp_manager.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,9 @@ def test_invalid_checkpoints_removed_from_topk(self, tmp_path):
946946
test_trainer2.fit(model)
947947

948948
ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()}
949-
# 3 top + 1 last
950-
assert len(ckpt_filenames) == 4
949+
# 3 top + 1 last + 1 resume ckpt since PTL >= 2.1 ensures to never delete the resume ckpt
950+
# (https://github.com/Lightning-AI/pytorch-lightning/pull/18750)
951+
assert len(ckpt_filenames) == 5
951952
assert 'epoch=9-last.ckpt' in ckpt_filenames
952953
assert 'epoch=8.ckpt' in ckpt_filenames
953954
assert 'epoch=7.ckpt' in ckpt_filenames

0 commit comments

Comments
 (0)