Skip to content

Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training #19613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

stevehuang52
Copy link

@stevehuang52 stevehuang52 commented Mar 11, 2024

What does this PR do?

Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training.

In current PTL2.2 model_checkpoint.py, it doesn't delete the previous *-last.ckpt checkpoint that the model was resumed on when saving a new *-last.ckpt checkpoint, which results in multiple checkpoints. This will cause errors of "multiple *-last.ckpt checkpoints" when trying to resume from a previous job (first time resume is fine since there's one -last.ckpt, second time resume will crash due to multiple -last.ckpt).

The fix is to check whether the checkpoint to be saved is a -last.ckpt checkpoint and the model is resumed from a -last.ckpt checkpoint. If it is, delete the previous -last.ckpt if they are under the same directory.

Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training.

In current PTL2.2 model_checkpoint.py, it doesn't delete the previous *-last.ckpt checkpoint that the model was resumed on when saving a new *-last.ckpt checkpoint, which results in multiple checkpoints. This will cause errors of "multiple *-last.ckpt checkpoints" when trying to resume from a previous job.

The fix is to check whether the checkpoint to be saved is also a -last.ckpt checkpoint. If it is, delete the previous -last.ckpt if they are under the same directory.
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Mar 11, 2024
Copy link

codecov bot commented Mar 11, 2024

Codecov Report

Merging #19613 (b46dbcc) into master (096b063) will decrease coverage by 31%.
The diff coverage is 67%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19613      +/-   ##
==========================================
- Coverage      84%      53%     -31%     
==========================================
  Files         424      416       -8     
  Lines       34903    34752     -151     
==========================================
- Hits        29347    18434   -10913     
- Misses       5556    16318   +10762     

@awaelchli
Copy link
Contributor

@stevehuang52 Thanks for sending a PR. I tried to show the problem with an example based on your description, but failed. Here is what I tried:

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
import shutil


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    if os.path.exists("./checkpoints"):
        shutil.rmtree("./checkpoints")

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    callback = ModelCheckpoint(dirpath="./checkpoints", monitor="step", mode="max", every_n_train_steps=2, save_last=True)
    trainer_kwargs = {
        "limit_train_batches": 10,
        "limit_val_batches": 0,
        "enable_progress_bar": False,
        "enable_model_summary": False,
        "logger": False,
    }

    model = BoringModel()
    trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
    trainer.fit(model, train_data)

    # Resume the first time
    trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
    trainer.fit(model, train_data, ckpt_path="last")

    # Resume one more time
    trainer = Trainer(callbacks=callback, max_steps=20, **trainer_kwargs)
    trainer.fit(model, train_data, ckpt_path="last")

    print(os.listdir("./checkpoints"))
    # ['last.ckpt', 'epoch=1-step=20.ckpt']


if __name__ == "__main__":
    run()

Could you please modify it to your use case?
I suspect that perhaps you are not using the Trainer's resume functionality via .fit(..., ckpt_path=...) perhaps?

@awaelchli awaelchli self-assigned this Mar 11, 2024
@awaelchli awaelchli added callback: model checkpoint community This PR is from the community labels Mar 11, 2024
@stevehuang52
Copy link
Author

stevehuang52 commented Mar 12, 2024

@awaelchli Thanks for your reply. I was using the NeMo toolkit when I got this problem. I think they set the last checkpoint format here, which makes the last checkpoint look like xxxxx-step=x-epoch=xx-last.ckpt, instead of the default last.ckpt as in your example. I tried manually setting callback.CHECKPOINT_NAME_LAST="{step}-last.ckpt" in your script, but it couldn't find the stored last checkpoint when trying to resume from training. Could you please take a look? Thanks again~

@awaelchli
Copy link
Contributor

awaelchli commented Mar 13, 2024

