Skip to content

Commit 03516e0

Browse files
committed
fix decoding tests
1 parent e914f5f commit 03516e0

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/test_iterable_dataset.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
import time
44
from copy import deepcopy
5+
from dataclasses import dataclass
56
from itertools import chain, cycle, islice
67
from unittest.mock import patch
78

@@ -2487,6 +2488,7 @@ def test_iterable_dataset_batch():
24872488
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]
24882489

24892490

2491+
@dataclass
24902492
class DecodableFeature:
24912493
decode_example_num_calls = 0
24922494

@@ -2497,15 +2499,18 @@ def decode_example(self, example, token_per_repo_id=None):
24972499
type(self).decode_example_num_calls += 1
24982500
return "decoded" if self.decode else example
24992501

2502+
def __call__(self):
2503+
return pa.string()
2504+
25002505

25012506
def test_decode():
2502-
data = [{"i": i} for i in range(10)]
2507+
data = [{"i": str(i)} for i in range(10)]
25032508
features = Features({"i": DecodableFeature()})
25042509
ds = IterableDataset.from_generator(lambda: (x for x in data), features=features)
25052510
assert next(iter(ds)) == {"i": "decoded"}
25062511
assert DecodableFeature.decode_example_num_calls == 1
25072512
ds = ds.decode(False)
2508-
assert next(iter(ds)) == {"i": 0}
2513+
assert next(iter(ds)) == {"i": "0"}
25092514
assert DecodableFeature.decode_example_num_calls == 1
25102515
ds = ds.decode(True)
25112516
assert next(iter(ds)) == {"i": "decoded"}

0 commit comments

Comments
 (0)