Skip to content

Commit 438f29f

Browse files
authored
Relax restrictions on wrapping a custom batch sampler in predict (#19678)
1 parent 94167d6 commit 438f29f

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
- The `prepare_data()` hook in `LightningModule` and `LightningDataModule` is now subject to a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
2424

25-
-
25+
- Relaxed the requirement for custom batch samplers to expose `drop_last` for prediction ([#19678](https://github.com/Lightning-AI/pytorch-lightning/pull/19678))
2626

2727
-
2828

src/lightning/pytorch/utilities/data.py

+9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
has_iterable_dataset,
2929
sized_len,
3030
)
31+
from lightning.fabric.utilities.warnings import PossibleUserWarning
3132
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
3233
from lightning.pytorch.trainer.states import RunningStage
3334
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -301,6 +302,14 @@ def _dataloader_init_kwargs_resolve_sampler(
301302
" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"
302303
" responsible for handling the distributed sampling within your batch sampler."
303304
) from ex
305+
elif is_predicting:
306+
rank_zero_warn(
307+
f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."
308+
" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"
309+
" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"
310+
" all indices.",
311+
category=PossibleUserWarning,
312+
)
304313
else:
305314
# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or
306315
# how to adjust the `drop_last` value

tests/tests_pytorch/utilities/test_data.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
import torch
77
from lightning.fabric.utilities.data import _replace_dunder_methods
8+
from lightning.fabric.utilities.warnings import PossibleUserWarning
89
from lightning.pytorch import Trainer
910
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
1011
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
@@ -230,7 +231,8 @@ def __len__(self) -> int:
230231
assert batch_sampler.drop_last == (not predicting)
231232

232233

233-
def test_custom_batch_sampler():
234+
@pytest.mark.parametrize("predicting", [True, False])
235+
def test_custom_batch_sampler(predicting):
234236
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
235237

236238
class CustomBatchSampler: # not inheriting from `BatchSampler`
@@ -240,8 +242,13 @@ def __iter__(self):
240242

241243
batch_sampler = CustomBatchSampler()
242244
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
243-
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
244-
_ = _update_dataloader(dataloader, sampler=Mock())
245+
246+
if predicting:
247+
with pytest.warns(PossibleUserWarning, match=r"Make sure your sampler is configured correctly to return all"):
248+
_ = _update_dataloader(dataloader, sampler=Mock(), mode=RunningStage.PREDICTING)
249+
else:
250+
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
251+
_ = _update_dataloader(dataloader, sampler=Mock(), mode=None)
245252

246253

247254
def test_custom_batch_sampler_no_drop_last():

0 commit comments

Comments
 (0)