Here is the modified example with your custom last-checkpoint format:

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
import shutil


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    if os.path.exists("./checkpoints"):
        shutil.rmtree("./checkpoints")

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    callback = ModelCheckpoint(
        dirpath="./checkpoints", monitor="step", mode="max", every_n_train_steps=2, save_last=True
    )
    callback.CHECKPOINT_NAME_LAST = "{step}-last"

    trainer_kwargs = {
        "limit_train_batches": 10,
        "limit_val_batches": 0,
        "enable_progress_bar": False,
        "enable_model_summary": False,
        "logger": False,
    }

    model = BoringModel()
    trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
    trainer.fit(model, train_data)

    # Resume the first time
    trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
    trainer.fit(model, train_data, ckpt_path="last")

    # Resume one more time
    trainer = Trainer(callbacks=callback, max_steps=20, **trainer_kwargs)
    trainer.fit(model, train_data, ckpt_path="last")

    print(os.listdir("./checkpoints"))
    # ['step=20-last.ckpt', 'epoch=1-step=20.ckpt']


if __name__ == "__main__":
    run()

With Lightning 2.2 we get two checkpoints that remain: ['step=20-last.ckpt', 'epoch=1-step=20.ckpt'], which looks ok I think.

If there is an issue in Lightning, I'd like to reproduce it first before we add any fixes, because we also need to add a test case. We can't add fixes in Lightning that only apply to NeMo, because NeMo has a custom model checkpoint callback. If your issue only occurs with NeMo, it might be an indication that the logic in NeMoModelCheckpoint should be changed. If you can show that the issue is with Lightning, I'm happy to investigate more.

@stevehuang52
Copy link
Author

@awaelchli Thanks for the update. Please take a look at my output, and you'll notice that there's line that says .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set ModelCheckpoint(..., save_last=True)., which means that the previous last checkpoint was not properly loaded. However, this message wasn't there when using last.ckpt as default. I think this might be related to how PTL finds the last checkpoint here last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?" in here. What do you think?

(nemo2) (base) ➜  nemo_experiments git:(main) python debug_ptl.py
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python debug_ptl.py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_steps=5` reached.
> /home/heh/github/NeMo-main/examples/asr/nemo_experiments/debug_ptl.py(64)run()
     63     import ipdb; ipdb.set_trace()
---> 64     trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
     65     trainer.fit(model, train_data, ckpt_path="last")

ipdb> n
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python debug_ptl.py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
> /home/heh/github/NeMo-main/examples/asr/nemo_experiments/debug_ptl.py(65)run()
     64     trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
---> 65     trainer.fit(model, train_data, ckpt_path="last")
     66 

ipdb> n
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /home/heh/github/NeMo-main/examples/asr/nemo_experiments/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/heh/anaconda3/envs/nemo2/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_steps=10` reached.
> /home/heh/github/NeMo-main/examples/asr/nemo_experiments/debug_ptl.py(68)run()
     67     # Resume one more time
---> 68     import ipdb; ipdb.set_trace()
     69     trainer = Trainer(callbacks=callback, max_steps=20, **trainer_kwargs)

ipdb> import pytorch_lightning as pl
ipdb> pl.__version__
'2.2.0.post0'
ipdb> 

@awaelchli
Copy link
Contributor

awaelchli commented Mar 16, 2024

@stevehuang52 Yes that's right. I missed that warning sorry. But if the self.CHECKPOINT_NAME_LAST is a pattern with metrics, then finding the last checkpoint is ill-defined and can't be done. For example, if we set

model_checkpoint.CHECKPOINT_NAME_LAST = "{val_loss:.2f}-last"

(or any other pattern that includes metrics)

and during our training it saved these checkpoints:

"val_loss=1.24-last.ckpt"
"val_loss=0.92-last.ckpt"
"val_loss=2.22-last.ckpt"

then which one is the last? It's not possible to tell. That's the reason why the "finding the last checkpoint" logic doesn't handle this, and you see the warning. Note that the "last.ckpt" feature was never intended to work that way. It was always meant to be a deterministic name. To make the name include metrics is not meant to be supported in my interpretation.

If possible, I would suggest Nemo to move away from using this pattern.

@stevehuang52
Copy link
Author

Thanks for the clarification, we'll make the changes in NeMo instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: model checkpoint community This PR is from the community pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants