Skip to content

Commit 90e5bf8

Browse files
Support skip_trying_type (#7483)
* Add skip_trying_type * Add a test case * Rename and apply make style * Apply suggestions from code review --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 94ccd1b commit 90e5bf8

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

src/datasets/arrow_dataset.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,7 @@ def map(
28512851
suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
28522852
new_fingerprint: Optional[str] = None,
28532853
desc: Optional[str] = None,
2854+
try_original_type: Optional[bool] = True,
28542855
) -> "Dataset":
28552856
"""
28562857
Apply a function to all the examples in the table (individually or in batches) and update the table.
@@ -2932,6 +2933,9 @@ def map(
29322933
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
29332934
desc (`str`, *optional*, defaults to `None`):
29342935
Meaningful description to be displayed alongside with the progress bar while mapping examples.
2936+
try_original_type (`Optional[bool]`, defaults to `True`):
2937+
Try to keep the types of the original columns (e.g. int32 -> int32).
2938+
Set to False if you want to always infer new types.
29352939
29362940
Example:
29372941
@@ -3022,6 +3026,7 @@ def map(
30223026
"features": features,
30233027
"disable_nullable": disable_nullable,
30243028
"fn_kwargs": fn_kwargs,
3029+
"try_original_type": try_original_type,
30253030
}
30263031

30273032
if new_fingerprint is None:
@@ -3216,6 +3221,7 @@ def _map_single(
32163221
new_fingerprint: Optional[str] = None,
32173222
rank: Optional[int] = None,
32183223
offset: int = 0,
3224+
try_original_type: Optional[bool] = True,
32193225
) -> Iterable[tuple[int, bool, Union[int, "Dataset"]]]:
32203226
"""Apply a function to all the elements in the table (individually or in batches)
32213227
and update the table (if function does update examples).
@@ -3257,6 +3263,9 @@ def _map_single(
32573263
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments
32583264
rank: (`int`, optional, defaults to `None`): If specified, this is the process rank when doing multiprocessing
32593265
offset: (`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`.
3266+
try_original_type: (`Optional[bool]`, defaults to `True`):
3267+
Try to keep the types of the original columns (e.g. int32 -> int32).
3268+
Set to False if you want to always infer new types.
32603269
"""
32613270
if fn_kwargs is None:
32623271
fn_kwargs = {}
@@ -3528,7 +3537,7 @@ def iter_outputs(shard_iterable):
35283537
):
35293538
writer.write_table(batch.to_arrow())
35303539
else:
3531-
writer.write_batch(batch)
3540+
writer.write_batch(batch, try_original_type=try_original_type)
35323541
num_examples_progress_update += num_examples_in_batch
35333542
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
35343543
_time = time.time()

src/datasets/arrow_writer.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -584,13 +584,15 @@ def write_batch(
584584
self,
585585
batch_examples: dict[str, list],
586586
writer_batch_size: Optional[int] = None,
587+
try_original_type: Optional[bool] = True,
587588
):
588589
"""Write a batch of Example to file.
589590
Ignores the batch if it appears to be empty,
590591
preventing a potential schema update of unknown types.
591592
592593
Args:
593594
batch_examples: the batch of examples to add.
595+
try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`.
594596
"""
595597
if batch_examples and len(next(iter(batch_examples.values()))) == 0:
596598
return
@@ -615,7 +617,11 @@ def write_batch(
615617
arrays.append(array)
616618
inferred_features[col] = generate_from_arrow_type(col_values.type)
617619
else:
618-
col_try_type = try_features[col] if try_features is not None and col in try_features else None
620+
col_try_type = (
621+
try_features[col]
622+
if try_features is not None and col in try_features and try_original_type
623+
else None
624+
)
619625
typed_sequence = OptimizedTypedSequence(col_values, type=col_type, try_type=col_try_type, col=col)
620626
arrays.append(pa.array(typed_sequence))
621627
inferred_features[col] = typed_sequence.get_inferred_type()

tests/test_arrow_dataset.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ def inject_fixtures(self, caplog, set_sqlalchemy_silence_uber_warning):
127127
self._caplog = caplog
128128

129129
def _create_dummy_dataset(
130-
self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False, nested_features=False
130+
self,
131+
in_memory: bool,
132+
tmp_dir: str,
133+
multiple_columns=False,
134+
array_features=False,
135+
nested_features=False,
136+
int_to_float=False,
131137
) -> Dataset:
132138
assert int(multiple_columns) + int(array_features) + int(nested_features) < 2
133139
if multiple_columns:
@@ -151,6 +157,12 @@ def _create_dummy_dataset(
151157
data = {"nested": [{"a": i, "x": i * 10, "c": i * 100} for i in range(1, 11)]}
152158
features = Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}})
153159
dset = Dataset.from_dict(data, features=features)
160+
elif int_to_float:
161+
data = {
162+
"text": ["text1", "text2", "text3", "text4"],
163+
"labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]],
164+
}
165+
dset = Dataset.from_dict(data)
154166
else:
155167
dset = Dataset.from_dict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]})
156168
if not in_memory:
@@ -1123,6 +1135,27 @@ def func(x, i):
11231135
self.assertListEqual(sorted(dset_test[0].keys()), ["col_1", "col_1_plus_one"])
11241136
self.assertListEqual(sorted(dset_test.column_names), ["col_1", "col_1_plus_one", "col_2", "col_3"])
11251137
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
1138+
# casting int labels to float labels
1139+
with tempfile.TemporaryDirectory() as tmp_dir:
1140+
with self._create_dummy_dataset(in_memory, tmp_dir, int_to_float=True) as dset:
1141+
1142+
def _preprocess(examples):
1143+
result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]}
1144+
return result
1145+
1146+
with dset.map(
1147+
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True
1148+
) as dset_test:
1149+
for labels in dset_test["labels"]:
1150+
for label in labels:
1151+
self.assertIsInstance(label, int)
1152+
1153+
with dset.map(
1154+
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False
1155+
) as dset_test:
1156+
for labels in dset_test["labels"]:
1157+
for label in labels:
1158+
self.assertIsInstance(label, float)
11261159

11271160
def test_map_multiprocessing(self, in_memory):
11281161
with tempfile.TemporaryDirectory() as tmp_dir: # standard

0 commit comments

Comments
 (0)