Skip to content

Commit db690a4

Browse files
author
Flax Authors
committed
Merge pull request #2687 from chiamp:flax_docs_model_surgery
PiperOrigin-RevId: 492575776
2 parents 650dcb2 + 33f95f8 commit db690a4

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

docs/guides/model_surgery.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ Let's create a small convolutional neural network model for our demo.
2323

2424
.. testcode::
2525

26+
import jax
27+
import jax.numpy as jnp
28+
from flax import traverse_util
29+
from flax import linen as nn
30+
from flax.core import freeze
31+
2632
class CNN(nn.Module):
2733
@nn.compact
2834
def __call__(self, x):
@@ -71,7 +77,7 @@ Let's create a small convolutional neural network model for our demo.
7177
})
7278

7379

74-
Next, get a flat dict for doing model surgery as follows:
80+
Next, get a flat dict for doing model surgery by using :meth:`traverse_util.flatten_dict() <flax.traverse_util.flatten_dict>`:
7581

7682
.. testcode::
7783

@@ -134,6 +140,8 @@ optimizer state that mirrors the original state.
134140

135141
.. testcode::
136142

143+
import optax
144+
137145
tx = optax.adam(1.0)
138146
opt_state = tx.init(params)
139147

0 commit comments

Comments
 (0)