Skip to content

Commit 7ad7379

Browse files
authored
Add IterableDataset.decode with multithreading (#7450)
* add IterableDataset.decode with multithreading * graceful async ends * test * docs * fix tests
1 parent f09db01 commit 7ad7379

File tree

7 files changed

+221
-5
lines changed

7 files changed

+221
-5
lines changed

docs/source/audio_load.mdx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,27 @@ For more information about creating your own `AudioFolder` dataset, take a look
9696
</Tip>
9797

9898
For a guide on how to load any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./loading">general loading guide</a>.
99+
100+
## Audio decoding
101+
102+
By default, audio files are decoded sequentially as NumPy arrays when you iterate on a dataset.
103+
However it is possible to speed up the dataset significantly using multithreaded decoding:
104+
105+
```python
106+
>>> import os
107+
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
108+
>>> dataset = dataset.decode(num_threads=num_threads)
109+
>>> for example in dataset: # up to 20 times faster !
110+
... ...
111+
```
112+
113+
You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
114+
However it can be slower than `num_threads=0` for local data on fast disks.
115+
116+
If you are not interested in the images decoded as NumPy arrays and would like to access the path/bytes instead, you can disable decoding:
117+
118+
```python
119+
>>> dataset = dataset.decode(False)
120+
```
121+
122+
Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.

docs/source/image_load.mdx

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,27 @@ You can load a WebDataset like this:
138138

139139
>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
140140
```
141+
142+
## Image decoding
143+
144+
By default, images are decoded sequentially as `PIL.Images` when you iterate on a dataset.
145+
However it is possible to speed up the dataset significantly using multithreaded decoding:
146+
147+
```python
148+
>>> import os
149+
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
150+
>>> dataset = dataset.decode(num_threads=num_threads)
151+
>>> for example in dataset: # up to 20 times faster !
152+
... ...
153+
```
154+
155+
You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
156+
However it can be slower than `num_threads=0` for local data on fast disks.
157+
158+
If you are not interested in the images decoded as `PIL.Images` and would like to access the path/bytes instead, you can disable decoding:
159+
160+
```python
161+
>>> dataset = dataset.decode(False)
162+
```
163+
164+
Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.

docs/source/package_reference/main_classes.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
163163
- select_columns
164164
- cast_column
165165
- cast
166+
- decode
166167
- __iter__
167168
- iter
168169
- map

docs/source/video_load.mdx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,29 @@ You can load a WebDataset like this:
169169

170170
>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
171171
```
172+
173+
## Video decoding
174+
175+
By default, videos are decoded sequentially as torchvision `VideoReaders` when you iterate on a dataset.
176+
It sequentially decodes the metadata of the videos, and doesn't read the video frames until you access them.
177+
178+
However it is possible to speed up the dataset significantly using multithreaded decoding:
179+
180+
```python
181+
>>> import os
182+
>>> num_threads = num_threads = min(32, (os.cpu_count() or 1) + 4)
183+
>>> dataset = dataset.decode(num_threads=num_threads)
184+
>>> for example in dataset: # up to 20 times faster !
185+
... ...
186+
```
187+
188+
You can enable multithreading using `num_threads`. This is especially useful to speed up remote data streaming.
189+
However it can be slower than `num_threads=0` for local data on fast disks.
190+
191+
If you are not interested in the images decoded as torchvision `VideoReaders` and would like to access the path/bytes instead, you can disable decoding:
192+
193+
```python
194+
>>> dataset = dataset.decode(False)
195+
```
196+
197+
Note: [`IterableDataset.decode`] is only available for streaming datasets at the moment.

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3551,7 +3551,7 @@ def iter_outputs(shard_iterable):
35513551
task.cancel(msg="KeyboardInterrupt")
35523552
try:
35533553
loop.run_until_complete(asyncio.gather(*tasks))
3554-
except asyncio.CancelledError:
3554+
except (asyncio.CancelledError, ValueError):
35553555
logger.debug("Tasks canceled.")
35563556
raise
35573557

src/datasets/iterable_dataset.py

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import inspect
44
import itertools
5+
import multiprocessing.pool
56
import sys
67
from collections import Counter
78
from collections.abc import Iterable, Iterator
@@ -24,6 +25,7 @@
2425
Value,
2526
_align_features,
2627
_check_if_features_can_be_aligned,
28+
_visit,
2729
cast_to_python_objects,
2830
)
2931
from .formatting import (
@@ -1010,6 +1012,7 @@ def __init__(
10101012
fn_kwargs: Optional[dict] = None,
10111013
formatting: Optional["FormattingConfig"] = None,
10121014
features: Optional[Features] = None,
1015+
max_num_running_async_map_functions_in_parallel: Optional[int] = None,
10131016
):
10141017
super().__init__()
10151018
self.ex_iterable = ex_iterable
@@ -1023,6 +1026,9 @@ def __init__(
10231026
self.fn_kwargs = fn_kwargs or {}
10241027
self.formatting = formatting # required for iter_arrow
10251028
self._features = features
1029+
self.max_num_running_async_map_functions_in_parallel = (
1030+
max_num_running_async_map_functions_in_parallel or config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL
1031+
)
10261032
# sanity checks
10271033
if formatting and formatting.is_table:
10281034
# batch_size should match for iter_arrow
@@ -1036,6 +1042,8 @@ def __init__(
10361042
f"The {formatting.format_type.capitalize()}-formatted {type(self).__name__} has batch_size={batch_size if batched else 1} which is"
10371043
f"different from {ex_iterable.batch_size=} from its underlying iterable."
10381044
)
1045+
# to enable graceful ends
1046+
self._owned_loops_and_tasks: list[tuple[asyncio.AbstractEventLoop, list[asyncio.Task]]] = []
10391047

10401048
@property
10411049
def iter_arrow(self):
@@ -1174,6 +1182,7 @@ async def async_apply_function(key_example, indices):
11741182
loop = asyncio.get_running_loop()
11751183
except RuntimeError:
11761184
loop = asyncio.new_event_loop()
1185+
self._owned_loops_and_tasks.append((loop, tasks))
11771186
else:
11781187
loop = None
11791188

@@ -1191,15 +1200,15 @@ def iter_outputs():
11911200
indices.append(i)
11921201
tasks.append(loop.create_task(async_apply_function(key_example, i)))
11931202
# keep the total active tasks under a certain number
1194-
if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
1203+
if len(tasks) >= self.max_num_running_async_map_functions_in_parallel:
11951204
done, pending = loop.run_until_complete(
11961205
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
11971206
)
1198-
while tasks and len(pending) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
1207+
while tasks and len(pending) >= self.max_num_running_async_map_functions_in_parallel:
11991208
done, pending = loop.run_until_complete(
12001209
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
12011210
)
1202-
if len(tasks) >= 10 * config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
1211+
if len(tasks) >= 10 * self.max_num_running_async_map_functions_in_parallel:
12031212
loop.run_until_complete(tasks[0])
12041213
# yield finished tasks
12051214
while tasks and tasks[0].done():
@@ -1257,7 +1266,7 @@ def iter_outputs():
12571266
task.cancel(msg="KeyboardInterrupt")
12581267
try:
12591268
loop.run_until_complete(asyncio.gather(*tasks))
1260-
except asyncio.CancelledError:
1269+
except (asyncio.CancelledError, ValueError):
12611270
logger.debug("Tasks canceled.")
12621271
raise
12631272

@@ -1347,6 +1356,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
13471356
fn_kwargs=self.fn_kwargs,
13481357
formatting=self.formatting,
13491358
features=self.features,
1359+
max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel,
13501360
)
13511361

13521362
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable":
@@ -1363,6 +1373,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "M
13631373
fn_kwargs=self.fn_kwargs,
13641374
formatting=self.formatting,
13651375
features=self.features,
1376+
max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel,
13661377
)
13671378

13681379
@property
@@ -3189,6 +3200,99 @@ def cast(
31893200
token_per_repo_id=self._token_per_repo_id,
31903201
)
31913202

3203+
def decode(self, enable: bool = True, num_threads: int = 0) -> "IterableDataset":
3204+
"""
3205+
Enable or disable the dataset features decoding for audio, image, video.
3206+
3207+
When enabled (default), media types are decoded:
3208+
3209+
* audio -> dict of "array" and "sampling_rate" and "path"
3210+
* image -> PIL.Image
3211+
* video -> torchvision.io.VideoReader
3212+
3213+
You can enable multithreading using `num_threads`. This is especially useful to speed up remote
3214+
data streaming. However it can be slower than `num_threads=0` for local data on fast disks.
3215+
3216+
Disabling decoding is useful if you want to iterate on the paths or bytes of the media files
3217+
without actually decoding their content. To disable decoding you can use `.decode(False)`, which
3218+
is equivalent to calling `.cast()` or `.cast_column()` with all the Audio, Image and Video types
3219+
set to `decode=False`.
3220+
3221+
Args:
3222+
enable (`bool`, defaults to `True`):
3223+
Enable or disable features decoding.
3224+
num_threads (`int`, defaults to `0`):
3225+
Enable multithreading for features decoding.
3226+
3227+
Returns:
3228+
`IterableDataset`: A copy of the dataset with casted features.
3229+
3230+
Examples:
3231+
3232+
Disable decoding:
3233+
3234+
```py
3235+
>>> from datasets import load_dataset
3236+
>>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3237+
>>> next(iter(ds))
3238+
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024>,
3239+
'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'}
3240+
>>> ds = ds.decode(False)
3241+
>>> ds.features
3242+
{'image': Image(mode=None, decode=False, id=None),
3243+
'text': Value(dtype='string', id=None)}
3244+
>>> next(iter(ds))
3245+
{
3246+
'image': {
3247+
'path': 'hf://datasets/sshh12/planet-textures@69dc4cef7a5c4b2cfe387727ec8ea73d4bff7302/train/textures/0000.png',
3248+
'bytes': None
3249+
},
3250+
'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'
3251+
}
3252+
```
3253+
3254+
Speed up streaming with multithreading:
3255+
3256+
```py
3257+
>>> import os
3258+
>>> from datasets import load_dataset
3259+
>>> from tqdm import tqdm
3260+
>>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3261+
>>> num_threads = min(32, (os.cpu_count() or 1) + 4)
3262+
>>> ds = ds.decode(num_threads=num_threads)
3263+
>>> for _ in tqdm(ds): # 20 times faster !
3264+
... ...
3265+
```
3266+
"""
3267+
if not self.features:
3268+
raise ValueError(
3269+
"Features decoding is only available for datasets with known features, but features are Unknown. "
3270+
"Please set the datasets features with `ds = ds.cast(features)`."
3271+
)
3272+
ds = self
3273+
3274+
def set_decoding(decode: bool, feature):
3275+
if hasattr(feature, "decode"):
3276+
feature.decode = decode
3277+
3278+
if enable and num_threads > 0:
3279+
disabled_decoding_features = self.features.copy()
3280+
enabled_decoding_features = self.features.copy()
3281+
3282+
_visit(disabled_decoding_features, partial(set_decoding, False))
3283+
_visit(enabled_decoding_features, partial(set_decoding, True))
3284+
ds = ds.cast(disabled_decoding_features)
3285+
pool = multiprocessing.pool.ThreadPool(num_threads)
3286+
func = partial(_apply_async, pool, enabled_decoding_features.decode_example)
3287+
ds = ds.map(func, features=enabled_decoding_features)
3288+
assert isinstance(ds._ex_iterable, MappedExamplesIterable)
3289+
ds._ex_iterable.max_num_running_async_map_functions_in_parallel = 2 * num_threads
3290+
else:
3291+
features = ds.features.copy()
3292+
_visit(features, partial(set_decoding, enable))
3293+
ds = ds.cast(features)
3294+
return ds
3295+
31923296
def _step(self, step: int, offset: int) -> "IterableDataset":
31933297
ex_iterable = StepExamplesIterable(self._ex_iterable, step=step, offset=offset)
31943298
return IterableDataset(
@@ -3407,3 +3511,12 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
34073511
distributed=distributed,
34083512
token_per_repo_id=dataset._token_per_repo_id,
34093513
)
3514+
3515+
3516+
async def _apply_async(pool, func, x):
3517+
future = pool.apply_async(func, (x,))
3518+
while True:
3519+
if future.ready():
3520+
return future.get()
3521+
else:
3522+
await asyncio.sleep(0)

tests/test_iterable_dataset.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,3 +2474,31 @@ def test_iterable_dataset_batch():
24742474
assert len(batch["text"]) == 3
24752475
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
24762476
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]
2477+
2478+
2479+
class DecodableFeature:
2480+
decode_example_num_calls = 0
2481+
2482+
def __init__(self):
2483+
self.decode = True
2484+
2485+
def decode_example(self, example, token_per_repo_id=None):
2486+
type(self).decode_example_num_calls += 1
2487+
return "decoded" if self.decode else example
2488+
2489+
2490+
def test_decode():
2491+
data = [{"i": i} for i in range(10)]
2492+
features = Features({"i": DecodableFeature()})
2493+
ds = IterableDataset.from_generator(lambda: (x for x in data), features=features)
2494+
assert next(iter(ds)) == {"i": "decoded"}
2495+
assert DecodableFeature.decode_example_num_calls == 1
2496+
ds = ds.decode(False)
2497+
assert next(iter(ds)) == {"i": 0}
2498+
assert DecodableFeature.decode_example_num_calls == 1
2499+
ds = ds.decode(True)
2500+
assert next(iter(ds)) == {"i": "decoded"}
2501+
assert DecodableFeature.decode_example_num_calls == 2
2502+
ds = ds.decode(num_threads=1)
2503+
assert next(iter(ds)) == {"i": "decoded"}
2504+
assert DecodableFeature.decode_example_num_calls == 4

0 commit comments

Comments
 (0)