You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This guide shows you how to convert a ``flax.training.checkpoints`` call to the equivalent in `Orbax <https://github.com/google/orbax>`_.
7
+
This guide shows how to convert Flax's checkpoint saving and restoring calls — `flax.training.checkpoints.save_checkpoint <https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint>`__ and `restore_checkpoint <https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints>`__ — to the equivalent `Orbax <https://github.com/google/orbax>`__ methods. Orbax provides a flexible and customizable API for managing checkpoints for various objects. Note that as Flax's checkpointing is being migrated to Orbax from ``flax.training.checkpoints``, all existing features in the Flax API will continue to be supported, but the API will change.
8
8
9
-
See also Orbax's quick start `colab introduction <http://colab.research.google.com/github/google/orbax/blob/main/orbax//checkpoint/orbax_checkpoint.ipynb>`_ and `official documentation <https://github.com/google/orbax/blob/main/docs/checkpoint.md>`_.
9
+
You will learn how to migrate to Orbax through the following scenarios:
10
10
11
-
Alternatively to this page, you can click the "Open in Colab" link above to run the following code in Colab environment.
11
+
* The most common use case: Saving/loading and managing checkpoints
12
+
* A "lightweight" use case: "Pure" saving/loading without the top-level checkpoint manager
13
+
* Restoring checkpoints without a target pytree
14
+
* Async checkpointing
15
+
* Saving/loading a single JAX or NumPy Array
16
+
17
+
To learn more about Orbax, check out the `quick start introductory Colab notebook <http://colab.research.google.com/github/google/orbax/blob/main/orbax//checkpoint/orbax_checkpoint.ipynb>`__ and `the official Orbax documentation <https://github.com/google/orbax/blob/main/docs/checkpoint.md>`_.
18
+
19
+
You can click on "Open in Colab" above to run the code from this guide.
20
+
21
+
Throughout the guide, you will be able to compare code examples with and without the Orbax code.
12
22
13
23
.. testsetup::
14
24
@@ -19,7 +29,7 @@ Alternatively to this page, you can click the "Open in Colab" link above to run
19
29
import jax.numpy as jnp
20
30
import numpy as np
21
31
22
-
# Orbax needs to enable asyncio in colab environment.
32
+
# Orbax needs to have asyncio enabled in the Colab environment.
23
33
import nest_asyncio
24
34
nest_asyncio.apply()
25
35
@@ -32,37 +42,36 @@ Alternatively to this page, you can click the "Open in Colab" link above to run
* Your original Flax ``save_checkpoint()`` or ``save_checkpoint_multiprocess()`` call contains these args: ``prefix``, ``keep``, ``keep_every_n_steps``.
57
+
This section covers the following scenario:
50
58
51
-
* You want to use some automatic management logic for your checkpoints (e.g., delete old data, delete based on metrics/loss, etc).
59
+
* Your original Flax ``save_checkpoint()`` or ``save_checkpoint_multiprocess()`` call contains the following arguments: ``prefix``, ``keep``, ``keep_every_n_steps``; or
60
+
* You want to use some automatic management logic for your checkpoints (for example, for deleting old data, deleting data based on metrics/loss, and so on).
52
61
53
-
Then you should switch to using an ``orbax.CheckpointManager``. This allows you to not only save and load your model, but also manage your checkpoints and delete outdated checkpoints automatically.
62
+
In this case, you need to use ``orbax.CheckpointManager``. This allows you to not only save and load your model, but also manage your checkpoints and delete outdated checkpoints *automatically*.
54
63
55
-
Modify your code to:
64
+
To upgrade your code:
56
65
57
-
1. Create and keep an ``orbax.CheckpointManager`` instance at the top level, customized with ``orbax.CheckpointManagerOptions``
66
+
1. Create and keep an ``orbax.CheckpointManager`` instance at the top level, customized with ``orbax.CheckpointManagerOptions``.
58
67
59
-
2. In runtime, call ``CheckpointManager.save()`` to save your data.
68
+
2. At runtime, call ``orbax.CheckpointManager.save()`` to save your data.
60
69
61
-
3. Call ``CheckpointManager.restore()`` to restore your data.
70
+
3. Then, call ``orbax.CheckpointManager.restore()`` to restore your data.
62
71
63
-
4. If your checkpoint includes some multihost/multiprocess array, you need to pass the correct ``mesh`` into a ``restore_args_from_target()`` to generate the correct ``restore_args`` before restoring.
72
+
4. And, if your checkpoint includes some multi-host/multi-process array, pass the correct ``mesh`` into ``flax.training.orbax_utils.restore_args_from_target()`` to generate the correct ``restore_args`` before restoring.
64
73
65
-
See below for code examples for before and after migration.
74
+
For example:
66
75
67
76
.. codediff::
68
77
:title_left: flax.checkpoints
@@ -71,9 +80,9 @@ See below for code examples for before and after migration.
If you prefer to not maintain a top-level checkpoint manager, you can still save and restore any individual checkpoint with an ``orbax.checkpoint.Checkpointer``. Note that this means you cannot use all the management features.
117
+
If you prefer to not maintain a top-level checkpoint manager, you can still save and restore any individual checkpoint with an ``orbax.checkpoint.Checkpointer``. Note that this means you cannot use all the Orbax management features.
109
118
110
-
For argument ``overwrite`` in ``flax.save_checkpoint()``, use argument ``force`` in ``Checkpointer.save()`` instead.
119
+
To migrate to Orbax code, instead of using the ``overwrite`` argument in ``flax.save_checkpoint()`` use the ``force`` argument in ``orbax.checkpoint.Checkpointer.save()``.
120
+
121
+
For example:
111
122
112
123
.. codediff::
113
124
:title_left: flax.checkpoints
@@ -123,18 +134,20 @@ For argument ``overwrite`` in ``flax.save_checkpoint()``, use argument ``force``
123
134
124
135
PURE_CKPT_DIR = './tmp/pure'
125
136
126
-
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # stateless object, can be created on-fly
137
+
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # A stateless object, can be created on the fly.
If you need to restore your checkpoints without a target pytree, pass ``item=None`` to ``orbax.checkpoint.Checkpointer`` or ``items=None`` to ``orbax.CheckpointManager``'s ``.restore()`` method, which should trigger the restoration.
136
149
137
-
Pass ``item=None`` to Orbax ``Checkpointer`` or ``items=None`` to ``CheckpointManager``'s ``.restore()`` should trigger restoration.
150
+
For example:
138
151
139
152
.. codediff::
140
153
:title_left: flax.checkpoints
@@ -150,27 +163,29 @@ Pass ``item=None`` to Orbax ``Checkpointer`` or ``items=None`` to ``CheckpointMa
Substitute ``orbax.checkpoint.Checkpointer`` with ``orbax.checkpoint.AsyncCheckpointer`` makes all saves async.
176
+
To make your checkpoint-saving asynchronous, substitute ``orbax.checkpoint.Checkpointer`` with ``orbax.checkpoint.AsyncCheckpointer``.
164
177
165
-
You can later call ``AsyncCheckpointer.wait_until_finished()`` or ``CheckpointerManager.wait_until_finished()`` to wait for the save the complete.
178
+
Then, you can call ``orbax.checkpoint.AsyncCheckpointer.wait_until_finished()`` or Orbax's ``CheckpointerManager.wait_until_finished()`` to wait for the save the complete.
166
179
167
-
See more details on the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#asynchronized-checkpointing>`_.
180
+
For more details, read the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#asynchronized-checkpointing>`_.
168
181
169
182
170
-
Save/Load a single JAX or Numpy Array
171
-
-----------------------------------
183
+
Saving/loading a single JAX or NumPy Array
184
+
******************************************
172
185
173
-
``orbax.checkpoint.PyTreeCheckpointHandler``, as the name suggests, is only for pytrees. If you want to save/restore a single Pytree leaf (e.g., an array), use ``orbax.checkpoint.ArrayCheckpointHandler`` instead.
186
+
The ``orbax.checkpoint.PyTreeCheckpointHandler`` class, as the name suggests, can only be used for pytrees. Therefore, if you need to save/restore a single pytree leaf (for example, an array), use ``orbax.checkpoint.ArrayCheckpointHandler`` instead.
187
+
188
+
For example:
174
189
175
190
.. codediff::
176
191
:title_left: flax.checkpoints
@@ -191,8 +206,7 @@ Save/Load a single JAX or Numpy Array
191
206
ckptr.restore(ARR_CKPT_DIR, item=None)
192
207
193
208
209
+
Final words
210
+
***********
194
211
195
-
Final Words
196
-
-----------
197
-
198
-
This guide only shows you how to migrate an existed Flax checkpointing call to Orbax. Orbax as a tool provides much more functionalities and is actively developing new features. Please stay tuned with their `official github repository <https://github.com/google/orbax>`_ for more!
212
+
This guide provides an overview of how to migrate from the "legacy" Flax checkpointing API to the Orbax API. Orbax provides more functionalities and the Orbax team is actively developing new features. Stay tuned and follow the `official Orbax GitHub repository <https://github.com/google/orbax>`__ for more!
0 commit comments