Skip to content

Fix progress bar display to correctly handle iterable dataset and max_steps during training #20869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

bandpooja
Copy link

@bandpooja bandpooja commented Jun 1, 2025

What does this PR do?

This PR fixes the progress bar display in PyTorch Lightning to correctly handle the case when max_steps is set and max_epochs is -1 (infinite epochs mode). Previously, the progress bar did not accurately reflect the total number of batches to process when training was limited by max_steps, causing confusing or incomplete progress updates.

Fixes #20862 and #20124

Does this PR introduce any breaking changes?

No breaking changes introduced. This is a UI/progress bar improvement only.


Additional notes

LOCAL TESTS

  1. max_steps > total training batches
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn
from pytorch_lightning import Trainer

# Dummy Dataset
class DummyDataset(Dataset):
    def __len__(self):
        return 10000  # Large enough to allow many steps

    def __getitem__(self, idx):
        x = torch.randn(10)
        y = torch.randint(0, 2, (1,))
        return x, y[0]

# Simple Model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=32)

model = DummyModel()

trainer = Trainer(
    max_steps=500,              # 👈 This tests the changes!
    accelerator="cpu",
    log_every_n_steps=1,
    enable_model_summary=False,
)

trainer.fit(model, train_dataloaders=loader)
Screenshot 2025-06-01 at 1 06 29 AM Screenshot 2025-06-01 at 1 05 13 AM
  1. max_steps < total training batches
trainer = Trainer(
    max_steps=100,              # 👈 This tests the changes!
    accelerator="cpu",
    log_every_n_steps=1,
    enable_model_summary=False,
)

trainer.fit(model, train_dataloaders=loader)
Screenshot 2025-06-01 at 1 11 34 AM
  1. Training with iterable dataset as done in Why does the progress bar not show the total steps when using iterable dataset? #20124
from torch.utils.data import IterableDataset

class InfiniteIterableDataset(IterableDataset):
    def __iter__(self):
        while True:
            x = torch.randn(10)
            y = torch.randint(0, 2, (1,))
            yield x, y[0] # infinite stream

dataset = InfiniteIterableDataset()
loader = DataLoader(dataset, batch_size=32)

model = DummyModel()

trainer = Trainer(
    max_steps=500,              # 👈 This tests the changes!
    accelerator="cpu",
    log_every_n_steps=1,
    enable_model_summary=False,
)

trainer.fit(model, train_dataloaders=loader)
Screenshot 2025-06-01 at 1 15 19 AM

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jun 1, 2025
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have a test that this return the correct number, pls

@Borda Borda added the waiting on author Waiting on user action, correction, or update label Jun 2, 2025
@bandpooja
Copy link
Author

bandpooja commented Jun 3, 2025

@Borda, I'm still getting up to speed with writing effective tests — happy to hear any feedback!

Screenshot 2025-06-02 at 11 12 32 PM

Comment on lines +238 to +239
# tqdm total steps should equal max_steps for iterator with no length
assert trainer.estimated_stepping_batches == max_steps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have assert on progress_bar property: total_train_batches

@Borda
Copy link
Member

Borda commented Jun 3, 2025

I'm still getting up to speed with writing effective tests — happy to hear any feedback!

overall the test looks good, let's just have a direct assert on the property

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Show actual steps instead of the length of dataset on the progress bar
2 participants