Skip to content

Commit 382e4e8

Browse files
8bitmp3IvyZX
authored andcommitted
Update Flax Migrate checkpointing to Orbax guide
1 parent f935e54 commit 382e4e8

File tree

3 files changed

+207
-104
lines changed

3 files changed

+207
-104
lines changed

docs/guides/orbax_upgrade_guide.rst

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
.. image:: https://colab.research.google.com/assets/colab-badge.svg
2-
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/orbax_upgrade_guide.ipynb
2+
:target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/orbax_upgrade_guide.ipynb
33

4-
Upgrading my codebase to Orbax
4+
Migrate checkpointing to Orbax
55
==============================
66

7-
This guide shows you how to convert a ``flax.training.checkpoints`` call to the equivalent in `Orbax <https://github.com/google/orbax>`_.
7+
This guide shows how to convert Flax's checkpoint saving and restoring calls — `flax.training.checkpoints.save_checkpoint <https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint>`__ and `restore_checkpoint <https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints>`__ — to the equivalent `Orbax <https://github.com/google/orbax>`__ methods. Orbax provides a flexible and customizable API for managing checkpoints for various objects. Note that as Flax's checkpointing is being migrated to Orbax from ``flax.training.checkpoints``, all existing features in the Flax API will continue to be supported, but the API will change.
88

9-
See also Orbax's quick start `colab introduction <http://colab.research.google.com/github/google/orbax/blob/main/orbax//checkpoint/orbax_checkpoint.ipynb>`_ and `official documentation <https://github.com/google/orbax/blob/main/docs/checkpoint.md>`_.
9+
You will learn how to migrate to Orbax through the following scenarios:
1010

11-
Alternatively to this page, you can click the "Open in Colab" link above to run the following code in Colab environment.
11+
* The most common use case: Saving/loading and managing checkpoints
12+
* A "lightweight" use case: "Pure" saving/loading without the top-level checkpoint manager
13+
* Restoring checkpoints without a target pytree
14+
* Async checkpointing
15+
* Saving/loading a single JAX or NumPy Array
16+
17+
To learn more about Orbax, check out the `quick start introductory Colab notebook <http://colab.research.google.com/github/google/orbax/blob/main/orbax//checkpoint/orbax_checkpoint.ipynb>`__ and `the official Orbax documentation <https://github.com/google/orbax/blob/main/docs/checkpoint.md>`_.
18+
19+
You can click on "Open in Colab" above to run the code from this guide.
20+
21+
Throughout the guide, you will be able to compare code examples with and without the Orbax code.
1222

1323
.. testsetup::
1424

@@ -19,7 +29,7 @@ Alternatively to this page, you can click the "Open in Colab" link above to run
1929
import jax.numpy as jnp
2030
import numpy as np
2131

22-
# Orbax needs to enable asyncio in colab environment.
32+
# Orbax needs to have asyncio enabled in the Colab environment.
2333
import nest_asyncio
2434
nest_asyncio.apply()
2535

@@ -32,37 +42,36 @@ Alternatively to this page, you can click the "Open in Colab" link above to run
3242

3343

3444
Setup
35-
---------------------------------------
45+
*****
3646

3747
.. testcode::
3848

39-
# Some pytrees to showcase
49+
# Create some dummy variables for this example.
4050
MAX_STEPS = 5
4151
CKPT_PYTREE = [12, {'foo': 'str', 'bar': np.array((2, 3))}, [1, 4, 10]]
4252
TARGET_PYTREE = [0, {'foo': '', 'bar': np.array((0))}, [0, 0, 0]]
4353

44-
Most Common Case: Save/Load + Management
45-
---------------------------------------
46-
47-
Follow this if:
54+
Most common use case: Saving/loading and managing checkpoints
55+
*************************************************************
4856

49-
* Your original Flax ``save_checkpoint()`` or ``save_checkpoint_multiprocess()`` call contains these args: ``prefix``, ``keep``, ``keep_every_n_steps``.
57+
This section covers the following scenario:
5058

