Skip to content

Relax restrictions on wrapping a custom batch sampler in predict #19678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))

-
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))

-

Expand Down
9 changes: 9 additions & 0 deletions src/lightning/pytorch/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
has_iterable_dataset,
sized_len,
)
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler(
" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"
" responsible for handling the distributed sampling within your batch sampler."
) from ex
elif is_predicting:
rank_zero_warn(
f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."
" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"
" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"
" all indices.",
category=PossibleUserWarning,
)
else:
# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or
# how to adjust the `drop_last` value
Expand Down
13 changes: 10 additions & 3 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from lightning.fabric.utilities.data import _replace_dunder_methods
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
Expand Down Expand Up @@ -230,7 +231,8 @@ def __len__(self) -> int:
assert batch_sampler.drop_last == (not predicting)


def test_custom_batch_sampler():
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""

class CustomBatchSampler: # not inheriting from `BatchSampler`
Expand All @@ -240,8 +242,13 @@ def __iter__(self):

batch_sampler = CustomBatchSampler()
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock())

if predicting:
with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"):
_ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING)
else:
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock(), mode=None)


def test_custom_batch_sampler_no_drop_last():
Expand Down