Skip to content

Make importing tensorstore optional and move related type hints to comments. #2348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@
from flax import serialization
from flax import traverse_util
from jax import process_index
from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
from jax.experimental.gda_serialization.serialization import GlobalAsyncCheckpointManager
from jax.experimental.global_device_array import GlobalDeviceArray
from tensorflow.io import gfile # pytype: disable=import-error

_IMPORT_GDAM_SUCCESSFUL = False
try:
from jax.experimental.gda_serialization.serialization import get_tensorstore_spec
from jax.experimental.gda_serialization.serialization import GlobalAsyncCheckpointManager
_IMPORT_GDAM_SUCCESSFUL = True
except ImportError:
logging.warning('GlobalAsyncCheckpointManager is not imported correctly. '
'Checkpointing of GlobalDeviceArrays will not be available.'
'To use the feature, install tensorstore.')


# Single-group reg-exps for int or float numerical substrings.
# captures sign:
Expand Down Expand Up @@ -97,7 +105,7 @@ def on_commit_callback(temp_path, final_path):
logging.info('Finished saving checkpoint to `%s`.', final_path)


def _save_gdas(gda_manager: GlobalAsyncCheckpointManager,
def _save_gdas(gda_manager,
gda_targets: List[Tuple[GlobalDeviceArray, str]],
tmp_path: str, final_path: str):
gda_list, gda_subpaths = zip(*gda_targets)
Expand All @@ -115,7 +123,7 @@ def _restore_gdas(state_dict,
target: Optional[Any],
ckpt_path: str,
step: Optional[Union[int, float]] = None,
gda_manager: Optional[GlobalAsyncCheckpointManager] = None):
gda_manager: Optional[Any] = None):

# When target is a single leaf instead of a pytree dict.
if not isinstance(state_dict, (core.FrozenDict, dict)):
Expand Down Expand Up @@ -229,8 +237,7 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
overwrite: bool = False,
keep_every_n_steps: Optional[int] = None,
async_manager: Optional[AsyncManager] = None,
gda_manager: Optional[
GlobalAsyncCheckpointManager] = None) -> str:
gda_manager: Optional[Any] = None) -> str:
"""Save a checkpoint of the model.

Attempts to be pre-emption safe by writing to temporary before
Expand All @@ -249,9 +256,10 @@ def save_checkpoint(ckpt_dir: Union[str, os.PathLike],
async_manager: if defined, the save will run without blocking the main
thread. Only works for single host. Note that an ongoing save will still
block subsequent saves, to make sure overwrite/keep logic works correctly.
gda_manager: required if target contains a JAX GlobalDeviceArray. Will save
the GDAs to a separate subdirectory with postfix "_gda" asynchronously.
Same as async_manager, this will block subsequent saves.
gda_manager: required if target contains a JAX GlobalDeviceArray. Type
should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported
correctly). Will save the GDAs to a separate subdirectory with postfix
"_gda" asynchronously. Same as async_manager, this will block subsequent saves.
Returns:
Filename of saved checkpoint.
"""
Expand Down Expand Up @@ -345,7 +353,7 @@ def save_task():
else:
save_task()

if gda_targets:
if gda_targets and _IMPORT_GDAM_SUCCESSFUL:
if not gda_manager:
raise errors.GDACheckpointingRequiredError(ckpt_path, step)
gda_tmp_path, gda_final_path = ckpt_tmp_path + '_gda', ckpt_path + '_gda'
Expand Down Expand Up @@ -385,7 +393,7 @@ def restore_checkpoint(
step: Optional[Union[int, float]] = None,
prefix: str = 'checkpoint_',
parallel: bool = True,
gda_manager: Optional[GlobalAsyncCheckpointManager] = None) -> PyTree:
gda_manager: Optional[Any] = None) -> PyTree:
"""Restore last/best checkpoint from checkpoints in path.

Sorts the checkpoint files naturally, returning the highest-valued
Expand All @@ -405,8 +413,10 @@ def restore_checkpoint(
ckpt_dir must be a directory.
prefix: str: name prefix of checkpoint files.
parallel: bool: whether to load seekable checkpoints in parallel, for speed.
gda_manager: required if checkpoint contains a JAX GlobalDeviceArray. Will
read the GDAs from the separate subdirectory with postfix "_gda".
gda_manager: required if checkpoint contains a JAX GlobalDeviceArray. Type
should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported
correctly). Will read the GDAs from the separate subdirectory with postfix
"_gda".

Returns:
Restored `target` updated from checkpoint file, or if no step specified and
Expand Down Expand Up @@ -462,7 +472,8 @@ def read_chunk(i):
checkpoint_contents = fp.read()

state_dict = serialization.msgpack_restore(checkpoint_contents)
state_dict = _restore_gdas(state_dict, target, ckpt_path, step, gda_manager)
if _IMPORT_GDAM_SUCCESSFUL:
state_dict = _restore_gdas(state_dict, target, ckpt_path, step, gda_manager)

if target is None:
return state_dict
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"msgpack",
"optax",
"rich~=11.1",
"tensorstore",
"typing_extensions>=4.1.1",
"PyYAML>=5.4.1",
]
Expand Down