|
21 | 21 |
|
22 | 22 | import pytorch_lightning
|
23 | 23 | import torch
|
24 |
| -from pytorch_lightning.callbacks import ModelCheckpoint |
| 24 | +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol |
25 | 25 | from pytorch_lightning.utilities import rank_zero_info
|
26 | 26 |
|
27 | 27 | from nemo.collections.common.callbacks import EMA
|
@@ -454,3 +454,29 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
|
454 | 454 | # delete markers
|
455 | 455 | for marker_path in existing_marker_filepaths:
|
456 | 456 | os.remove(marker_path)
|
| 457 | + |
| 458 | + def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: |
| 459 | + """Checks if the previous checkpoint should be deleted. |
| 460 | + A checkpoint won't be deleted if any of the cases apply: |
| 461 | + - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) |
| 462 | + - The previous checkpoint is not in the current checkpoint directory and the filesystem is local |
| 463 | + - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local |
| 464 | + and the resumed from checkpoint is not the last checkpoint |
| 465 | + """ |
| 466 | + if previous == current: |
| 467 | + return False |
| 468 | + if not _is_local_file_protocol(previous): |
| 469 | + return True |
| 470 | + previous = Path(previous).absolute() |
| 471 | + resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None |
| 472 | + |
| 473 | + if resume_path is not None and previous == resume_path: |
| 474 | + if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"): |
| 475 | + # delete the previous `-last.ckpt` checkpoint when current saved checkpoint is also `-last.ckpt`, if they're in the same directory |
| 476 | + pass |
| 477 | + else: |
| 478 | + return False |
| 479 | + if self.dirpath is None: |
| 480 | + raise ValueError(f"{self.__class__}.dirpath is None.") |
| 481 | + dirpath = Path(self.dirpath).absolute() |
| 482 | + return dirpath in previous.parents |
0 commit comments