Skip to content

Commit ead089d

Browse files
add split argument to Generator (#7015)
* add split argument to Generator, from_generator, AbstractDatasetInputStream, GeneratorDatasetInputStream * split generator review feedbacks * import Split * tag added version in iterable_dataset, rollback change in _concatenate_iterable_datasets * rm useless Generator __init__ * docstring formatting Co-authored-by: Albert Villanova del Moral <[email protected]> * format docstring Co-authored-by: Albert Villanova del Moral <[email protected]> * fix test_dataset_from_generator_split[None] --------- Co-authored-by: Albert Villanova del Moral <[email protected]>
1 parent 92bdab5 commit ead089d

File tree

5 files changed

+39
-12
lines changed

5 files changed

+39
-12
lines changed

src/datasets/arrow_dataset.py

+6
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,7 @@ def from_generator(
10681068
keep_in_memory: bool = False,
10691069
gen_kwargs: Optional[dict] = None,
10701070
num_proc: Optional[int] = None,
1071+
split: NamedSplit = Split.TRAIN,
10711072
**kwargs,
10721073
):
10731074
"""Create a Dataset from a generator.
@@ -1090,6 +1091,10 @@ def from_generator(
10901091
If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`.
10911092
10921093
<Added version="2.7.0"/>
1094+
split ([`NamedSplit`], defaults to `Split.TRAIN`):
1095+
Split name to be assigned to the dataset.
1096+
1097+
<Added version="2.21.0"/>
10931098
**kwargs (additional keyword arguments):
10941099
Keyword arguments to be passed to :[`GeneratorConfig`].
10951100
@@ -1126,6 +1131,7 @@ def from_generator(
11261131
keep_in_memory=keep_in_memory,
11271132
gen_kwargs=gen_kwargs,
11281133
num_proc=num_proc,
1134+
split=split,
11291135
**kwargs,
11301136
).read()
11311137

src/datasets/io/generator.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, Optional
22

3-
from .. import Features
3+
from .. import Features, NamedSplit, Split
44
from ..packaged_modules.generator.generator import Generator
55
from .abc import AbstractDatasetInputStream
66

@@ -15,6 +15,7 @@ def __init__(
1515
streaming: bool = False,
1616
gen_kwargs: Optional[dict] = None,
1717
num_proc: Optional[int] = None,
18+
split: NamedSplit = Split.TRAIN,
1819
**kwargs,
1920
):
2021
super().__init__(
@@ -30,13 +31,14 @@ def __init__(
3031
features=features,
3132
generator=generator,
3233
gen_kwargs=gen_kwargs,
34+
split=split,
3335
**kwargs,
3436
)
3537

3638
def read(self):
3739
# Build iterable dataset
3840
if self.streaming:
39-
dataset = self.builder.as_streaming_dataset(split="train")
41+
dataset = self.builder.as_streaming_dataset(split=self.builder.config.split)
4042
# Build regular (map-style) dataset
4143
else:
4244
download_config = None
@@ -52,6 +54,6 @@ def read(self):
5254
num_proc=self.num_proc,
5355
)
5456
dataset = self.builder.as_dataset(
55-
split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory
57+
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
5658
)
5759
return dataset

src/datasets/iterable_dataset.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
2020
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
2121
from .info import DatasetInfo
22-
from .splits import NamedSplit
22+
from .splits import NamedSplit, Split
2323
from .table import cast_table_to_features, read_schema_from_file, table_cast
2424
from .utils.logging import get_logger
2525
from .utils.py_utils import Literal
@@ -2083,6 +2083,7 @@ def from_generator(
20832083
generator: Callable,
20842084
features: Optional[Features] = None,
20852085
gen_kwargs: Optional[dict] = None,
2086+
split: NamedSplit = Split.TRAIN,
20862087
) -> "IterableDataset":
20872088
"""Create an Iterable Dataset from a generator.
20882089
@@ -2095,7 +2096,10 @@ def from_generator(
20952096
Keyword arguments to be passed to the `generator` callable.
20962097
You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`.
20972098
This can be used to improve shuffling and when iterating over the dataset with multiple workers.
2099+
split ([`NamedSplit`], defaults to `Split.TRAIN`):
2100+
Split name to be assigned to the dataset.
20982101
2102+
<Added version="2.21.0"/>
20992103
Returns:
21002104
`IterableDataset`
21012105
@@ -2126,10 +2130,7 @@ def from_generator(
21262130
from .io.generator import GeneratorDatasetInputStream
21272131

21282132
return GeneratorDatasetInputStream(
2129-
generator=generator,
2130-
features=features,
2131-
gen_kwargs=gen_kwargs,
2132-
streaming=True,
2133+
generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split
21332134
).read()
21342135

21352136
@staticmethod

src/datasets/packaged_modules/generator/generator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig):
99
generator: Optional[Callable] = None
1010
gen_kwargs: Optional[dict] = None
1111
features: Optional[datasets.Features] = None
12+
split: datasets.NamedSplit = datasets.Split.TRAIN
1213

1314
def __post_init__(self):
1415
super().__post_init__()
@@ -26,7 +27,7 @@ def _info(self):
2627
return datasets.DatasetInfo(features=self.config.features)
2728

2829
def _split_generators(self, dl_manager):
29-
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)]
30+
return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)]
3031

3132
def _generate_examples(self, **gen_kwargs):
3233
for idx, ex in enumerate(self.config.generator(**gen_kwargs)):

tests/test_arrow_dataset.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -3871,10 +3871,11 @@ def _gen():
38713871
return _gen
38723872

38733873

3874-
def _check_generator_dataset(dataset, expected_features):
3874+
def _check_generator_dataset(dataset, expected_features, split):
38753875
assert isinstance(dataset, Dataset)
38763876
assert dataset.num_rows == 4
38773877
assert dataset.num_columns == 3
3878+
assert dataset.split == split
38783879
assert dataset.column_names == ["col_1", "col_2", "col_3"]
38793880
for feature, expected_dtype in expected_features.items():
38803881
assert dataset.features[feature].dtype == expected_dtype
@@ -3886,7 +3887,7 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t
38863887
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
38873888
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
38883889
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3889-
_check_generator_dataset(dataset, expected_features)
3890+
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))
38903891

38913892

38923893
@pytest.mark.parametrize(
@@ -3907,7 +3908,23 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path):
39073908
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
39083909
)
39093910
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
3910-
_check_generator_dataset(dataset, expected_features)
3911+
_check_generator_dataset(dataset, expected_features, NamedSplit("train"))
3912+
3913+
3914+
@pytest.mark.parametrize(
3915+
"split",
3916+
[None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"],
3917+
)
3918+
def test_dataset_from_generator_split(split, data_generator, tmp_path):
3919+
cache_dir = tmp_path / "cache"
3920+
default_expected_split = "train"
3921+
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3922+
expected_split = split if split else default_expected_split
3923+
if split:
3924+
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split)
3925+
else:
3926+
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir)
3927+
_check_generator_dataset(dataset, expected_features, expected_split)
39113928

39123929

39133930
@require_not_windows

0 commit comments

Comments
 (0)