File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change 28
28
from flax .training import train_state
29
29
import jax
30
30
from jax import random
31
+ from jax .tree_util import Partial
31
32
import jax .numpy as jnp
32
33
import msgpack
33
34
import numpy as np
@@ -107,6 +108,16 @@ def test_model_serialization(self):
107
108
restored_model = serialization .from_state_dict (initial_params , state )
108
109
self .assertEqual (restored_model , freeze (state ))
109
110
111
+ def test_partial_serialization (self ):
112
+ add_one = Partial (jnp .add , 1 )
113
+ state = serialization .to_state_dict (add_one )
114
+ self .assertEqual (state , {
115
+ 'args' : {'0' : 1 },
116
+ 'keywords' : {}
117
+ })
118
+ restored_add_one = serialization .from_state_dict (add_one , state )
119
+ self .assertEqual (add_one .args , restored_add_one .args )
120
+
110
121
def test_optimizer_serialization (self ):
111
122
rng = random .PRNGKey (0 )
112
123
module = nn .Dense (features = 1 , kernel_init = nn .initializers .ones )
You can’t perform that action at this time.
0 commit comments