51-
* You want to use some automatic management logic for your checkpoints (e.g., delete old data, delete based on metrics/loss, etc).
59+
* Your original Flax ``save_checkpoint()`` or ``save_checkpoint_multiprocess()`` call contains the following arguments: ``prefix``, ``keep``, ``keep_every_n_steps``; or
60+
* You want to use some automatic management logic for your checkpoints (for example, for deleting old data, deleting data based on metrics/loss, and so on).
5261

53-
Then you should switch to using an ``orbax.CheckpointManager``. This allows you to not only save and load your model, but also manage your checkpoints and delete outdated checkpoints automatically.
62+
In this case, you need to use ``orbax.CheckpointManager``. This allows you to not only save and load your model, but also manage your checkpoints and delete outdated checkpoints *automatically*.
5463

55-
Modify your code to:
64+
To upgrade your code:
5665

57-
1. Create and keep an ``orbax.CheckpointManager`` instance at the top level, customized with ``orbax.CheckpointManagerOptions``
66+
1. Create and keep an ``orbax.CheckpointManager`` instance at the top level, customized with ``orbax.CheckpointManagerOptions``.
5867

59-
2. In runtime, call ``CheckpointManager.save()`` to save your data.
68+
2. At runtime, call ``orbax.CheckpointManager.save()`` to save your data.
6069

61-
3. Call ``CheckpointManager.restore()`` to restore your data.
70+
3. Then, call ``orbax.CheckpointManager.restore()`` to restore your data.
6271

63-
4. If your checkpoint includes some multihost/multiprocess array, you need to pass the correct ``mesh`` into a ``restore_args_from_target()`` to generate the correct ``restore_args`` before restoring.
72+
4. And, if your checkpoint includes some multi-host/multi-process array, pass the correct ``mesh`` into ``flax.training.orbax_utils.restore_args_from_target()`` to generate the correct ``restore_args`` before restoring.
6473

65-
See below for code examples for before and after migration.
74+
For example:
6675

6776
.. codediff::
6877
:title_left: flax.checkpoints
@@ -71,9 +80,9 @@ See below for code examples for before and after migration.
7180

7281
CKPT_DIR = './tmp/'
7382

74-
# Inside a training loop
83+
# Inside your training loop
7584
for step in range(MAX_STEPS):
76-
# ... do your training ...
85+
# do training
7786
checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step,
7887
prefix='test_', keep=3, keep_every_n_steps=2)
7988

@@ -84,16 +93,16 @@ See below for code examples for before and after migration.
8493

8594
CKPT_DIR = './tmp/'
8695

87-
# At top level
96+
# At the top level
8897
mgr_options = orbax.checkpoint.CheckpointManagerOptions(
8998
max_to_keep=3, keep_period=2, step_prefix='test_')
9099
ckpt_mgr = orbax.checkpoint.CheckpointManager(
91100
CKPT_DIR,
92101
orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options)
93102

94-
# Inside a training loop
103+
# Inside your training loop
95104
for step in range(MAX_STEPS):
96-
# ... do your training ...
105+
# do training
97106
save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)
98107
ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args})
99108

@@ -102,12 +111,14 @@ See below for code examples for before and after migration.
102111
ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args})
103112

104113

105-
Lightweight Case: Pure Save/Load without Setup
106-
-----------------------------------
114+
A "lightweight" use case: "Pure" saving/loading without the top-level checkpoint manager
115+
****************************************************************************************
107116

108-
If you prefer to not maintain a top-level checkpoint manager, you can still save and restore any individual checkpoint with an ``orbax.checkpoint.Checkpointer``. Note that this means you cannot use all the management features.
117+
If you prefer to not maintain a top-level checkpoint manager, you can still save and restore any individual checkpoint with an ``orbax.checkpoint.Checkpointer``. Note that this means you cannot use all the Orbax management features.
109118

110-
For argument ``overwrite`` in ``flax.save_checkpoint()``, use argument ``force`` in ``Checkpointer.save()`` instead.
119+
To migrate to Orbax code, instead of using the ``overwrite`` argument in ``flax.save_checkpoint()`` use the ``force`` argument in ``orbax.checkpoint.Checkpointer.save()``.
120+
121+
For example:
111122

112123
.. codediff::
113124
:title_left: flax.checkpoints
@@ -123,18 +134,20 @@ For argument ``overwrite`` in ``flax.save_checkpoint()``, use argument ``force``
123134

124135
PURE_CKPT_DIR = './tmp/pure'
125136

126-
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # stateless object, can be created on-fly
137+
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # A stateless object, can be created on the fly.
127138
ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE,
128139
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True)
129140
ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE,
130141
restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None))
131142

132143

133144

134-
Restore without a target pytree
135-
-----------------------------------
145+
Restoring checkpoints without a target pytree
146+
*********************************************
147+
148+
If you need to restore your checkpoints without a target pytree, pass ``item=None`` to ``orbax.checkpoint.Checkpointer`` or ``items=None`` to ``orbax.CheckpointManager``'s ``.restore()`` method, which should trigger the restoration.
136149

137-
Pass ``item=None`` to Orbax ``Checkpointer`` or ``items=None`` to ``CheckpointManager``'s ``.restore()`` should trigger restoration.
150+
For example:
138151

139152
.. codediff::
140153
:title_left: flax.checkpoints
@@ -150,27 +163,29 @@ Pass ``item=None`` to Orbax ``Checkpointer`` or ``items=None`` to ``CheckpointMa
150163

151164
NOTARGET_CKPT_DIR = './tmp/no_target'
152165

153-
# stateless object, can be created on-fly
166+
# A stateless object, can be created on the fly.
154167
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
155168
ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE,
156169
save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE))
157170
ckptr.restore(NOTARGET_CKPT_DIR, item=None)
158171

159172

160-
Async Checkpointing
161-
-----------------------------------
173+
Async checkpointing
174+
*******************
162175

163-
Substitute ``orbax.checkpoint.Checkpointer`` with ``orbax.checkpoint.AsyncCheckpointer`` makes all saves async.
176+
To make your checkpoint-saving asynchronous, substitute ``orbax.checkpoint.Checkpointer`` with ``orbax.checkpoint.AsyncCheckpointer``.
164177

165-
You can later call ``AsyncCheckpointer.wait_until_finished()`` or ``CheckpointerManager.wait_until_finished()`` to wait for the save the complete.
178+
Then, you can call ``orbax.checkpoint.AsyncCheckpointer.wait_until_finished()`` or Orbax's ``CheckpointerManager.wait_until_finished()`` to wait for the save the complete.
166179

167-
See more details on the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#asynchronized-checkpointing>`_.
180+
For more details, read the `checkpoint guide <https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#asynchronized-checkpointing>`_.
168181

169182

170-
Save/Load a single JAX or Numpy Array
171-
-----------------------------------
183+
Saving/loading a single JAX or NumPy Array
184+
******************************************
172185

173-
``orbax.checkpoint.PyTreeCheckpointHandler``, as the name suggests, is only for pytrees. If you want to save/restore a single Pytree leaf (e.g., an array), use ``orbax.checkpoint.ArrayCheckpointHandler`` instead.
186+
The ``orbax.checkpoint.PyTreeCheckpointHandler`` class, as the name suggests, can only be used for pytrees. Therefore, if you need to save/restore a single pytree leaf (for example, an array), use ``orbax.checkpoint.ArrayCheckpointHandler`` instead.
187+
188+
For example:
174189

175190
.. codediff::
176191
:title_left: flax.checkpoints
@@ -191,8 +206,7 @@ Save/Load a single JAX or Numpy Array
191206
ckptr.restore(ARR_CKPT_DIR, item=None)
192207

193208

209+
Final words
210+
***********
194211

195-
Final Words
196-
-----------
197-
198-
This guide only shows you how to migrate an existed Flax checkpointing call to Orbax. Orbax as a tool provides much more functionalities and is actively developing new features. Please stay tuned with their `official github repository <https://github.com/google/orbax>`_ for more!
212+
This guide provides an overview of how to migrate from the "legacy" Flax checkpointing API to the Orbax API. Orbax provides more functionalities and the Orbax team is actively developing new features. Stay tuned and follow the `official Orbax GitHub repository <https://github.com/google/orbax>`__ for more!

0 commit comments

Comments
 (0)