Skip to content

Commit 390b42c

Browse files
committed
Add serialization test for Partial
1 parent 6d2d804 commit 390b42c

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/serialization_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from flax.training import train_state
2929
import jax
3030
from jax import random
31+
from jax.tree_util import Partial
3132
import jax.numpy as jnp
3233
import msgpack
3334
import numpy as np
@@ -107,6 +108,16 @@ def test_model_serialization(self):
107108
restored_model = serialization.from_state_dict(initial_params, state)
108109
self.assertEqual(restored_model, freeze(state))
109110

111+
def test_partial_serialization(self):
112+
add_one = Partial(jnp.add, 1)
113+
state = serialization.to_state_dict(add_one)
114+
self.assertEqual(state, {
115+
'args': {'0': 1},
116+
'keywords': {}
117+
})
118+
restored_add_one = serialization.from_state_dict(add_one, state)
119+
self.assertEqual(add_one.args, restored_add_one.args)
120+
110121
def test_optimizer_serialization(self):
111122
rng = random.PRNGKey(0)
112123
module = nn.Dense(features=1, kernel_init=nn.initializers.ones)

0 commit comments

Comments
 (0)