53
53
TypeHandler = type_handlers .TypeHandler
54
54
AggregateHandler = aggregate_handlers .AggregateHandler
55
55
MsgpackHandler = aggregate_handlers .MsgpackHandler
56
- TransformFn = Callable [[PyTree , PyTree , PyTree ], Tuple [PyTree , PyTree ]]
56
+ LegacyTransformFn = Callable [[PyTree , PyTree , PyTree ], Tuple [PyTree , PyTree ]]
57
57
Transform = transform_utils .Transform
58
58
RestoreTransform = transform_utils .RestoreTransform
59
59
JsonCheckpointHandler = json_checkpoint_handler .JsonCheckpointHandler
@@ -849,7 +849,7 @@ def restore(
849
849
restore_args : Optional [PyTree ] = None ,
850
850
transforms : Optional [PyTree ] = None ,
851
851
transforms_default_to_original : bool = True ,
852
- transform_fn : Optional [TransformFn ] = None ,
852
+ legacy_transform_fn : Optional [LegacyTransformFn ] = None ,
853
853
) -> PyTree :
854
854
"""Restores a PyTree from the checkpoint directory at the given path.
855
855
@@ -940,9 +940,9 @@ class TrainState:
940
940
completely.
941
941
See `transform_utils` for further information.
942
942
transforms_default_to_original: See transform_utils.apply_transformations.
943
- transform_fn : WARNING: NOT GENERALLY SUPPORTED. A function which accepts
944
- the `item` argument, a PyTree checkpoint structure and a PyTree of
945
- ParamInfos based on the checkpoint. Returns a transformed PyTree
943
+ legacy_transform_fn : WARNING: NOT GENERALLY SUPPORTED. A function which
944
+ accepts the `item` argument, a PyTree checkpoint structure and a PyTree
945
+ of ParamInfos based on the checkpoint. Returns a transformed PyTree
946
946
matching the desired return tree structure, and a matching ParamInfo
947
947
tree.
948
948
@@ -982,10 +982,12 @@ async def _create_byte_limiter():
982
982
transforms_default_to_original = transforms_default_to_original ,
983
983
)
984
984
985
- if transform_fn is not None and transforms is not None :
986
- raise ValueError ('Cannot provide both `transforms` and `transform_fn`.' )
987
- if transform_fn is not None :
988
- structure , param_infos = transform_fn (item , structure , param_infos )
985
+ if legacy_transform_fn is not None and transforms is not None :
986
+ raise ValueError (
987
+ 'Cannot provide both `transforms` and `legacy_transform_fn`.'
988
+ )
989
+ if legacy_transform_fn is not None :
990
+ structure , param_infos = legacy_transform_fn (item , structure , param_infos )
989
991
if restore_args is None :
990
992
restore_args = jax .tree_util .tree_map (lambda x : RestoreArgs (), item )
991
993
checkpoint_restore_args = restore_args
@@ -1009,7 +1011,7 @@ def _maybe_set_default_restore_types(
1009
1011
self ._maybe_deserialize (structure , param_infos , checkpoint_restore_args )
1010
1012
)
1011
1013
1012
- if not transform_fn :
1014
+ if not legacy_transform_fn :
1013
1015
restored_item = _transform_checkpoint (
1014
1016
item ,
1015
1017
restored_item ,
0 commit comments