Skip to content

Commit 43fd659

Browse files
authored
[Resumable IterableDataset] Add IterableDataset state_dict (#6658)
* add ex_iterable state_dict() and load_state_dict() * style * resuming + IterableDataset state_dict() + load_state_dict() * minor * fix tests * fix one more test * implement CyclingMultiSourcesExamplesIterable resuming * remove unused code * fix spark * fix spark tests * add test dependency * enable long paths for git * fix git command * no additional deps for windows * fix tests * fix iter_arrow resuming * tests * fix map and filter resuming * fix spark * mark spark resuming as experimental * docs * fix docs * docs * add note * add to docs too
1 parent 3d95159 commit 43fd659

File tree

12 files changed

+872
-149
lines changed

12 files changed

+872
-149
lines changed

.github/workflows/ci.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ jobs:
5656
- name: Install uv
5757
run: pip install --upgrade uv
5858
- name: Install dependencies
59-
run: |
60-
uv pip install --system "datasets[tests,metrics-tests] @ ."
61-
uv pip install --system -r additional-tests-requirements.txt --no-deps
59+
run: uv pip install --system "datasets[tests,metrics-tests] @ ."
60+
- name: Install dependencies (latest versions)
61+
if: ${{ matrix.os == 'ubuntu-latest' }}
62+
run: uv pip install --system -r additional-tests-requirements.txt --no-deps
6263
- name: Install dependencies (latest versions)
6364
if: ${{ matrix.deps_versions == 'deps-latest' }}
6465
run: uv pip install --system --upgrade pyarrow huggingface-hub dill

additional-tests-requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
unbabel-comet>=1.0.0
2+
git+https://github.com/pytorch/data.git
23
git+https://github.com/google-research/bleurt.git
34
git+https://github.com/ns-moosavi/coval.git
45
git+https://github.com/hendrycks/math.git

docs/source/about_mapstyle_vs_iterable.mdx

+31
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,37 @@ for epoch in range(n_epochs):
205205
pass
206206
```
207207

208+
## Checkpoint and resuming differences
209+
210+
If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.
211+
212+
To restart the iteration of a map-style dataset, you can simply skip the first examples:
213+
214+
```python
215+
my_dataset = my_dataset.select(range(start_index, len(dataset)))
216+
```
217+
218+
But if you use a `DataLoader` with a `Sampler`, you should instead save the state of your sampler (you might have write a custom sampler that allows resuming).
219+
220+
On the other hand, iterable datasets don't provide random access to a specific example inde to resume from. But you can use [`IterableDataset.state_dict`] and [`IterableDataset.load_state_dict`] to resume from a checkpoint instead, similarly to what you can do for models and optimizers:
221+
222+
```python
223+
>>> iterable_dataset = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=3)
224+
>>> # save in the middle of training
225+
>>> state_dict = iterable_dataset.state_dict()
226+
>>> # and resume later
227+
>>> iterable_dataset.load_state_dict(state_dict)
228+
```
229+
230+
Under the hood, the iterable dataset keeps track of the current shard being read and the example index in the current shard and it stores this info in the `state_dict`.
231+
232+
To resume from a checkpoint, the dataset skips all the shards that were previously read to restart from the current shard.
233+
Then it reads the shard and skips examples until it reaches the exact example from the checkpoint.
234+
235+
Therefore restarting a dataset is quite fast, since it will not re-read the shards that have already been iterated on. Still, resuming a dataset is generally not instantaneous since it has to restart reading from the beginning of the current shard and skip examples until it reaches the checkpoint location.
236+
237+
This can be used with the `StatefulDataLoader` from `torchdata`, see [streaming with a PyTorch DataLoader](./use_with_pytorch#stream-data).
238+
208239
## Switch from map-style to iterable
209240

210241
If you want to benefit from the "lazy" behavior of an [`IterableDataset`] or their speed advantages, you can switch your map-style [`Dataset`] to an [`IterableDataset`]:

docs/source/package_reference/main_classes.mdx

+2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
172172
- shuffle
173173
- skip
174174
- take
175+
- load_state_dict
176+
- state_dict
175177
- info
176178
- split
177179
- builder_name

docs/source/stream.mdx

+58
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,61 @@ Lastly, create a simple training loop and start training:
360360
</frameworkcontent>
361361

362362
<!-- TODO: Write the TF content! -->
363+
364+
### Save a dataset checkpoint and resume iteration
365+
366+
If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.
367+
368+
Iterable datasets don't provide random access to a specific example index to resume from, but you can use [`IterableDataset.state_dict`] and [`IterableDataset.load_state_dict`] to resume from a checkpoint instead, similarly to what you can do for models and optimizers:
369+
370+
```python
371+
>>> iterable_dataset = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=3)
372+
>>> for idx, example in enumerate(iterable_dataset):
373+
... print(example)
374+
... if idx == 2:
375+
... state_dict = iterable_dataset.state_dict()
376+
... print("checkpoint")
377+
... break
378+
>>> iterable_dataset.load_state_dict(state_dict)
379+
>>> print(f"restart from checkpoint")
380+
>>> for example in iterable_dataset:
381+
... print(example)
382+
```
383+
384+
Returns:
385+
386+
```
387+
{'a': 0}
388+
{'a': 1}
389+
{'a': 2}
390+
checkpoint
391+
restart from checkpoint
392+
{'a': 3}
393+
{'a': 4}
394+
{'a': 5}
395+
```
396+
397+
Under the hood, the iterable dataset keeps track of the current shard being read and the example index in the current shard and it stores this info in the `state_dict`.
398+
399+
To resume from a checkpoint, the dataset skips all the shards that were previously read to restart from the current shard.
400+
Then it reads the shard and skips examples until it reaches the exact example from the checkpoint.
401+
402+
Therefore restarting a dataset is quite fast, since it will not re-read the shards that have already been iterated on. Still, resuming a dataset is generally not instantaneous since it has to restart reading from the beginning of the current shard and skip examples until it reaches the checkpoint location.
403+
404+
This can be used with the `StatefulDataLoader` from `torchdata`:
405+
406+
```python
407+
>>> from torchdata.stateful_dataloader import StatefulDataLoader
408+
>>> iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
409+
>>> dataloader = StatefulDataLoader(iterable_dataset, batch_size=32, num_workers=4)
410+
>>> # checkpoint
411+
>>> state_dict = dataloader.state_dict() # uses iterable_dataset.state_dict() under the hood
412+
>>> # resume from checkpoint
413+
>>> dataloader.load_state_dict(state_dict) # uses iterable_dataset.load_state_dict() under the hood
414+
```
415+
416+
<Tip>
417+
418+
Resuming returns exactly where the checkpoint was saved except in two cases: 1) examples from shuffle buffers are lost when resuming and the buffers are refilled with new data and 2) combinations of `.with_format(arrow)` and batched `.map()` may skip one batch.
419+
420+
</Tip>

docs/source/use_with_pytorch.mdx

+14
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,20 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi
213213

214214
In this case each worker is given a subset of the list of shards to stream from.
215215

216+
If you need a DataLoader that you can checkpoint and resume in the middle of training, you can use the `StatefulDataLoader` from [torchdata](https://github.com/pytorch/data):
217+
218+
```py
219+
>>> from torchdata.stateful_dataloader import StatefulDataLoader
220+
>>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
221+
>>> dataloader = StatefulDataLoader(my_iterable_dataset, batch_size=32, num_workers=4)
222+
>>> # save in the middle of training
223+
>>> state_dict = dataloader.state_dict()
224+
>>> # and resume later
225+
>>> dataloader.load_state_dict(state_dict)
226+
```
227+
228+
This is possible thanks to [`IterableDataset.state_dict`] and [`IterableDataset.load_state_dict`].
229+
216230
### Distributed
217231

218232
To split your dataset across your training nodes, you can use [`datasets.distributed.split_dataset_by_node`]:

0 commit comments

Comments
 (0)