Skip to content

Commit e83d6fa

Browse files
lappemiclhoestq
andauthored
Add batching to IterableDataset (#7054)
* feat: add `.batch() to `IterableDataset` and introduce new `BatchedExamplesIterable` * style: formatting... * refactor: implement feedback to use .map() * test: add tests for new `batch()` method * style: formatting... * fix: remove type hints in `batch_fn()` to fix failing CI * docs: add section "Batching data in IterableDataset" to "Differences between Dataset and IterableDataset" * refactor: apply feedback * docs nit --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 16fa442 commit e83d6fa

File tree

5 files changed

+110
-4
lines changed

5 files changed

+110
-4
lines changed

docs/source/about_mapstyle_vs_iterable.mdx

-4
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ for epoch in range(n_epochs):
205205
pass
206206
```
207207

208-
## Checkpoint and resuming differences
209-
210-
If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.
211-
212208
To restart the iteration of a map-style dataset, you can simply skip the first examples:
213209

214210
```python

docs/source/package_reference/main_classes.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
170170
- rename_column
171171
- filter
172172
- shuffle
173+
- batch
173174
- skip
174175
- take
175176
- load_state_dict

docs/source/stream.mdx

+38
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,44 @@ You can filter rows in the dataset based on a predicate function using [`Dataset
318318
{'id': 4, 'text': 'Are you looking for Number the Stars (Essential Modern Classics)? Normally, ...'}]
319319
```
320320

321+
## Batch
322+
323+
The `batch` method transforms your `IterableDataset` into an iterable of batches. This is particularly useful when you want to work with batches in your training loop or when using frameworks that expect batched inputs.
324+
325+
<Tip>
326+
327+
There is also a "Batch Processing" option when using the `map` function to apply a function to batches of data, which is discussed in the [Map section](#map) above. The `batch` method described here is different and provides a more direct way to create batches from your dataset.
328+
329+
</Tip>
330+
331+
You can use the `batch` method like this:
332+
333+
```python
334+
from datasets import load_dataset
335+
336+
# Load a dataset in streaming mode
337+
dataset = load_dataset("some_dataset", split="train", streaming=True)
338+
339+
# Create batches of 32 samples
340+
batched_dataset = dataset.batch(batch_size=32)
341+
342+
# Iterate over the batched dataset
343+
for batch in batched_dataset:
344+
print(batch)
345+
break
346+
```
347+
348+
In this example, batched_dataset is still an IterableDataset, but each item yielded is now a batch of 32 samples instead of a single sample.
349+
This batching is done on-the-fly as you iterate over the dataset, preserving the memory-efficient nature of IterableDataset.
350+
351+
The batch method also provides a drop_last_batch parameter.
352+
When set to True, it will discard the last batch if it's smaller than the specified batch_size.
353+
This can be useful in scenarios where your downstream processing requires all batches to be of the same size:
354+
355+
```python
356+
batched_dataset = dataset.batch(batch_size=32, drop_last_batch=True)
357+
```
358+
321359
## Stream in a training loop
322360

323361
[`IterableDataset`] can be integrated into a training loop. First, shuffle the dataset:

src/datasets/iterable_dataset.py

+20
Original file line numberDiff line numberDiff line change
@@ -2885,6 +2885,26 @@ def _resolve_features(self):
28852885
token_per_repo_id=self._token_per_repo_id,
28862886
)
28872887

2888+
def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset":
2889+
"""
2890+
Group samples from the dataset into batches.
2891+
2892+
Args:
2893+
batch_size (`int`): The number of samples in each batch.
2894+
drop_last_batch (`bool`, defaults to `False`): Whether to drop the last incomplete batch.
2895+
2896+
Example:
2897+
```py
2898+
>>> ds = load_dataset("some_dataset", streaming=True)
2899+
>>> batched_ds = ds.batch(batch_size=32)
2900+
```
2901+
"""
2902+
2903+
def batch_fn(unbatched):
2904+
return {k: [v] for k, v in unbatched.items()}
2905+
2906+
return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)
2907+
28882908

28892909
def _concatenate_iterable_datasets(
28902910
dsets: List[IterableDataset],

tests/test_iterable_dataset.py

+51
Original file line numberDiff line numberDiff line change
@@ -2176,3 +2176,54 @@ def test_resume_dataloader(dataset: IterableDataset):
21762176
dl = StatefulDataLoader(dataset)
21772177
dl.load_state_dict(state_dict)
21782178
assert remaining == list(dl)
2179+
2180+
2181+
def test_iterable_dataset_batch():
2182+
# Create a simple IterableDataset
2183+
data = [{"id": i, "text": f"Text {i}"} for i in range(10)]
2184+
ds = IterableDataset.from_generator(lambda: (x for x in data))
2185+
2186+
# Test with batch_size=3, drop_last_batch=False
2187+
batched_ds = ds.batch(batch_size=3, drop_last_batch=False)
2188+
batches = list(batched_ds)
2189+
2190+
assert len(batches) == 4 # 3 full batches and 1 partial batch
2191+
for i, batch in enumerate(batches[:3]): # Check full batches
2192+
assert len(batch["id"]) == 3
2193+
assert len(batch["text"]) == 3
2194+
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
2195+
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
2196+
2197+
# Check last partial batch
2198+
assert len(batches[3]["id"]) == 1
2199+
assert len(batches[3]["text"]) == 1
2200+
assert batches[3]["id"] == [9]
2201+
assert batches[3]["text"] == ["Text 9"]
2202+
2203+
# Test with batch_size=3, drop_last_batch=True
2204+
batched_ds = ds.batch(batch_size=3, drop_last_batch=True)
2205+
batches = list(batched_ds)
2206+
2207+
assert len(batches) == 3 # Only full batches
2208+
for i, batch in enumerate(batches):
2209+
assert len(batch["id"]) == 3
2210+
assert len(batch["text"]) == 3
2211+
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
2212+
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
2213+
2214+
# Test with batch_size=4 (doesn't evenly divide dataset size)
2215+
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
2216+
batches = list(batched_ds)
2217+
2218+
assert len(batches) == 3 # 2 full batches and 1 partial batch
2219+
for i, batch in enumerate(batches[:2]): # Check full batches
2220+
assert len(batch["id"]) == 4
2221+
assert len(batch["text"]) == 4
2222+
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
2223+
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]
2224+
2225+
# Check last partial batch
2226+
assert len(batches[2]["id"]) == 2
2227+
assert len(batches[2]["text"]) == 2
2228+
assert batches[2]["id"] == [8, 9]
2229+
assert batches[2]["text"] == ["Text 8", "Text 9"]

0 commit comments

Comments
 (0)