-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training #19613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training. In current PTL2.2 model_checkpoint.py, it doesn't delete the previous *-last.ckpt checkpoint that the model was resumed on when saving a new *-last.ckpt checkpoint, which results in multiple checkpoints. This will cause errors of "multiple *-last.ckpt checkpoints" when trying to resume from a previous job. The fix is to check whether the checkpoint to be saved is also a -last.ckpt checkpoint. If it is, delete the previous -last.ckpt if they are under the same directory.
for more information, see https://pre-commit.ci
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #19613 +/- ##
==========================================
- Coverage 84% 53% -31%
==========================================
Files 424 416 -8
Lines 34903 34752 -151
==========================================
- Hits 29347 18434 -10913
- Misses 5556 16318 +10762 |
@stevehuang52 Thanks for sending a PR. I tried to show the problem with an example based on your description, but failed. Here is what I tried: import os
import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
import shutil
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
if os.path.exists("./checkpoints"):
shutil.rmtree("./checkpoints")
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
callback = ModelCheckpoint(dirpath="./checkpoints", monitor="step", mode="max", every_n_train_steps=2, save_last=True)
trainer_kwargs = {
"limit_train_batches": 10,
"limit_val_batches": 0,
"enable_progress_bar": False,
"enable_model_summary": False,
"logger": False,
}
model = BoringModel()
trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
trainer.fit(model, train_data)
# Resume the first time
trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
trainer.fit(model, train_data, ckpt_path="last")
# Resume one more time
trainer = Trainer(callbacks=callback, max_steps=20, **trainer_kwargs)
trainer.fit(model, train_data, ckpt_path="last")
print(os.listdir("./checkpoints"))
# ['last.ckpt', 'epoch=1-step=20.ckpt']
if __name__ == "__main__":
run() Could you please modify it to your use case? |
@awaelchli Thanks for your reply. I was using the NeMo toolkit when I got this problem. I think they set the last checkpoint format here, which makes the last checkpoint look like |
Here is the modified example with your custom last-checkpoint format: import os
import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset
import shutil
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
if os.path.exists("./checkpoints"):
shutil.rmtree("./checkpoints")
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
callback = ModelCheckpoint(
dirpath="./checkpoints", monitor="step", mode="max", every_n_train_steps=2, save_last=True
)
callback.CHECKPOINT_NAME_LAST = "{step}-last"
trainer_kwargs = {
"limit_train_batches": 10,
"limit_val_batches": 0,
"enable_progress_bar": False,
"enable_model_summary": False,
"logger": False,
}
model = BoringModel()
trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
trainer.fit(model, train_data)
# Resume the first time
trainer = Trainer(callbacks=callback, max_steps=10, **trainer_kwargs)
trainer.fit(model, train_data, ckpt_path="last")
# Resume one more time
trainer = Trainer(callbacks=callback, max_steps=20, **trainer_kwargs)
trainer.fit(model, train_data, ckpt_path="last")
print(os.listdir("./checkpoints"))
# ['step=20-last.ckpt', 'epoch=1-step=20.ckpt']
if __name__ == "__main__":
run() With Lightning 2.2 we get two checkpoints that remain: If there is an issue in Lightning, I'd like to reproduce it first before we add any fixes, because we also need to add a test case. We can't add fixes in Lightning that only apply to NeMo, because NeMo has a custom model checkpoint callback. If your issue only occurs with NeMo, it might be an indication that the logic in NeMoModelCheckpoint should be changed. If you can show that the issue is with Lightning, I'm happy to investigate more. |
@awaelchli Thanks for the update. Please take a look at my output, and you'll notice that there's line that says
|
@stevehuang52 Yes that's right. I missed that warning sorry. But if the
(or any other pattern that includes metrics) and during our training it saved these checkpoints:
then which one is the last? It's not possible to tell. That's the reason why the "finding the last checkpoint" logic doesn't handle this, and you see the warning. Note that the "last.ckpt" feature was never intended to work that way. It was always meant to be a deterministic name. To make the name include metrics is not meant to be supported in my interpretation. If possible, I would suggest Nemo to move away from using this pattern. |
Thanks for the clarification, we'll make the changes in NeMo instead. |
What does this PR do?
Fix PTL2.2 saving multiple *-last.ckpt checkpoints in resumed training.
In current PTL2.2 model_checkpoint.py, it doesn't delete the previous
*-last.ckpt
checkpoint that the model was resumed on when saving a new*-last.ckpt
checkpoint, which results in multiple checkpoints. This will cause errors of "multiple *-last.ckpt checkpoints" when trying to resume from a previous job (first time resume is fine since there's one -last.ckpt, second time resume will crash due to multiple -last.ckpt).The fix is to check whether the checkpoint to be saved is a
-last.ckpt
checkpoint and the model is resumed from a-last.ckpt
checkpoint. If it is, delete the previous -last.ckpt if they are under the same directory.