Closed
Description
jax.tree_util.Partial
is a PyTree-aware version of functors.partial
whose leaves (args and keyword args) are arrays.
Therefore, I believe it makes sense to make it serialization
-aware. This allows to use Partial
inside of structures that we use for checkpointing.
The definition would be as simple as the following:
# jax.tree_util.Partial does not support flax serialization
# should be upstreamed to Flax
serialization.register_serialization_state(
jax.tree_util.Partial,
lambda x: (
{
"args": serialization.to_state_dict(x.args),
"keywords": serialization.to_state_dict(x.keywords),
}
),
lambda x, sd: jax.tree_util.Partial(
x.func,
*serialization.from_state_dict(x.args, sd["args"]),
**serialization.from_state_dict(x.keywords, sd["keywords"]),
),
)
Would you accept this to be contributed?