Skip to content

Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training #8480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 23, 2024
28 changes: 27 additions & 1 deletion nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
from pytorch_lightning.utilities import rank_zero_info

from nemo.collections.common.callbacks import EMA
Expand Down Expand Up @@ -454,3 +454,29 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
# delete markers
for marker_path in existing_marker_filepaths:
os.remove(marker_path)

def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
Copy link
Collaborator

@athitten athitten Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevehuang52 we want to avoid overriding PTL's protected methods in NeMo as much as possible so that NeMo doesn't break when PTL changes something in the protected methods.

Maybe we can create this PR in PTL itself ? Since its not a lot of modification and just addition of couple of lines in _should_remove_checkpoint function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@athitten thanks for the advice, I've created a PR to PTL Lightning-AI/pytorch-lightning#19613, let's see if they'll approve

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you very much @stevehuang52.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems they closed the PR above. Is the resolution implemented here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994 Yeah they don't want to do the fix in PTL, we will do the fix in NeMo (which is this PR).

"""Checks if the previous checkpoint should be deleted.
A checkpoint won't be deleted if any of the cases apply:
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
and the resumed from checkpoint is not the last checkpoint
"""
if previous == current:
return False
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None

if resume_path is not None and previous == resume_path:
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory
pass
else:
return False
if self.dirpath is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevehuang52 why have we added if self.dirpath is None: instead of assert self.dirpath is not None which is the case in PTL code here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@athitten I remembered @titu1994 told me to avoid using assert in nemo, and to raise exceptions instead so that they can be caught by error handling code. However, it's up to you if you would like to keep it identical to PTL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, assert hides actual error, we should raise it properly with useful error message (like here)

raise ValueError(f"{self.__class__}.dirpath is None.")
dirpath = Path(self.dirpath).absolute()
return dirpath in previous.parents
5 changes: 2 additions & 3 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,8 @@ def test_invalid_checkpoints_removed_from_topk(self, tmp_path):
test_trainer2.fit(model)

ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()}
# 3 top + 1 last + 1 resume ckpt since PTL >= 2.1 ensures to never delete the resume ckpt
# (https://github.com/Lightning-AI/pytorch-lightning/pull/18750)
assert len(ckpt_filenames) == 5
# 3 top + 1 last
assert len(ckpt_filenames) == 4
assert 'epoch=9-last.ckpt' in ckpt_filenames
assert 'epoch=8.ckpt' in ckpt_filenames
assert 'epoch=7.ckpt' in ckpt_filenames
Expand Down