Skip to content

Commit 4ba47a3

Browse files
authored
Persist IterableDataset epoch in workers (#6710)
* persist IterableDataset epoch in workers * more tests * comment * re-share memory after pickling * Update src/datasets/iterable_dataset.py
1 parent 100361d commit 4ba47a3

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

src/datasets/iterable_dataset.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass
88
from functools import partial
99
from itertools import cycle, islice
10-
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
10+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
1111

1212
import fsspec.asyn
1313
import numpy as np
@@ -26,6 +26,9 @@
2626
from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs
2727

2828

29+
if TYPE_CHECKING:
30+
import torch
31+
2932
logger = get_logger(__name__)
3033

3134
Key = Union[int, str]
@@ -1690,6 +1693,18 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls):
16901693
cls.__bases__ += (torch.utils.data.IterableDataset,)
16911694

16921695

1696+
def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]:
1697+
if config.TORCH_AVAILABLE:
1698+
import torch
1699+
1700+
if isinstance(value, torch.Tensor):
1701+
return value.share_memory_()
1702+
else:
1703+
return torch.tensor(value).share_memory_()
1704+
else:
1705+
return value
1706+
1707+
16931708
class IterableDataset(DatasetInfoMixin):
16941709
"""A Dataset backed by an iterable."""
16951710

@@ -1722,8 +1737,8 @@ def __init__(
17221737
self._formatting = formatting
17231738
self._shuffling = shuffling
17241739
self._distributed = distributed
1725-
self._epoch = 0
17261740
self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
1741+
self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0)
17271742
self._starting_state_dict: Optional[dict] = None
17281743
self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
17291744
self._state_dict = self._prepared_ex_iterable._init_state_dict()
@@ -1841,18 +1856,24 @@ def __getstate__(self):
18411856

18421857
def __setstate__(self, d):
18431858
self.__dict__ = d
1859+
# Re-add torch shared memory, since shared memory is not always kept when pickling
1860+
self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch)
18441861
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
18451862
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)
18461863

18471864
def _head(self, n=5):
18481865
return _examples_to_batch(list(self.take(n)))
18491866

1867+
@property
1868+
def epoch(self) -> int:
1869+
return int(self._epoch)
1870+
18501871
def _effective_generator(self):
1851-
if self._shuffling and self._epoch == 0:
1872+
if self._shuffling and self.epoch == 0:
18521873
return self._shuffling.generator
18531874
elif self._shuffling:
1854-
# Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars)
1855-
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch
1875+
# Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars)
1876+
effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch
18561877
effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
18571878
return np.random.default_rng(effective_seed)
18581879
else:
@@ -2465,7 +2486,7 @@ def shuffle(
24652486
)
24662487

24672488
def set_epoch(self, epoch: int):
2468-
self._epoch = epoch
2489+
self._epoch += epoch - self._epoch # update torch value in shared memory in-place
24692490

24702491
def skip(self, n: int) -> "IterableDataset":
24712492
"""

tests/test_iterable_dataset.py

+22
Original file line numberDiff line numberDiff line change
@@ -1641,6 +1641,28 @@ def test_iterable_dataset_is_torch_iterable_dataset(dataset: IterableDataset):
16411641
assert len(out) == DEFAULT_N_EXAMPLES
16421642

16431643

1644+
@require_torch
1645+
def test_iterable_dataset_persists_epoch_in_torch_workers(dataset: IterableDataset):
1646+
from torch.utils.data import DataLoader
1647+
1648+
dataset = dataset.shuffle(seed=42)
1649+
dataloader = DataLoader(dataset, num_workers=1, persistent_workers=True)
1650+
epoch0 = list(dataloader)
1651+
assert list(dataloader) == epoch0
1652+
dataset.set_epoch(1)
1653+
assert list(dataloader) != epoch0
1654+
1655+
# Make sure pickle works even with torch objects in shared memory
1656+
dataset_copy: IterableDataset = pickle.loads(pickle.dumps(dataset))
1657+
dataloader = DataLoader(dataset_copy, num_workers=1, persistent_workers=True)
1658+
epoch1 = list(dataloader)
1659+
assert list(dataloader) == epoch1
1660+
dataset.set_epoch(2) # this should not affect the copy
1661+
assert list(dataloader) == epoch1
1662+
dataset_copy.set_epoch(2)
1663+
assert list(dataloader) != epoch1
1664+
1665+
16441666
@pytest.mark.parametrize("n", [0, 2, int(1e10)])
16451667
def test_iterable_dataset_skip(dataset: IterableDataset, n):
16461668
skip_dataset = dataset.skip(n)

0 commit comments

Comments
 (0)