diff --git a/flax/serialization.py b/flax/serialization.py index 1d8404417..a7506763a 100644 --- a/flax/serialization.py +++ b/flax/serialization.py @@ -193,6 +193,20 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]): _namedtuple_state_dict, _restore_namedtuple) +register_serialization_state( + jax.tree_util.Partial, + lambda x: ( + { + "args": to_state_dict(x.args), + "keywords": to_state_dict(x.keywords), + } + ), + lambda x, sd: jax.tree_util.Partial( + x.func, + *from_state_dict(x.args, sd["args"]), + **from_state_dict(x.keywords, sd["keywords"]), + ), +) # On-the-wire / disk serialization format diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 8219f84cc..e1d6c549d 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -28,6 +28,7 @@ from flax.training import train_state import jax from jax import random +from jax.tree_util import Partial import jax.numpy as jnp import msgpack import numpy as np @@ -107,6 +108,16 @@ def test_model_serialization(self): restored_model = serialization.from_state_dict(initial_params, state) self.assertEqual(restored_model, freeze(state)) + def test_partial_serialization(self): + add_one = Partial(jnp.add, 1) + state = serialization.to_state_dict(add_one) + self.assertEqual(state, { + 'args': {'0': 1}, + 'keywords': {} + }) + restored_add_one = serialization.from_state_dict(add_one, state) + self.assertEqual(add_one.args, restored_add_one.args) + def test_optimizer_serialization(self): rng = random.PRNGKey(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones)