Skip to content

Commit 12f2b27

Browse files
author
Flax Authors
committed
Merge pull request #2668 from IvyZX:flax-pen1
PiperOrigin-RevId: 494558631
2 parents 2946b3c + 1adf0d8 commit 12f2b27

File tree

3 files changed

+392
-197
lines changed

3 files changed

+392
-197
lines changed

docs/guides/model_surgery.ipynb

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "120e57f5",
6+
"metadata": {},
7+
"source": [
8+
"Model Surgery\n",
9+
"==============================\n",
10+
"\n",
11+
"Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick."
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"id": "9c3bfb0e",
17+
"metadata": {},
18+
"source": [
19+
"## Setup"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": null,
25+
"id": "413f8b2d",
26+
"metadata": {
27+
"tags": [
28+
"skip-execution"
29+
]
30+
},
31+
"outputs": [],
32+
"source": [
33+
"!pip install --upgrade -q pip jax jaxlib flax"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"id": "5b002c8d",
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"import functools\n",
44+
"\n",
45+
"import jax\n",
46+
"import jax.numpy as jnp\n",
47+
"from flax import traverse_util\n",
48+
"from flax import linen as nn\n",
49+
"from flax.core import freeze\n",
50+
"import jax\n",
51+
"import optax"
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"id": "1060b519",
57+
"metadata": {},
58+
"source": [
59+
"Surgery with Flax Modules\n",
60+
"--------------------------------\n",
61+
"\n",
62+
"Let's create a small convolutional neural network model for our demo.\n",
63+
"\n",
64+
"As usual, you can run `CNN.init(...)['params']` to get the `params` to pass and modify it in every step of your training.\n"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"id": "755ae323",
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"class CNN(nn.Module):\n",
75+
" @nn.compact\n",
76+
" def __call__(self, x):\n",
77+
" x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
78+
" x = nn.relu(x)\n",
79+
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
80+
" x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
81+
" x = nn.relu(x)\n",
82+
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
83+
" x = x.reshape((x.shape[0], -1))\n",
84+
" x = nn.Dense(features=256)(x)\n",
85+
" x = nn.relu(x)\n",
86+
" x = nn.Dense(features=10)(x)\n",
87+
" x = nn.log_softmax(x)\n",
88+
" return x\n",
89+
"\n",
90+
"def get_initial_params(key):\n",
91+
" init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)\n",
92+
" initial_params = CNN().init(key, init_shape)['params']\n",
93+
" return initial_params\n",
94+
"\n",
95+
"key = jax.random.PRNGKey(0)\n",
96+
"params = get_initial_params(key)\n",
97+
"\n",
98+
"jax.tree_util.tree_map(jnp.shape, params)"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"id": "170273f8",
104+
"metadata": {},
105+
"source": [
106+
"Note that what returned as `params` is a `FrozenDict`, which contains a few JAX arrays as kernel and bias. \n",
107+
"\n",
108+
"A `FrozenDict` is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new `params` need to replace the old `params`. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating.\n",
109+
"\n",
110+
"One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator `sep` to join all nested keys. If no `sep` is given, the key will be a tuple of all nested keys."
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"id": "c7ec7741",
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"# Get a flattened key-value list.\n",
121+
"flat_params = traverse_util.flatten_dict(params, sep='/')\n",
122+
"\n",
123+
"jax.tree_util.tree_map(jnp.shape, flat_params)"
124+
]
125+
},
126+
{
127+
"cell_type": "markdown",
128+
"id": "2adda656",
129+
"metadata": {},
130+
"source": [
131+
"Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training."
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"id": "bb975feb",
138+
"metadata": {},
139+
"outputs": [],
140+
"source": [
141+
"# Somehow modify a layer\n",
142+
"dense_kernel = flat_params['Dense_1/kernel']\n",
143+
"flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)\n",
144+
"\n",
145+
"# Unflatten.\n",
146+
"unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')\n",
147+
"# Refreeze.\n",
148+
"unflat_params = freeze(unflat_params)\n",
149+
"jax.tree_util.tree_map(jnp.shape, unflat_params)"
150+
]
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"id": "f3462cd8",
155+
"metadata": {},
156+
"source": [
157+
"Surgery with Optimizers\n",
158+
"--------------------------------\n",
159+
"\n",
160+
"When using `Optax` as an optimizer, the ``opt_state`` is actually a nested tuple\n",
161+
"of the states of individual gradient transformations that compose the optimizer.\n",
162+
"These states contain pytrees that mirror the parameter tree, and can be modified\n",
163+
"the same way: flattening, modifying, unflattening, and then recreating a new\n",
164+
"optimizer state that mirrors the original state."
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"id": "3cbecb63",
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"tx = optax.adam(1.0)\n",
175+
"opt_state = tx.init(params)\n",
176+
"\n",
177+
"# The optimizer state is a tuple of gradient transformation states.\n",
178+
"jax.tree_util.tree_map(jnp.shape, opt_state)"
179+
]
180+
},
181+
{
182+
"cell_type": "markdown",
183+
"id": "18f1cebb",
184+
"metadata": {},
185+
"source": [
186+
"The pytrees inside the optimizer state follow the same structure as the\n",
187+
"parameters and can be flattened / modified exactly the same way."
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"id": "13b5e25f",
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')\n",
198+
"flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')\n",
199+
"\n",
200+
"jax.tree_util.tree_map(jnp.shape, flat_mu)"
201+
]
202+
},
203+
{
204+
"cell_type": "markdown",
205+
"id": "e5c4479e",
206+
"metadata": {},
207+
"source": [
208+
"After modification, re-create optimizer state. Use this for future training."
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"id": "9dcac8cd",
215+
"metadata": {},
216+
"outputs": [],
217+
"source": [
218+
"opt_state = (\n",
219+
" opt_state[0]._replace(\n",
220+
" mu=traverse_util.unflatten_dict(flat_mu, sep='/'),\n",
221+
" nu=traverse_util.unflatten_dict(flat_nu, sep='/'),\n",
222+
" ),\n",
223+
") + opt_state[1:]\n",
224+
"jax.tree_util.tree_map(jnp.shape, opt_state)"
225+
]
226+
}
227+
],
228+
"metadata": {
229+
"jupytext": {
230+
"formats": "md,ipynb"
231+
},
232+
"kernelspec": {
233+
"display_name": "Python 3 (ipykernel)",
234+
"language": "python",
235+
"name": "python3"
236+
},
237+
"language_info": {
238+
"codemirror_mode": {
239+
"name": "ipython",
240+
"version": 3
241+
},
242+
"file_extension": ".py",
243+
"mimetype": "text/x-python",
244+
"name": "python",
245+
"nbconvert_exporter": "python",
246+
"pygments_lexer": "ipython3",
247+
"version": "3.9.15"
248+
}
249+
},
250+
"nbformat": 4,
251+
"nbformat_minor": 5
252+
}

docs/guides/model_surgery.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
formats: md,ipynb
5+
text_representation:
6+
extension: .md
7+
format_name: markdown
8+
format_version: '1.3'
9+
jupytext_version: 1.13.8
10+
kernelspec:
11+
display_name: Python 3 (ipykernel)
12+
language: python
13+
name: python3
14+
---
15+
16+
Model Surgery
17+
==============================
18+
19+
Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick.
20+
21+
22+
## Setup
23+
24+
```python tags=["skip-execution"]
25+
!pip install --upgrade -q pip jax jaxlib flax
26+
```
27+
28+
```python
29+
import functools
30+
31+
import jax
32+
import jax.numpy as jnp
33+
from flax import traverse_util
34+
from flax import linen as nn
35+
from flax.core import freeze
36+
import jax
37+
import optax
38+
```
39+
40+
Surgery with Flax Modules
41+
--------------------------------
42+
43+
Let's create a small convolutional neural network model for our demo.
44+
45+
As usual, you can run `CNN.init(...)['params']` to get the `params` to pass and modify it in every step of your training.
46+
47+
48+
```python
49+
class CNN(nn.Module):
50+
@nn.compact
51+
def __call__(self, x):
52+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
53+
x = nn.relu(x)
54+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
55+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
56+
x = nn.relu(x)
57+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
58+
x = x.reshape((x.shape[0], -1))
59+
x = nn.Dense(features=256)(x)
60+
x = nn.relu(x)
61+
x = nn.Dense(features=10)(x)
62+
x = nn.log_softmax(x)
63+
return x
64+
65+
def get_initial_params(key):
66+
init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
67+
initial_params = CNN().init(key, init_shape)['params']
68+
return initial_params
69+
70+
key = jax.random.PRNGKey(0)
71+
params = get_initial_params(key)
72+
73+
jax.tree_util.tree_map(jnp.shape, params)
74+
```
75+
76+
Note that what returned as `params` is a `FrozenDict`, which contains a few JAX arrays as kernel and bias.
77+
78+
A `FrozenDict` is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new `params` need to replace the old `params`. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating.
79+
80+
One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator `sep` to join all nested keys. If no `sep` is given, the key will be a tuple of all nested keys.
81+
82+
```python
83+
# Get a flattened key-value list.
84+
flat_params = traverse_util.flatten_dict(params, sep='/')
85+
86+
jax.tree_util.tree_map(jnp.shape, flat_params)
87+
```
88+
89+
Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training.
90+
91+
```python
92+
# Somehow modify a layer
93+
dense_kernel = flat_params['Dense_1/kernel']
94+
flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)
95+
96+
# Unflatten.
97+
unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')
98+
# Refreeze.
99+
unflat_params = freeze(unflat_params)
100+
jax.tree_util.tree_map(jnp.shape, unflat_params)
101+
```
102+
103+
Surgery with Optimizers
104+
--------------------------------
105+
106+
When using `Optax` as an optimizer, the ``opt_state`` is actually a nested tuple
107+
of the states of individual gradient transformations that compose the optimizer.
108+
These states contain pytrees that mirror the parameter tree, and can be modified
109+
the same way: flattening, modifying, unflattening, and then recreating a new
110+
optimizer state that mirrors the original state.
111+
112+
```python
113+
tx = optax.adam(1.0)
114+
opt_state = tx.init(params)
115+
116+
# The optimizer state is a tuple of gradient transformation states.
117+
jax.tree_util.tree_map(jnp.shape, opt_state)
118+
```
119+
120+
The pytrees inside the optimizer state follow the same structure as the
121+
parameters and can be flattened / modified exactly the same way.
122+
123+
```python
124+
flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')
125+
flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')
126+
127+
jax.tree_util.tree_map(jnp.shape, flat_mu)
128+
```
129+
130+
After modification, re-create optimizer state. Use this for future training.
131+
132+
```python
133+
opt_state = (
134+
opt_state[0]._replace(
135+
mu=traverse_util.unflatten_dict(flat_mu, sep='/'),
136+
nu=traverse_util.unflatten_dict(flat_nu, sep='/'),
137+
),
138+
) + opt_state[1:]
139+
jax.tree_util.tree_map(jnp.shape, opt_state)
140+
```

0 commit comments

Comments
 (0)