|
7 | 7 | from dataclasses import dataclass
|
8 | 8 | from functools import partial
|
9 | 9 | from itertools import cycle, islice
|
10 |
| -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union |
| 10 | +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union |
11 | 11 |
|
12 | 12 | import fsspec.asyn
|
13 | 13 | import numpy as np
|
|
26 | 26 | from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs
|
27 | 27 |
|
28 | 28 |
|
| 29 | +if TYPE_CHECKING: |
| 30 | + import torch |
| 31 | + |
29 | 32 | logger = get_logger(__name__)
|
30 | 33 |
|
31 | 34 | Key = Union[int, str]
|
@@ -1690,6 +1693,18 @@ def _maybe_add_torch_iterable_dataset_parent_class(cls):
|
1690 | 1693 | cls.__bases__ += (torch.utils.data.IterableDataset,)
|
1691 | 1694 |
|
1692 | 1695 |
|
| 1696 | +def _maybe_share_with_torch_persistent_workers(value: Union[int, "torch.Tensor"]) -> Union[int, "torch.Tensor"]: |
| 1697 | + if config.TORCH_AVAILABLE: |
| 1698 | + import torch |
| 1699 | + |
| 1700 | + if isinstance(value, torch.Tensor): |
| 1701 | + return value.share_memory_() |
| 1702 | + else: |
| 1703 | + return torch.tensor(value).share_memory_() |
| 1704 | + else: |
| 1705 | + return value |
| 1706 | + |
| 1707 | + |
1693 | 1708 | class IterableDataset(DatasetInfoMixin):
|
1694 | 1709 | """A Dataset backed by an iterable."""
|
1695 | 1710 |
|
@@ -1722,8 +1737,8 @@ def __init__(
|
1722 | 1737 | self._formatting = formatting
|
1723 | 1738 | self._shuffling = shuffling
|
1724 | 1739 | self._distributed = distributed
|
1725 |
| - self._epoch = 0 |
1726 | 1740 | self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {}
|
| 1741 | + self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) |
1727 | 1742 | self._starting_state_dict: Optional[dict] = None
|
1728 | 1743 | self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
|
1729 | 1744 | self._state_dict = self._prepared_ex_iterable._init_state_dict()
|
@@ -1841,18 +1856,24 @@ def __getstate__(self):
|
1841 | 1856 |
|
1842 | 1857 | def __setstate__(self, d):
|
1843 | 1858 | self.__dict__ = d
|
| 1859 | + # Re-add torch shared memory, since shared memory is not always kept when pickling |
| 1860 | + self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch) |
1844 | 1861 | # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
|
1845 | 1862 | _maybe_add_torch_iterable_dataset_parent_class(self.__class__)
|
1846 | 1863 |
|
1847 | 1864 | def _head(self, n=5):
|
1848 | 1865 | return _examples_to_batch(list(self.take(n)))
|
1849 | 1866 |
|
| 1867 | + @property |
| 1868 | + def epoch(self) -> int: |
| 1869 | + return int(self._epoch) |
| 1870 | + |
1850 | 1871 | def _effective_generator(self):
|
1851 |
| - if self._shuffling and self._epoch == 0: |
| 1872 | + if self._shuffling and self.epoch == 0: |
1852 | 1873 | return self._shuffling.generator
|
1853 | 1874 | elif self._shuffling:
|
1854 |
| - # Create effective seed using self._epoch (we subtract in order to avoir overflow in long_scalars) |
1855 |
| - effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self._epoch |
| 1875 | + # Create effective seed using self.epoch (we subtract in order to avoir overflow in long_scalars) |
| 1876 | + effective_seed = deepcopy(self._shuffling.generator).integers(0, 1 << 63) - self.epoch |
1856 | 1877 | effective_seed = (1 << 63) + effective_seed if effective_seed < 0 else effective_seed
|
1857 | 1878 | return np.random.default_rng(effective_seed)
|
1858 | 1879 | else:
|
@@ -2465,7 +2486,7 @@ def shuffle(
|
2465 | 2486 | )
|
2466 | 2487 |
|
2467 | 2488 | def set_epoch(self, epoch: int):
|
2468 |
| - self._epoch = epoch |
| 2489 | + self._epoch += epoch - self._epoch # update torch value in shared memory in-place |
2469 | 2490 |
|
2470 | 2491 | def skip(self, n: int) -> "IterableDataset":
|
2471 | 2492 | """
|
|
0 commit comments