Skip to content

Commit d70c902

Browse files
Align filename prefix splitting with WebDataset library (#7151)
* Align filename prefix splitting with WebDataset library * Fix import
1 parent 43b1fe1 commit d70c902

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

src/datasets/packaged_modules/webdataset/webdataset.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import json
3+
import re
34
from itertools import islice
45
from typing import Any, Callable, Dict, List
56

@@ -28,25 +29,26 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
2829
fs: fsspec.AbstractFileSystem = fsspec.filesystem("memory")
2930
streaming_download_manager = datasets.StreamingDownloadManager()
3031
for filename, f in tar_iterator:
31-
if "." in filename:
32-
example_key, field_name = filename.split(".", 1)
33-
if current_example and current_example["__key__"] != example_key:
34-
yield current_example
35-
current_example = {}
36-
current_example["__key__"] = example_key
37-
current_example["__url__"] = tar_path
38-
current_example[field_name.lower()] = f.read()
39-
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
40-
fs.write_bytes(filename, current_example[field_name.lower()])
41-
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
42-
with fsspec.open(extracted_file_path) as f:
43-
current_example[field_name.lower()] = f.read()
44-
fs.delete(filename)
45-
data_extension = xbasename(extracted_file_path).split(".")[-1]
46-
else:
47-
data_extension = field_name.split(".")[-1]
48-
if data_extension in cls.DECODERS:
49-
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
32+
example_key, field_name = base_plus_ext(filename)
33+
if example_key is None:
34+
continue
35+
if current_example and current_example["__key__"] != example_key:
36+
yield current_example
37+
current_example = {}
38+
current_example["__key__"] = example_key
39+
current_example["__url__"] = tar_path
40+
current_example[field_name.lower()] = f.read()
41+
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
42+
fs.write_bytes(filename, current_example[field_name.lower()])
43+
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
44+
with fsspec.open(extracted_file_path) as f:
45+
current_example[field_name.lower()] = f.read()
46+
fs.delete(filename)
47+
data_extension = xbasename(extracted_file_path).split(".")[-1]
48+
else:
49+
data_extension = field_name.split(".")[-1]
50+
if data_extension in cls.DECODERS:
51+
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
5052
if current_example:
5153
yield current_example
5254

@@ -121,6 +123,18 @@ def _generate_examples(self, tar_paths, tar_iterators):
121123
yield f"{tar_idx}_{example_idx}", example
122124

123125

126+
# Source: https://github.com/webdataset/webdataset/blob/87bd5aa41602d57f070f65a670893ee625702f2f/webdataset/tariterators.py#L25
127+
def base_plus_ext(path):
128+
"""Split off all file extensions.
129+
130+
Returns base, allext.
131+
"""
132+
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
133+
if not match:
134+
return None, None
135+
return match.group(1), match.group(2)
136+
137+
124138
# Obtained with:
125139
# ```
126140
# import PIL.Image

0 commit comments

Comments
 (0)