Skip to content

Commit 15ffefe

Browse files
authored
[WebDataset] Add .pth support for torch tensors (#6920)
pth support
1 parent 048c789 commit 15ffefe

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

src/datasets/packaged_modules/webdataset/webdataset.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pyarrow as pa
88

99
import datasets
10+
from datasets.features.features import cast_to_python_objects
1011

1112

1213
logger = datasets.utils.logging.get_logger(__name__)
@@ -64,7 +65,10 @@ def _split_generators(self, dl_manager):
6465
"The TAR archives of the dataset should be in WebDataset format, "
6566
"but the files in the archive don't share the same prefix or the same types."
6667
)
67-
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
68+
pa_tables = [
69+
pa.Table.from_pylist(cast_to_python_objects([example], only_1d_for_numpy=True))
70+
for example in first_examples
71+
]
6872
if datasets.config.PYARROW_VERSION.major < 14:
6973
inferred_arrow_schema = pa.concat_tables(pa_tables, promote=True).schema
7074
else:
@@ -256,16 +260,21 @@ def cbor_loads(data: bytes):
256260
return cbor.loads(data)
257261

258262

263+
def torch_loads(data: bytes):
264+
import torch
265+
266+
return torch.load(io.BytesIO(data), weights_only=True)
267+
268+
259269
# Obtained by checking `decoders` in `webdataset.autodecode`
260270
# and removing unsafe extension decoders.
261271
# Removed Pickle decoders:
262272
# - "pyd": lambda data: pickle.loads(data)
263273
# - "pickle": lambda data: pickle.loads(data)
264-
# Removed Torch decoders:
265-
# - "pth": lambda data: torch_loads(data)
266-
# Modified NumPy decoders to fix CVE-2019-6446 (add allow_pickle=False):
274+
# Modified NumPy decoders to fix CVE-2019-6446 (add allow_pickle=False and weights_only=True):
267275
# - "npy": npy_loads,
268276
# - "npz": lambda data: np.load(io.BytesIO(data)),
277+
# - "pth": lambda data: torch_loads(data)
269278
DECODERS = {
270279
"txt": text_loads,
271280
"text": text_loads,
@@ -284,5 +293,6 @@ def cbor_loads(data: bytes):
284293
"npy": npy_loads,
285294
"npz": npz_loads,
286295
"cbor": cbor_loads,
296+
"pth": torch_loads,
287297
}
288298
WebDataset.DECODERS = DECODERS

tests/fixtures/files.py

+10
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,16 @@ def audio_file():
551551
return os.path.join("tests", "features", "data", "test_audio_44100.wav")
552552

553553

554+
@pytest.fixture(scope="session")
555+
def tensor_file(tmp_path_factory):
556+
import torch
557+
558+
path = tmp_path_factory.mktemp("data") / "tensor.pth"
559+
with open(path, "wb") as f:
560+
torch.save(torch.ones(128), f)
561+
return path
562+
563+
554564
@pytest.fixture(scope="session")
555565
def zip_image_path(image_file, tmp_path_factory):
556566
path = tmp_path_factory.mktemp("data") / "dataset.img.zip"

tests/packaged_modules/test_webdataset.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import numpy as np
55
import pytest
66

7-
from datasets import Audio, DownloadManager, Features, Image, Value
7+
from datasets import Audio, DownloadManager, Features, Image, Sequence, Value
88
from datasets.packaged_modules.webdataset.webdataset import WebDataset
99

10-
from ..utils import require_pil, require_sndfile
10+
from ..utils import require_pil, require_sndfile, require_torch
1111

1212

1313
@pytest.fixture
@@ -50,6 +50,20 @@ def bad_wds_file(tmp_path, image_file, text_file):
5050
return str(filename)
5151

5252

53+
@pytest.fixture
54+
def tensor_wds_file(tmp_path, tensor_file):
55+
json_file = tmp_path / "data.json"
56+
filename = tmp_path / "file.tar"
57+
num_examples = 3
58+
with json_file.open("w", encoding="utf-8") as f:
59+
f.write(json.dumps({"text": "this is a text"}))
60+
with tarfile.open(str(filename), "w") as f:
61+
for example_idx in range(num_examples):
62+
f.add(json_file, f"{example_idx:05d}.json")
63+
f.add(tensor_file, f"{example_idx:05d}.pth")
64+
return str(filename)
65+
66+
5367
@require_pil
5468
def test_image_webdataset(image_wds_file):
5569
import PIL.Image
@@ -145,3 +159,34 @@ def test_webdataset_with_features(image_wds_file):
145159
assert isinstance(decoded["json"], dict)
146160
assert isinstance(decoded["json"]["caption"], str)
147161
assert isinstance(decoded["jpg"], PIL.Image.Image)
162+
163+
164+
@require_torch
165+
def test_tensor_webdataset(tensor_wds_file):
166+
import torch
167+
168+
data_files = {"train": [tensor_wds_file]}
169+
webdataset = WebDataset(data_files=data_files)
170+
split_generators = webdataset._split_generators(DownloadManager())
171+
assert webdataset.info.features == Features(
172+
{
173+
"__key__": Value("string"),
174+
"__url__": Value("string"),
175+
"json": {"text": Value("string")},
176+
"pth": Sequence(Value("float32")),
177+
}
178+
)
179+
assert len(split_generators) == 1
180+
split_generator = split_generators[0]
181+
assert split_generator.name == "train"
182+
generator = webdataset._generate_examples(**split_generator.gen_kwargs)
183+
_, examples = zip(*generator)
184+
assert len(examples) == 3
185+
assert isinstance(examples[0]["json"], dict)
186+
assert isinstance(examples[0]["json"]["text"], str)
187+
assert isinstance(examples[0]["pth"], torch.Tensor) # keep encoded to avoid unecessary copies
188+
encoded = webdataset.info.features.encode_example(examples[0])
189+
decoded = webdataset.info.features.decode_example(encoded)
190+
assert isinstance(decoded["json"], dict)
191+
assert isinstance(decoded["json"]["text"], str)
192+
assert isinstance(decoded["pth"], list)

0 commit comments

Comments
 (0)