Skip to content

Commit cac954e

Browse files
committed
transfer learning guide
1 parent 521f516 commit cac954e

File tree

4 files changed

+494
-0
lines changed

4 files changed

+494
-0
lines changed

docs/guides/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Guides
99
state_params
1010
setup_or_nncompact
1111
model_surgery
12+
transfer_learning
1213
extracting_intermediates
1314
lr_schedule
1415

docs/guides/transfer_learning.ipynb

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Transfer learning"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"This guide demonstrates various parts of the transfer learning workflow with Flax. Depending on your task, you can use a pretrained model as a feature extractor or fine-tune the entire model. This guide uses simple classification as a default task. You will learn how to:\n",
15+
"\n",
16+
"* Load a pretrained model from HuggingFace [Transformers](https://huggingface.co/docs/transformers/index) and extract a specific sub-module from that pretrained model.\n",
17+
"* Create the classifier model.\n",
18+
"* Transfer the pretrained parameters to the new model structure.\n",
19+
"* Set up optimization for training different parts of the model separately with [Optax](https://optax.readthedocs.io/).\n",
20+
"* Set up the model for training.\n",
21+
"\n",
22+
"**Note:** Depending on your task, some of the content in this guide may be suboptimal. For example, if you are only going to train a linear classifier on top of a pretrained model, it may be better to just extract the feature embeddings once, which can result in much faster training, and you can use specialized algorithms for linear regression or logistic classification. This guide shows how to do transfer learning with all the model parameters."
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"metadata": {},
28+
"source": [
29+
"## Setup"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {
36+
"tags": [
37+
"skip-execution"
38+
]
39+
},
40+
"outputs": [],
41+
"source": [
42+
"# Note that the Transformers library doesn't use the latest Flax version.\n",
43+
"! pip install transformers[flax]\n",
44+
"# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n",
45+
"# visit https://github.com/google/jax#installation.\n",
46+
"! pip install -U flax jax jaxlib"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"metadata": {},
52+
"source": [
53+
"## Create a function for model loading\n",
54+
"\n",
55+
"To load a pre-trained classifier, you can create a custom function that will return a [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics) and its pretrained variables.\n",
56+
"\n",
57+
"In the code below, the `load_model` function uses HuggingFace's `FlaxCLIPVisionModel` model from the [Transformers](https://huggingface.co/docs/transformers/index) library and extracts a `FlaxCLIPModule` module (note that it is not a Flax `Module`):"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": 3,
63+
"metadata": {},
64+
"outputs": [],
65+
"source": [
66+
"%%capture\n",
67+
"from IPython.display import clear_output\n",
68+
"from transformers import FlaxCLIPModel\n",
69+
"\n",
70+
"# Note: FlaxCLIPModel is not a Flax Module\n",
71+
"def load_model():\n",
72+
" clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')\n",
73+
" clear_output(wait=False) # Clear the loading messages\n",
74+
" module = clip.module # Extract the Flax Module\n",
75+
" variables = {'params': clip.params} # Extract the parameters\n",
76+
" return module, variables"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"### Extract a sub-model from the loaded trained model\n",
84+
"\n",
85+
"Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of text and vision sub-modules.\n",
86+
"\n",
87+
"Suppose you want to extract the `vision_model` sub-module defined inside `.setup()` and its variables. To do this you can use [`nn.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.apply) to run a helper function that will grant you access to submodules and their variables:"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 8,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"import flax.linen as nn\n",
97+
"\n",
98+
"clip, clip_variables = load_model()\n",
99+
"\n",
100+
"def extract_submodule(clip):\n",
101+
" vision_model = clip.vision_model.clone()\n",
102+
" variables = clip.vision_model.variables\n",
103+
" return vision_model, variables\n",
104+
"\n",
105+
"vision_model, vision_model_variables = nn.apply(extract_submodule, clip)(clip_variables)"
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"metadata": {},
111+
"source": [
112+
"Notice that here `.clone()` was used to get an unbounded copy of `vision_model`, this is important to avoid leakage as bounded modules contain their variables.\n",
113+
"\n",
114+
"### Create the classifier\n",
115+
"\n",
116+
"Next create a `Classifier` model with [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics), consisting of a `backbone` (the pretrained vision model) and a `head` (the classifier)."
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": 7,
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"import jax.numpy as jnp\n",
126+
"import jax\n",
127+
"\n",
128+
"class Classifier(nn.Module):\n",
129+
" num_classes: int\n",
130+
" backbone: nn.Module\n",
131+
"\n",
132+
" @nn.compact\n",
133+
" def __call__(self, x):\n",
134+
" x = self.backbone(x).pooler_output\n",
135+
" x = nn.Dense(self.num_classes, name='head')(x)\n",
136+
" return x"
137+
]
138+
},
139+
{
140+
"cell_type": "markdown",
141+
"metadata": {},
142+
"source": [
143+
"Then, pass the `vision_model` sub-module as the backbone to the `Classifier` to create the complete model.\n",
144+
"\n",
145+
"You can randomly initialize the model's variables using some toy data for demonstration purposes."
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": 8,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"num_classes = 3\n",
155+
"model = Classifier(num_classes=num_classes, backbone=vision_model)\n",
156+
"\n",
157+
"x = jnp.ones((1, 224, 224, 3))\n",
158+
"variables = model.init(jax.random.PRNGKey(1), x)"
159+
]
160+
},
161+
{
162+
"cell_type": "markdown",
163+
"metadata": {},
164+
"source": [
165+
"## Transfer the parameters\n",
166+
"\n",
167+
"Since `variables` are randomly initialized, you now have to transfer the parameters from `vision_model_variables` to the complete `variables` at the appropriate location. This can be done by unfreezing the `variables`, updating the `backbone` parameters, and freezing the `variables` again:"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": 9,
173+
"metadata": {},
174+
"outputs": [],
175+
"source": [
176+
"from flax.core.frozen_dict import freeze\n",
177+
"\n",
178+
"variables = variables.unfreeze()\n",
179+
"variables['params']['backbone'] = vision_model_variables['params']\n",
180+
"variables = freeze(variables)"
181+
]
182+
},
183+
{
184+
"cell_type": "markdown",
185+
"metadata": {},
186+
"source": [
187+
"## Optimization\n",
188+
"\n",
189+
"If you need to to train different parts of the model separately, you have two options:\n",
190+
"\n",
191+
"1. Use `stop_gradient`.\n",
192+
"2. Filter the parameters for `jax.grad`.\n",
193+
"3. Use multiple optimizers for different parameters.\n",
194+
"\n",
195+
"While each could be useful in different situations, its recommended to use use multiple optimizers via [Optax](https://optax.readthedocs.io/)'s [`optax.multi_transform`](https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform) because it is efficient and can be easily extended to implement differential learning rates. To use `optax.multi_transform` you have to do two things:\n",
196+
"\n",
197+
"1. Define some parameter partitions.\n",
198+
"2. Create a mapping between partitions and their optimizer.\n",
199+
"3. Create a pytree with the same shape as the parameters but its leaves containing the corresponding partition label.\n",
200+
"\n",
201+
"## Freeze layers\n",
202+
"\n",
203+
"To freeze layers with `optax.multi_transform`, create the `trainable` and `frozen` parameter partitions.\n",
204+
"\n",
205+
"In the example below:\n",
206+
"\n",
207+
"- For the `trainable` parameters use the Adam (`optax.adam`) optimizer.\n",
208+
"- For the `frozen` parameters use `optax.set_to_zero`, which zeros-out the gradients.\n",
209+
"- To map parameters to partitions, you can use the [`flax.traverse_util.path_aware_map`](https://flax.readthedocs.io/en/latest/api_reference/flax.traverse_util.html#flax.traverse_util.path_aware_map) function, by leveraging the `path` argument you can map the `backbone` parameters to `frozen` and the rest to `trainable`."
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": 10,
215+
"metadata": {},
216+
"outputs": [
217+
{
218+
"data": {
219+
"text/plain": [
220+
"FrozenDict({\n",
221+
" backbone: {\n",
222+
" embeddings: {\n",
223+
" class_embedding: 'frozen',\n",
224+
" patch_embedding: {\n",
225+
" kernel: 'frozen',\n",
226+
" },\n",
227+
" },\n",
228+
" },\n",
229+
" head: {\n",
230+
" bias: 'trainable',\n",
231+
" kernel: 'trainable',\n",
232+
" },\n",
233+
"})"
234+
]
235+
},
236+
"execution_count": 10,
237+
"metadata": {},
238+
"output_type": "execute_result"
239+
}
240+
],
241+
"source": [
242+
"from flax import traverse_util\n",
243+
"import optax\n",
244+
"\n",
245+
"partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n",
246+
"param_partitions = freeze(traverse_util.path_aware_map(\n",
247+
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', variables['params']))\n",
248+
"tx = optax.multi_transform(partition_optimizers, param_partitions)\n",
249+
"\n",
250+
"# visualize a subset of the param_partitions structure\n",
251+
"flat = list(traverse_util.flatten_dict(param_partitions).items())\n",
252+
"freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))"
253+
]
254+
},
255+
{
256+
"cell_type": "markdown",
257+
"metadata": {},
258+
"source": [
259+
"To implement _differential learning rates_ simply replace `optax.set_to_zero` with the optimizer of your choice, you can choose different optimizers and partitioning schemes depending on your needs.\n",
260+
"\n",
261+
"For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.\n",
262+
"\n",
263+
"## Create the `TrainState` object for model training\n",
264+
"\n",
265+
"Once you define your module, variables, and optimizer, you can construct the `TrainState` object and proceed to train the model as you normally would."
266+
]
267+
},
268+
{
269+
"cell_type": "code",
270+
"execution_count": 12,
271+
"metadata": {},
272+
"outputs": [],
273+
"source": [
274+
"from flax.training.train_state import TrainState\n",
275+
"\n",
276+
"state = TrainState.create(\n",
277+
" apply_fn=model.apply,\n",
278+
" params=variables['params'],\n",
279+
" tx=tx)"
280+
]
281+
}
282+
],
283+
"metadata": {
284+
"jupytext": {
285+
"formats": "ipynb,md:myst"
286+
},
287+
"kernelspec": {
288+
"display_name": "Python 3.8.10 ('.venv': venv)",
289+
"language": "python",
290+
"name": "python3"
291+
},
292+
"language_info": {
293+
"codemirror_mode": {
294+
"name": "ipython",
295+
"version": 3
296+
},
297+
"file_extension": ".py",
298+
"mimetype": "text/x-python",
299+
"name": "python",
300+
"nbconvert_exporter": "python",
301+
"pygments_lexer": "ipython3",
302+
"version": "3.9.14"
303+
},
304+
"vscode": {
305+
"interpreter": {
306+
"hash": "ec7c69eb752b35b8fd728edc4753e382b54c10c43e6028c93b5837f81a552f5c"
307+
}
308+
}
309+
},
310+
"nbformat": 4,
311+
"nbformat_minor": 2
312+
}

0 commit comments

Comments
 (0)