Skip to content

Commit 88d53d1

Browse files
authored
[WebDataset] Support compressed files (#6931)
* support compressed files for in webdataset * ignore windows behavior * again
1 parent 670e1cf commit 88d53d1

File tree

4 files changed

+57
-12
lines changed

4 files changed

+57
-12
lines changed

src/datasets/features/features.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,17 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
312312
True,
313313
)
314314
elif config.TORCH_AVAILABLE and "torch" in sys.modules and isinstance(obj, torch.Tensor):
315+
if obj.dtype == torch.bfloat16:
316+
return _cast_to_python_objects(
317+
obj.detach().to(torch.float).cpu().numpy(),
318+
only_1d_for_numpy=only_1d_for_numpy,
319+
optimize_list_casting=optimize_list_casting,
320+
)[0], True
315321
if obj.ndim == 0:
316322
return obj.detach().cpu().numpy()[()], True
317323
elif not only_1d_for_numpy or obj.ndim == 1:
318324
return obj.detach().cpu().numpy(), True
319325
else:
320-
if obj.dtype == torch.bfloat16:
321-
return (
322-
[
323-
_cast_to_python_objects(
324-
x, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
325-
)[0]
326-
for x in obj.detach().to(torch.float).cpu().numpy()
327-
],
328-
True,
329-
)
330326
return (
331327
[
332328
_cast_to_python_objects(

src/datasets/packaged_modules/webdataset/webdataset.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from itertools import islice
44
from typing import Any, Callable, Dict, List
55

6+
import fsspec
67
import numpy as np
78
import pyarrow as pa
89

910
import datasets
1011
from datasets.features.features import cast_to_python_objects
12+
from datasets.utils.file_utils import SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL, xbasename
1113

1214

1315
logger = datasets.utils.logging.get_logger(__name__)
@@ -23,6 +25,8 @@ class WebDataset(datasets.GeneratorBasedBuilder):
2325
@classmethod
2426
def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
2527
current_example = {}
28+
fs: fsspec.AbstractFileSystem = fsspec.filesystem("memory")
29+
streaming_download_manager = datasets.StreamingDownloadManager()
2630
for filename, f in tar_iterator:
2731
if "." in filename:
2832
example_key, field_name = filename.split(".", 1)
@@ -32,8 +36,17 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
3236
current_example["__key__"] = example_key
3337
current_example["__url__"] = tar_path
3438
current_example[field_name.lower()] = f.read()
35-
if field_name.split(".")[-1] in cls.DECODERS:
36-
current_example[field_name] = cls.DECODERS[field_name.split(".")[-1]](current_example[field_name])
39+
if field_name.split(".")[-1] in SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL:
40+
fs.write_bytes(filename, current_example[field_name.lower()])
41+
extracted_file_path = streaming_download_manager.extract(f"memory://{filename}")
42+
with fsspec.open(extracted_file_path) as f:
43+
current_example[field_name.lower()] = f.read()
44+
fs.delete(filename)
45+
data_extension = xbasename(extracted_file_path).split(".")[-1]
46+
else:
47+
data_extension = field_name.split(".")[-1]
48+
if data_extension in cls.DECODERS:
49+
current_example[field_name] = cls.DECODERS[data_extension](current_example[field_name])
3750
if current_example:
3851
yield current_example
3952

src/datasets/utils/file_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,9 @@ def readline(f: io.RawIOBase):
762762
# archive compression
763763
"zip": "zip",
764764
}
765+
SINGLE_FILE_COMPRESSION_EXTENSION_TO_PROTOCOL = {
766+
fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS
767+
}
765768
SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}
766769
SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?<!:):/")
767770

tests/packaged_modules/test_webdataset.py

+33
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010
from ..utils import require_pil, require_sndfile, require_torch
1111

1212

13+
@pytest.fixture
14+
def gzipped_text_wds_file(tmp_path, text_gz_path):
15+
filename = tmp_path / "file.tar"
16+
num_examples = 3
17+
with tarfile.open(str(filename), "w") as f:
18+
for example_idx in range(num_examples):
19+
f.add(text_gz_path, f"{example_idx:05d}.txt.gz")
20+
return str(filename)
21+
22+
1323
@pytest.fixture
1424
def image_wds_file(tmp_path, image_file):
1525
json_file = tmp_path / "data.json"
@@ -64,6 +74,29 @@ def tensor_wds_file(tmp_path, tensor_file):
6474
return str(filename)
6575

6676

77+
@require_pil
78+
def test_gzipped_text_webdataset(gzipped_text_wds_file, text_path):
79+
data_files = {"train": [gzipped_text_wds_file]}
80+
webdataset = WebDataset(data_files=data_files)
81+
split_generators = webdataset._split_generators(DownloadManager())
82+
assert webdataset.info.features == Features(
83+
{
84+
"__key__": Value("string"),
85+
"__url__": Value("string"),
86+
"txt.gz": Value("string"),
87+
}
88+
)
89+
assert len(split_generators) == 1
90+
split_generator = split_generators[0]
91+
assert split_generator.name == "train"
92+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
93+
_, examples = zip(*generator)
94+
assert len(examples) == 3
95+
assert isinstance(examples[0]["txt.gz"], str)
96+
with open(text_path, "r") as f:
97+
assert examples[0]["txt.gz"].replace("\r\n", "\n") == f.read().replace("\r\n", "\n")
98+
99+
67100
@require_pil
68101
def test_image_webdataset(image_wds_file):
69102
import PIL.Image

0 commit comments

Comments
 (0)