File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -23,6 +23,12 @@ Let's create a small convolutional neural network model for our demo.
23
23
24
24
.. testcode ::
25
25
26
+ import jax
27
+ import jax.numpy as jnp
28
+ from flax import traverse_util
29
+ from flax import linen as nn
30
+ from flax.core import freeze
31
+
26
32
class CNN(nn.Module):
27
33
@nn.compact
28
34
def __call__(self, x):
@@ -71,7 +77,7 @@ Let's create a small convolutional neural network model for our demo.
71
77
})
72
78
73
79
74
- Next, get a flat dict for doing model surgery as follows :
80
+ Next, get a flat dict for doing model surgery by using :meth: ` traverse_util.flatten_dict() <flax.traverse_util.flatten_dict> ` :
75
81
76
82
.. testcode ::
77
83
@@ -134,6 +140,8 @@ optimizer state that mirrors the original state.
134
140
135
141
.. testcode ::
136
142
143
+ import optax
144
+
137
145
tx = optax.adam(1.0)
138
146
opt_state = tx.init(params)
139
147
You can’t perform that action at this time.
0 commit comments