Skip to content

Transfer Learning Guide #2394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Guides
state_params
setup_or_nncompact
model_surgery
transfer_learning
extracting_intermediates
lr_schedule

Expand Down
312 changes: 312 additions & 0 deletions docs/guides/transfer_learning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transfer learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This guide demonstrates various parts of the transfer learning workflow with Flax. Depending on your task, you can use a pretrained model as a feature extractor or fine-tune the entire model. This guide uses simple classification as a default task. You will learn how to:\n",
"\n",
"* Load a pretrained model from HuggingFace [Transformers](https://huggingface.co/docs/transformers/index) and extract a specific sub-module from that pretrained model.\n",
"* Create the classifier model.\n",
"* Transfer the pretrained parameters to the new model structure.\n",
"* Set up optimization for training different parts of the model separately with [Optax](https://optax.readthedocs.io/).\n",
"* Set up the model for training.\n",
"\n",
"**Note:** Depending on your task, some of the content in this guide may be suboptimal. For example, if you are only going to train a linear classifier on top of a pretrained model, it may be better to just extract the feature embeddings once, which can result in much faster training, and you can use specialized algorithms for linear regression or logistic classification. This guide shows how to do transfer learning with all the model parameters."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"# Note that the Transformers library doesn't use the latest Flax version.\n",
"! pip install transformers[flax]\n",
"# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n",
"# visit https://github.com/google/jax#installation.\n",
"! pip install -U flax jax jaxlib"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a function for model loading\n",
"\n",
"To load a pre-trained classifier, you can create a custom function that will return a [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics) and its pretrained variables.\n",
"\n",
"In the code below, the `load_model` function uses HuggingFace's `FlaxCLIPVisionModel` model from the [Transformers](https://huggingface.co/docs/transformers/index) library and extracts a `FlaxCLIPModule` module (note that it is not a Flax `Module`):"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"from IPython.display import clear_output\n",
"from transformers import FlaxCLIPModel\n",
"\n",
"# Note: FlaxCLIPModel is not a Flax Module\n",
"def load_model():\n",
" clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')\n",
" clear_output(wait=False) # Clear the loading messages\n",
" module = clip.module # Extract the Flax Module\n",
" variables = {'params': clip.params} # Extract the parameters\n",
" return module, variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Extract a sub-model from the loaded trained model\n",
"\n",
"Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of text and vision sub-modules.\n",
"\n",
"Suppose you want to extract the `vision_model` sub-module defined inside `.setup()` and its variables. To do this you can use [`nn.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.apply) to run a helper function that will grant you access to submodules and their variables:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import flax.linen as nn\n",
"\n",
"clip, clip_variables = load_model()\n",
"\n",
"def extract_submodule(clip):\n",
" vision_model = clip.vision_model.clone()\n",
" variables = clip.vision_model.variables\n",
" return vision_model, variables\n",
"\n",
"vision_model, vision_model_variables = nn.apply(extract_submodule, clip)(clip_variables)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that here `.clone()` was used to get an unbounded copy of `vision_model`, this is important to avoid leakage as bounded modules contain their variables.\n",
"\n",
"### Create the classifier\n",
"\n",
"Next create a `Classifier` model with [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics), consisting of a `backbone` (the pretrained vision model) and a `head` (the classifier)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax\n",
"\n",
"class Classifier(nn.Module):\n",
" num_classes: int\n",
" backbone: nn.Module\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" x = self.backbone(x).pooler_output\n",
" x = nn.Dense(self.num_classes, name='head')(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, pass the `vision_model` sub-module as the backbone to the `Classifier` to create the complete model.\n",
"\n",
"You can randomly initialize the model's variables using some toy data for demonstration purposes."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"num_classes = 3\n",
"model = Classifier(num_classes=num_classes, backbone=vision_model)\n",
"\n",
"x = jnp.ones((1, 224, 224, 3))\n",
"variables = model.init(jax.random.PRNGKey(1), x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transfer the parameters\n",
"\n",
"Since `variables` are randomly initialized, you now have to transfer the parameters from `vision_model_variables` to the complete `variables` at the appropriate location. This can be done by unfreezing the `variables`, updating the `backbone` parameters, and freezing the `variables` again:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from flax.core.frozen_dict import freeze\n",
"\n",
"variables = variables.unfreeze()\n",
"variables['params']['backbone'] = vision_model_variables['params']\n",
"variables = freeze(variables)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimization\n",
"\n",
"If you need to to train different parts of the model separately, you have two options:\n",
"\n",
"1. Use `stop_gradient`.\n",
"2. Filter the parameters for `jax.grad`.\n",
"3. Use multiple optimizers for different parameters.\n",
"\n",
"While each could be useful in different situations, its recommended to use use multiple optimizers via [Optax](https://optax.readthedocs.io/)'s [`optax.multi_transform`](https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform) because it is efficient and can be easily extended to implement differential learning rates. To use `optax.multi_transform` you have to do two things:\n",
"\n",
"1. Define some parameter partitions.\n",
"2. Create a mapping between partitions and their optimizer.\n",
"3. Create a pytree with the same shape as the parameters but its leaves containing the corresponding partition label.\n",
"\n",
"## Freeze layers\n",
"\n",
"To freeze layers with `optax.multi_transform`, create the `trainable` and `frozen` parameter partitions.\n",
"\n",
"In the example below:\n",
"\n",
"- For the `trainable` parameters use the Adam (`optax.adam`) optimizer.\n",
"- For the `frozen` parameters use `optax.set_to_zero`, which zeros-out the gradients.\n",
"- To map parameters to partitions, you can use the [`flax.traverse_util.path_aware_map`](https://flax.readthedocs.io/en/latest/api_reference/flax.traverse_util.html#flax.traverse_util.path_aware_map) function, by leveraging the `path` argument you can map the `backbone` parameters to `frozen` and the rest to `trainable`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FrozenDict({\n",
" backbone: {\n",
" embeddings: {\n",
" class_embedding: 'frozen',\n",
" patch_embedding: {\n",
" kernel: 'frozen',\n",
" },\n",
" },\n",
" },\n",
" head: {\n",
" bias: 'trainable',\n",
" kernel: 'trainable',\n",
" },\n",
"})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from flax import traverse_util\n",
"import optax\n",
"\n",
"partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n",
"param_partitions = freeze(traverse_util.path_aware_map(\n",
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', variables['params']))\n",
"tx = optax.multi_transform(partition_optimizers, param_partitions)\n",
"\n",
"# visualize a subset of the param_partitions structure\n",
"flat = list(traverse_util.flatten_dict(param_partitions).items())\n",
"freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To implement _differential learning rates_ simply replace `optax.set_to_zero` with the optimizer of your choice, you can choose different optimizers and partitioning schemes depending on your needs.\n",
"\n",
"For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.\n",
"\n",
"## Create the `TrainState` object for model training\n",
"\n",
"Once you define your module, variables, and optimizer, you can construct the `TrainState` object and proceed to train the model as you normally would."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from flax.training.train_state import TrainState\n",
"\n",
"state = TrainState.create(\n",
" apply_fn=model.apply,\n",
" params=variables['params'],\n",
" tx=tx)"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3.8.10 ('.venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.14"
},
"vscode": {
"interpreter": {
"hash": "ec7c69eb752b35b8fd728edc4753e382b54c10c43e6028c93b5837f81a552f5c"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading