2
2
import copy
3
3
import inspect
4
4
import itertools
5
+ import multiprocessing .pool
5
6
import sys
6
7
from collections import Counter
7
8
from collections .abc import Iterable , Iterator
24
25
Value ,
25
26
_align_features ,
26
27
_check_if_features_can_be_aligned ,
28
+ _visit ,
27
29
cast_to_python_objects ,
28
30
)
29
31
from .formatting import (
@@ -1010,6 +1012,7 @@ def __init__(
1010
1012
fn_kwargs : Optional [dict ] = None ,
1011
1013
formatting : Optional ["FormattingConfig" ] = None ,
1012
1014
features : Optional [Features ] = None ,
1015
+ max_num_running_async_map_functions_in_parallel : Optional [int ] = None ,
1013
1016
):
1014
1017
super ().__init__ ()
1015
1018
self .ex_iterable = ex_iterable
@@ -1023,6 +1026,9 @@ def __init__(
1023
1026
self .fn_kwargs = fn_kwargs or {}
1024
1027
self .formatting = formatting # required for iter_arrow
1025
1028
self ._features = features
1029
+ self .max_num_running_async_map_functions_in_parallel = (
1030
+ max_num_running_async_map_functions_in_parallel or config .MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL
1031
+ )
1026
1032
# sanity checks
1027
1033
if formatting and formatting .is_table :
1028
1034
# batch_size should match for iter_arrow
@@ -1036,6 +1042,8 @@ def __init__(
1036
1042
f"The { formatting .format_type .capitalize ()} -formatted { type (self ).__name__ } has batch_size={ batch_size if batched else 1 } which is"
1037
1043
f"different from { ex_iterable .batch_size = } from its underlying iterable."
1038
1044
)
1045
+ # to enable graceful ends
1046
+ self ._owned_loops_and_tasks : list [tuple [asyncio .AbstractEventLoop , list [asyncio .Task ]]] = []
1039
1047
1040
1048
@property
1041
1049
def iter_arrow (self ):
@@ -1174,6 +1182,7 @@ async def async_apply_function(key_example, indices):
1174
1182
loop = asyncio .get_running_loop ()
1175
1183
except RuntimeError :
1176
1184
loop = asyncio .new_event_loop ()
1185
+ self ._owned_loops_and_tasks .append ((loop , tasks ))
1177
1186
else :
1178
1187
loop = None
1179
1188
@@ -1191,15 +1200,15 @@ def iter_outputs():
1191
1200
indices .append (i )
1192
1201
tasks .append (loop .create_task (async_apply_function (key_example , i )))
1193
1202
# keep the total active tasks under a certain number
1194
- if len (tasks ) >= config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1203
+ if len (tasks ) >= self . max_num_running_async_map_functions_in_parallel :
1195
1204
done , pending = loop .run_until_complete (
1196
1205
asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
1197
1206
)
1198
- while tasks and len (pending ) >= config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1207
+ while tasks and len (pending ) >= self . max_num_running_async_map_functions_in_parallel :
1199
1208
done , pending = loop .run_until_complete (
1200
1209
asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
1201
1210
)
1202
- if len (tasks ) >= 10 * config . MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL :
1211
+ if len (tasks ) >= 10 * self . max_num_running_async_map_functions_in_parallel :
1203
1212
loop .run_until_complete (tasks [0 ])
1204
1213
# yield finished tasks
1205
1214
while tasks and tasks [0 ].done ():
@@ -1257,7 +1266,7 @@ def iter_outputs():
1257
1266
task .cancel (msg = "KeyboardInterrupt" )
1258
1267
try :
1259
1268
loop .run_until_complete (asyncio .gather (* tasks ))
1260
- except asyncio .CancelledError :
1269
+ except ( asyncio .CancelledError , ValueError ) :
1261
1270
logger .debug ("Tasks canceled." )
1262
1271
raise
1263
1272
@@ -1347,6 +1356,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
1347
1356
fn_kwargs = self .fn_kwargs ,
1348
1357
formatting = self .formatting ,
1349
1358
features = self .features ,
1359
+ max_num_running_async_map_functions_in_parallel = self .max_num_running_async_map_functions_in_parallel ,
1350
1360
)
1351
1361
1352
1362
def shard_data_sources (self , num_shards : int , index : int , contiguous = True ) -> "MappedExamplesIterable" :
@@ -1363,6 +1373,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "M
1363
1373
fn_kwargs = self .fn_kwargs ,
1364
1374
formatting = self .formatting ,
1365
1375
features = self .features ,
1376
+ max_num_running_async_map_functions_in_parallel = self .max_num_running_async_map_functions_in_parallel ,
1366
1377
)
1367
1378
1368
1379
@property
@@ -3189,6 +3200,99 @@ def cast(
3189
3200
token_per_repo_id = self ._token_per_repo_id ,
3190
3201
)
3191
3202
3203
+ def decode (self , enable : bool = True , num_threads : int = 0 ) -> "IterableDataset" :
3204
+ """
3205
+ Enable or disable the dataset features decoding for audio, image, video.
3206
+
3207
+ When enabled (default), media types are decoded:
3208
+
3209
+ * audio -> dict of "array" and "sampling_rate" and "path"
3210
+ * image -> PIL.Image
3211
+ * video -> torchvision.io.VideoReader
3212
+
3213
+ You can enable multithreading using `num_threads`. This is especially useful to speed up remote
3214
+ data streaming. However it can be slower than `num_threads=0` for local data on fast disks.
3215
+
3216
+ Disabling decoding is useful if you want to iterate on the paths or bytes of the media files
3217
+ without actually decoding their content. To disable decoding you can use `.decode(False)`, which
3218
+ is equivalent to calling `.cast()` or `.cast_column()` with all the Audio, Image and Video types
3219
+ set to `decode=False`.
3220
+
3221
+ Args:
3222
+ enable (`bool`, defaults to `True`):
3223
+ Enable or disable features decoding.
3224
+ num_threads (`int`, defaults to `0`):
3225
+ Enable multithreading for features decoding.
3226
+
3227
+ Returns:
3228
+ `IterableDataset`: A copy of the dataset with casted features.
3229
+
3230
+ Examples:
3231
+
3232
+ Disable decoding:
3233
+
3234
+ ```py
3235
+ >>> from datasets import load_dataset
3236
+ >>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3237
+ >>> next(iter(ds))
3238
+ {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=2048x1024>,
3239
+ 'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'}
3240
+ >>> ds = ds.decode(False)
3241
+ >>> ds.features
3242
+ {'image': Image(mode=None, decode=False, id=None),
3243
+ 'text': Value(dtype='string', id=None)}
3244
+ >>> next(iter(ds))
3245
+ {
3246
+ 'image': {
3247
+ 'path': 'hf://datasets/sshh12/planet-textures@69dc4cef7a5c4b2cfe387727ec8ea73d4bff7302/train/textures/0000.png',
3248
+ 'bytes': None
3249
+ },
3250
+ 'text': 'A distant celestial object with an icy crust, displaying a light blue shade, covered with round pits and rugged terrains.'
3251
+ }
3252
+ ```
3253
+
3254
+ Speed up streaming with multithreading:
3255
+
3256
+ ```py
3257
+ >>> import os
3258
+ >>> from datasets import load_dataset
3259
+ >>> from tqdm import tqdm
3260
+ >>> ds = load_dataset("sshh12/planet-textures", split="train", streaming=True)
3261
+ >>> num_threads = min(32, (os.cpu_count() or 1) + 4)
3262
+ >>> ds = ds.decode(num_threads=num_threads)
3263
+ >>> for _ in tqdm(ds): # 20 times faster !
3264
+ ... ...
3265
+ ```
3266
+ """
3267
+ if not self .features :
3268
+ raise ValueError (
3269
+ "Features decoding is only available for datasets with known features, but features are Unknown. "
3270
+ "Please set the datasets features with `ds = ds.cast(features)`."
3271
+ )
3272
+ ds = self
3273
+
3274
+ def set_decoding (decode : bool , feature ):
3275
+ if hasattr (feature , "decode" ):
3276
+ feature .decode = decode
3277
+
3278
+ if enable and num_threads > 0 :
3279
+ disabled_decoding_features = self .features .copy ()
3280
+ enabled_decoding_features = self .features .copy ()
3281
+
3282
+ _visit (disabled_decoding_features , partial (set_decoding , False ))
3283
+ _visit (enabled_decoding_features , partial (set_decoding , True ))
3284
+ ds = ds .cast (disabled_decoding_features )
3285
+ pool = multiprocessing .pool .ThreadPool (num_threads )
3286
+ func = partial (_apply_async , pool , enabled_decoding_features .decode_example )
3287
+ ds = ds .map (func , features = enabled_decoding_features )
3288
+ assert isinstance (ds ._ex_iterable , MappedExamplesIterable )
3289
+ ds ._ex_iterable .max_num_running_async_map_functions_in_parallel = 2 * num_threads
3290
+ else :
3291
+ features = ds .features .copy ()
3292
+ _visit (features , partial (set_decoding , enable ))
3293
+ ds = ds .cast (features )
3294
+ return ds
3295
+
3192
3296
def _step (self , step : int , offset : int ) -> "IterableDataset" :
3193
3297
ex_iterable = StepExamplesIterable (self ._ex_iterable , step = step , offset = offset )
3194
3298
return IterableDataset (
@@ -3407,3 +3511,12 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
3407
3511
distributed = distributed ,
3408
3512
token_per_repo_id = dataset ._token_per_repo_id ,
3409
3513
)
3514
+
3515
+
3516
+ async def _apply_async (pool , func , x ):
3517
+ future = pool .apply_async (func , (x ,))
3518
+ while True :
3519
+ if future .ready ():
3520
+ return future .get ()
3521
+ else :
3522
+ await asyncio .sleep (0 )
0 commit comments