Skip to content

Commit 5e72fb1

Browse files
authored
Fix webdataset pickling (#6972)
* fix webdataset pickling * more general fix
1 parent ef2fb35 commit 5e72fb1

File tree

3 files changed

+25
-26
lines changed

3 files changed

+25
-26
lines changed

src/datasets/exceptions.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from . import config
88
from .table import CastError
99
from .utils.deprecation_utils import deprecated
10-
from .utils.track import TrackedIterable, tracked_list, tracked_str
10+
from .utils.track import TrackedIterableFromGenerator, tracked_list, tracked_str
1111

1212

1313
class DatasetsError(Exception):
@@ -65,9 +65,11 @@ def from_cast_error(
6565
)
6666
formatted_tracked_gen_kwargs: List[str] = []
6767
for gen_kwarg in gen_kwargs.values():
68-
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterable)):
68+
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterableFromGenerator)):
6969
continue
70-
while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None:
70+
while (
71+
isinstance(gen_kwarg, (tracked_list, TrackedIterableFromGenerator)) and gen_kwarg.last_item is not None
72+
):
7173
gen_kwarg = gen_kwarg.last_item
7274
if isinstance(gen_kwarg, tracked_str):
7375
gen_kwarg = gen_kwarg.get_origin()

src/datasets/utils/file_utils.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from io import BytesIO
2727
from itertools import chain
2828
from pathlib import Path, PurePosixPath
29-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, TypeVar, Union
29+
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
3030
from unittest.mock import patch
3131
from urllib.parse import urljoin, urlparse
3232
from xml.etree import ElementTree as ET
@@ -47,7 +47,7 @@
4747
from . import tqdm as hf_tqdm
4848
from ._filelock import FileLock
4949
from .extract import ExtractManager
50-
from .track import TrackedIterable
50+
from .track import TrackedIterableFromGenerator
5151

5252

5353
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1564,23 +1564,7 @@ def xxml_dom_minidom_parse(filename_or_file, download_config: Optional[DownloadC
15641564
return xml.dom.minidom.parse(f, **kwargs)
15651565

15661566

1567-
class _IterableFromGenerator(TrackedIterable):
1568-
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""
1569-
1570-
def __init__(self, generator: Callable, *args, **kwargs):
1571-
super().__init__()
1572-
self.generator = generator
1573-
self.args = args
1574-
self.kwargs = kwargs
1575-
1576-
def __iter__(self):
1577-
for x in self.generator(*self.args, **self.kwargs):
1578-
self.last_item = x
1579-
yield x
1580-
self.last_item = None
1581-
1582-
1583-
class ArchiveIterable(_IterableFromGenerator):
1567+
class ArchiveIterable(TrackedIterableFromGenerator):
15841568
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""
15851569

15861570
@staticmethod
@@ -1645,7 +1629,7 @@ def from_urlpath(cls, urlpath_or_buf, download_config: Optional[DownloadConfig]
16451629
return cls(cls._iter_from_urlpath, urlpath_or_buf, download_config)
16461630

16471631

1648-
class FilesIterable(_IterableFromGenerator):
1632+
class FilesIterable(TrackedIterableFromGenerator):
16491633
"""An iterable of paths from a list of directories or files"""
16501634

16511635
@classmethod

src/datasets/utils/track.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,26 @@ def __repr__(self) -> str:
3737
return f"{self.__class__.__name__}(current={self.last_item})"
3838

3939

40-
class TrackedIterable(Iterable):
41-
def __init__(self) -> None:
40+
class TrackedIterableFromGenerator(Iterable):
41+
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""
42+
43+
def __init__(self, generator, *args):
4244
super().__init__()
45+
self.generator = generator
46+
self.args = args
47+
self.last_item = None
48+
49+
def __iter__(self):
50+
for x in self.generator(*self.args):
51+
self.last_item = x
52+
yield x
4353
self.last_item = None
4454

4555
def __repr__(self) -> str:
4656
if self.last_item is None:
47-
super().__repr__()
57+
return super().__repr__()
4858
else:
4959
return f"{self.__class__.__name__}(current={self.last_item})"
60+
61+
def __reduce__(self):
62+
return (self.__class__, (self.generator, *self.args))

0 commit comments

Comments
 (0)