You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/about_mapstyle_vs_iterable.mdx
+31
Original file line number
Diff line number
Diff line change
@@ -205,6 +205,37 @@ for epoch in range(n_epochs):
205
205
pass
206
206
```
207
207
208
+
## Checkpoint and resuming differences
209
+
210
+
If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.
211
+
212
+
To restart the iteration of a map-style dataset, you can simply skip the first examples:
But if you use a `DataLoader` with a `Sampler`, you should instead save the state of your sampler (you might have write a custom sampler that allows resuming).
219
+
220
+
On the other hand, iterable datasets don't provide random access to a specific example inde to resume from. But you can use [`IterableDataset.state_dict`] and [`IterableDataset.load_state_dict`] to resume from a checkpoint instead, similarly to what you can do for models and optimizers:
Under the hood, the iterable dataset keeps track of the current shard being read and the example index in the current shard and it stores this info in the `state_dict`.
231
+
232
+
To resume from a checkpoint, the dataset skips all the shards that were previously read to restart from the current shard.
233
+
Then it reads the shard and skips examples until it reaches the exact example from the checkpoint.
234
+
235
+
Therefore restarting a dataset is quite fast, since it will not re-read the shards that have already been iterated on. Still, resuming a dataset is generally not instantaneous since it has to restart reading from the beginning of the current shard and skip examples until it reaches the checkpoint location.
236
+
237
+
This can be used with the `StatefulDataLoader` from `torchdata`, see [streaming with a PyTorch DataLoader](./use_with_pytorch#stream-data).
238
+
208
239
## Switch from map-style to iterable
209
240
210
241
If you want to benefit from the "lazy" behavior of an [`IterableDataset`] or their speed advantages, you can switch your map-style [`Dataset`] to an [`IterableDataset`]:
Copy file name to clipboardExpand all lines: docs/source/stream.mdx
+58
Original file line number
Diff line number
Diff line change
@@ -360,3 +360,61 @@ Lastly, create a simple training loop and start training:
360
360
</frameworkcontent>
361
361
362
362
<!--TODO: Write the TF content! -->
363
+
364
+
### Save a dataset checkpoint and resume iteration
365
+
366
+
If you training loop stops, you may want to restart the training from where it was. To do so you can save a checkpoint of your model and optimizers, as well as your data loader.
367
+
368
+
Iterable datasets don't provide random access to a specific example index to resume from, but you can use [`IterableDataset.state_dict`] and [`IterableDataset.load_state_dict`] to resume from a checkpoint instead, similarly to what you can do for models and optimizers:
>>>for idx, example inenumerate(iterable_dataset):
373
+
...print(example)
374
+
...if idx ==2:
375
+
...state_dict= iterable_dataset.state_dict()
376
+
...print("checkpoint")
377
+
...break
378
+
>>> iterable_dataset.load_state_dict(state_dict)
379
+
>>>print(f"restart from checkpoint")
380
+
>>>for example in iterable_dataset:
381
+
...print(example)
382
+
```
383
+
384
+
Returns:
385
+
386
+
```
387
+
{'a': 0}
388
+
{'a': 1}
389
+
{'a': 2}
390
+
checkpoint
391
+
restart from checkpoint
392
+
{'a': 3}
393
+
{'a': 4}
394
+
{'a': 5}
395
+
```
396
+
397
+
Under the hood, the iterable dataset keeps track of the current shard being read and the example index in the current shard and it stores this info in the `state_dict`.
398
+
399
+
To resume from a checkpoint, the dataset skips all the shards that were previously read to restart from the current shard.
400
+
Then it reads the shard and skips examples until it reaches the exact example from the checkpoint.
401
+
402
+
Therefore restarting a dataset is quite fast, since it will not re-read the shards that have already been iterated on. Still, resuming a dataset is generally not instantaneous since it has to restart reading from the beginning of the current shard and skip examples until it reaches the checkpoint location.
403
+
404
+
This can be used with the `StatefulDataLoader`from`torchdata`:
>>>state_dict= dataloader.state_dict() # uses iterable_dataset.state_dict() under the hood
412
+
>>># resume from checkpoint
413
+
>>> dataloader.load_state_dict(state_dict) # uses iterable_dataset.load_state_dict() under the hood
414
+
```
415
+
416
+
<Tip>
417
+
418
+
Resuming returns exactly where the checkpoint was saved exceptin two cases: 1) examples from shuffle buffers are lost when resuming and the buffers are refilled with new data and2) combinations of `.with_format(arrow)`and batched `.map()` may skip one batch.
Copy file name to clipboardExpand all lines: docs/source/use_with_pytorch.mdx
+14
Original file line number
Diff line number
Diff line change
@@ -213,6 +213,20 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi
213
213
214
214
In this case each worker is given a subset of the list of shards to stream from.
215
215
216
+
If you need a DataLoader that you can checkpoint and resume in the middle of training, you can use the `StatefulDataLoader` from [torchdata](https://github.com/pytorch/data):
0 commit comments