Skip to content

Add serialization for Partial functions. Completes #2433 #2557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flax/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
_namedtuple_state_dict,
_restore_namedtuple)

register_serialization_state(
jax.tree_util.Partial,
lambda x: (
{
"args": to_state_dict(x.args),
"keywords": to_state_dict(x.keywords),
}
),
lambda x, sd: jax.tree_util.Partial(
x.func,
*from_state_dict(x.args, sd["args"]),
**from_state_dict(x.keywords, sd["keywords"]),
),
)

# On-the-wire / disk serialization format

Expand Down
11 changes: 11 additions & 0 deletions tests/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from flax.training import train_state
import jax
from jax import random
from jax.tree_util import Partial
import jax.numpy as jnp
import msgpack
import numpy as np
Expand Down Expand Up @@ -107,6 +108,16 @@ def test_model_serialization(self):
restored_model = serialization.from_state_dict(initial_params, state)
self.assertEqual(restored_model, freeze(state))

def test_partial_serialization(self):
add_one = Partial(jnp.add, 1)
state = serialization.to_state_dict(add_one)
self.assertEqual(state, {
'args': {'0': 1},
'keywords': {}
})
restored_add_one = serialization.from_state_dict(add_one, state)
self.assertEqual(add_one.args, restored_add_one.args)

def test_optimizer_serialization(self):
rng = random.PRNGKey(0)
module = nn.Dense(features=1, kernel_init=nn.initializers.ones)
Expand Down