Skip to content

Commit ebb8a8e

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 557561606
1 parent 5947368 commit ebb8a8e

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
TypeHandler = type_handlers.TypeHandler
5454
AggregateHandler = aggregate_handlers.AggregateHandler
5555
MsgpackHandler = aggregate_handlers.MsgpackHandler
56-
TransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]]
56+
LegacyTransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]]
5757
Transform = transform_utils.Transform
5858
RestoreTransform = transform_utils.RestoreTransform
5959
JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler
@@ -849,7 +849,7 @@ def restore(
849849
restore_args: Optional[PyTree] = None,
850850
transforms: Optional[PyTree] = None,
851851
transforms_default_to_original: bool = True,
852-
transform_fn: Optional[TransformFn] = None,
852+
legacy_transform_fn: Optional[LegacyTransformFn] = None,
853853
) -> PyTree:
854854
"""Restores a PyTree from the checkpoint directory at the given path.
855855
@@ -940,9 +940,9 @@ class TrainState:
940940
completely.
941941
See `transform_utils` for further information.
942942
transforms_default_to_original: See transform_utils.apply_transformations.
943-
transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which accepts
944-
the `item` argument, a PyTree checkpoint structure and a PyTree of
945-
ParamInfos based on the checkpoint. Returns a transformed PyTree
943+
legacy_transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which
944+
accepts the `item` argument, a PyTree checkpoint structure and a PyTree
945+
of ParamInfos based on the checkpoint. Returns a transformed PyTree
946946
matching the desired return tree structure, and a matching ParamInfo
947947
tree.
948948
@@ -982,10 +982,12 @@ async def _create_byte_limiter():
982982
transforms_default_to_original=transforms_default_to_original,
983983
)
984984

985-
if transform_fn is not None and transforms is not None:
986-
raise ValueError('Cannot provide both `transforms` and `transform_fn`.')
987-
if transform_fn is not None:
988-
structure, param_infos = transform_fn(item, structure, param_infos)
985+
if legacy_transform_fn is not None and transforms is not None:
986+
raise ValueError(
987+
'Cannot provide both `transforms` and `legacy_transform_fn`.'
988+
)
989+
if legacy_transform_fn is not None:
990+
structure, param_infos = legacy_transform_fn(item, structure, param_infos)
989991
if restore_args is None:
990992
restore_args = jax.tree_util.tree_map(lambda x: RestoreArgs(), item)
991993
checkpoint_restore_args = restore_args
@@ -1009,7 +1011,7 @@ def _maybe_set_default_restore_types(
10091011
self._maybe_deserialize(structure, param_infos, checkpoint_restore_args)
10101012
)
10111013

1012-
if not transform_fn:
1014+
if not legacy_transform_fn:
10131015
restored_item = _transform_checkpoint(
10141016
item,
10151017
restored_item,

checkpoint/orbax/checkpoint/type_handlers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,7 @@ def __init__(
443443
'Must provide a ts.Context if use_ocdbt is True. Ensure that the'
444444
' context contains a coordinator address.'
445445
)
446-
self._ts_context = ts_context or ts.Context(
447-
{'file_io_concurrency': {'limit': 128}}
448-
)
446+
self._ts_context = ts_context or serialization.TS_CONTEXT
449447

450448
def enable_ocdbt(self, ts_context: ts.Context) -> None:
451449
self._use_ocdbt = True
@@ -651,9 +649,7 @@ def __init__(
651649
'Must provide a ts.Context if use_ocdbt is True. Ensure that the'
652650
' context contains a coordinator address.'
653651
)
654-
self._ts_context = ts_context or ts.Context(
655-
{'file_io_concurrency': {'limit': 128}}
656-
)
652+
self._ts_context = ts_context or serialization.TS_CONTEXT
657653

658654
def enable_ocdbt(self, ts_context: ts.Context) -> None:
659655
self._use_ocdbt = True

checkpoint/orbax/checkpoint/value_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class ArrayMetadata(Metadata):
3636
Tuple of integers describing the array shape.
3737
shards:
3838
Tuple of integers indicating how many shards each dimension is divided
39-
into. May be None if the array is not sharded.
39+
into. E.g. a dimension may be 1 if it is unsharded, or 2 if it is divided
40+
into 2 chunks.
41+
May be None if the array is not sharded.
4042
dtype:
4143
Dtype of array elements.
4244
"""

0 commit comments

Comments
 (0)