Skip to content

Support serialization for jax.tree_util.Partial #2433

Closed
@PhilipVinc

Description

@PhilipVinc

jax.tree_util.Partial is a PyTree-aware version of functors.partial whose leaves (args and keyword args) are arrays.

Therefore, I believe it makes sense to make it serialization-aware. This allows to use Partial inside of structures that we use for checkpointing.
The definition would be as simple as the following:

# jax.tree_util.Partial does not support flax serialization
# should be upstreamed to Flax
serialization.register_serialization_state(
    jax.tree_util.Partial,
    lambda x: (
        {
            "args": serialization.to_state_dict(x.args),
            "keywords": serialization.to_state_dict(x.keywords),
        }
    ),
    lambda x, sd: jax.tree_util.Partial(
        x.func,
        *serialization.from_state_dict(x.args, sd["args"]),
        **serialization.from_state_dict(x.keywords, sd["keywords"]),
    ),
)

Would you accept this to be contributed?

Metadata

Metadata

Assignees

Labels

Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions