|
4 | 4 | import numpy as np
|
5 | 5 | import pytest
|
6 | 6 |
|
7 |
| -from datasets import Audio, DownloadManager, Features, Image, Value |
| 7 | +from datasets import Audio, DownloadManager, Features, Image, Sequence, Value |
8 | 8 | from datasets.packaged_modules.webdataset.webdataset import WebDataset
|
9 | 9 |
|
10 |
| -from ..utils import require_pil, require_sndfile |
| 10 | +from ..utils import require_pil, require_sndfile, require_torch |
11 | 11 |
|
12 | 12 |
|
13 | 13 | @pytest.fixture
|
@@ -50,6 +50,20 @@ def bad_wds_file(tmp_path, image_file, text_file):
|
50 | 50 | return str(filename)
|
51 | 51 |
|
52 | 52 |
|
| 53 | +@pytest.fixture |
| 54 | +def tensor_wds_file(tmp_path, tensor_file): |
| 55 | + json_file = tmp_path / "data.json" |
| 56 | + filename = tmp_path / "file.tar" |
| 57 | + num_examples = 3 |
| 58 | + with json_file.open("w", encoding="utf-8") as f: |
| 59 | + f.write(json.dumps({"text": "this is a text"})) |
| 60 | + with tarfile.open(str(filename), "w") as f: |
| 61 | + for example_idx in range(num_examples): |
| 62 | + f.add(json_file, f"{example_idx:05d}.json") |
| 63 | + f.add(tensor_file, f"{example_idx:05d}.pth") |
| 64 | + return str(filename) |
| 65 | + |
| 66 | + |
53 | 67 | @require_pil
|
54 | 68 | def test_image_webdataset(image_wds_file):
|
55 | 69 | import PIL.Image
|
@@ -145,3 +159,34 @@ def test_webdataset_with_features(image_wds_file):
|
145 | 159 | assert isinstance(decoded["json"], dict)
|
146 | 160 | assert isinstance(decoded["json"]["caption"], str)
|
147 | 161 | assert isinstance(decoded["jpg"], PIL.Image.Image)
|
| 162 | + |
| 163 | + |
| 164 | +@require_torch |
| 165 | +def test_tensor_webdataset(tensor_wds_file): |
| 166 | + import torch |
| 167 | + |
| 168 | + data_files = {"train": [tensor_wds_file]} |
| 169 | + webdataset = WebDataset(data_files=data_files) |
| 170 | + split_generators = webdataset._split_generators(DownloadManager()) |
| 171 | + assert webdataset.info.features == Features( |
| 172 | + { |
| 173 | + "__key__": Value("string"), |
| 174 | + "__url__": Value("string"), |
| 175 | + "json": {"text": Value("string")}, |
| 176 | + "pth": Sequence(Value("float32")), |
| 177 | + } |
| 178 | + ) |
| 179 | + assert len(split_generators) == 1 |
| 180 | + split_generator = split_generators[0] |
| 181 | + assert split_generator.name == "train" |
| 182 | + generator = webdataset._generate_examples(**split_generator.gen_kwargs) |
| 183 | + _, examples = zip(*generator) |
| 184 | + assert len(examples) == 3 |
| 185 | + assert isinstance(examples[0]["json"], dict) |
| 186 | + assert isinstance(examples[0]["json"]["text"], str) |
| 187 | + assert isinstance(examples[0]["pth"], torch.Tensor) # keep encoded to avoid unecessary copies |
| 188 | + encoded = webdataset.info.features.encode_example(examples[0]) |
| 189 | + decoded = webdataset.info.features.decode_example(encoded) |
| 190 | + assert isinstance(decoded["json"], dict) |
| 191 | + assert isinstance(decoded["json"]["text"], str) |
| 192 | + assert isinstance(decoded["pth"], list) |
0 commit comments