-
Notifications
You must be signed in to change notification settings - Fork 716
Transfer Learning Guide #2394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
copybara-service
merged 1 commit into
google:main
from
cgarciae:transfer-learning-guide
Oct 21, 2022
Merged
Transfer Learning Guide #2394
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Transfer learning" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This guide demonstrates various parts of the transfer learning workflow with Flax. Depending on your task, you can use a pretrained model as a feature extractor or fine-tune the entire model. This guide uses simple classification as a default task. You will learn how to:\n", | ||
"\n", | ||
"* Load a pretrained model from HuggingFace [Transformers](https://huggingface.co/docs/transformers/index) and extract a specific sub-module from that pretrained model.\n", | ||
"* Create the classifier model.\n", | ||
"* Transfer the pretrained parameters to the new model structure.\n", | ||
"* Set up optimization for training different parts of the model separately with [Optax](https://optax.readthedocs.io/).\n", | ||
"* Set up the model for training.\n", | ||
"\n", | ||
"**Note:** Depending on your task, some of the content in this guide may be suboptimal. For example, if you are only going to train a linear classifier on top of a pretrained model, it may be better to just extract the feature embeddings once, which can result in much faster training, and you can use specialized algorithms for linear regression or logistic classification. This guide shows how to do transfer learning with all the model parameters." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [ | ||
"skip-execution" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Note that the Transformers library doesn't use the latest Flax version.\n", | ||
"! pip install transformers[flax]\n", | ||
"# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n", | ||
"# visit https://github.com/google/jax#installation.\n", | ||
"! pip install -U flax jax jaxlib" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create a function for model loading\n", | ||
"\n", | ||
"To load a pre-trained classifier, you can create a custom function that will return a [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics) and its pretrained variables.\n", | ||
"\n", | ||
"In the code below, the `load_model` function uses HuggingFace's `FlaxCLIPVisionModel` model from the [Transformers](https://huggingface.co/docs/transformers/index) library and extracts a `FlaxCLIPModule` module (note that it is not a Flax `Module`):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"from IPython.display import clear_output\n", | ||
"from transformers import FlaxCLIPModel\n", | ||
"\n", | ||
"# Note: FlaxCLIPModel is not a Flax Module\n", | ||
"def load_model():\n", | ||
" clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')\n", | ||
" clear_output(wait=False) # Clear the loading messages\n", | ||
" module = clip.module # Extract the Flax Module\n", | ||
" variables = {'params': clip.params} # Extract the parameters\n", | ||
" return module, variables" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Extract a sub-model from the loaded trained model\n", | ||
"\n", | ||
"Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of text and vision sub-modules.\n", | ||
"\n", | ||
"Suppose you want to extract the `vision_model` sub-module defined inside `.setup()` and its variables. To do this you can use [`nn.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.apply) to run a helper function that will grant you access to submodules and their variables:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import flax.linen as nn\n", | ||
"\n", | ||
"clip, clip_variables = load_model()\n", | ||
"\n", | ||
"def extract_submodule(clip):\n", | ||
" vision_model = clip.vision_model.clone()\n", | ||
" variables = clip.vision_model.variables\n", | ||
" return vision_model, variables\n", | ||
"\n", | ||
"vision_model, vision_model_variables = nn.apply(extract_submodule, clip)(clip_variables)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Notice that here `.clone()` was used to get an unbounded copy of `vision_model`, this is important to avoid leakage as bounded modules contain their variables.\n", | ||
"\n", | ||
"### Create the classifier\n", | ||
"\n", | ||
"Next create a `Classifier` model with [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics), consisting of a `backbone` (the pretrained vision model) and a `head` (the classifier)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax.numpy as jnp\n", | ||
"import jax\n", | ||
"\n", | ||
"class Classifier(nn.Module):\n", | ||
" num_classes: int\n", | ||
" backbone: nn.Module\n", | ||
cgarciae marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, x):\n", | ||
" x = self.backbone(x).pooler_output\n", | ||
" x = nn.Dense(self.num_classes, name='head')(x)\n", | ||
" return x" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, pass the `vision_model` sub-module as the backbone to the `Classifier` to create the complete model.\n", | ||
"\n", | ||
"You can randomly initialize the model's variables using some toy data for demonstration purposes." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_classes = 3\n", | ||
"model = Classifier(num_classes=num_classes, backbone=vision_model)\n", | ||
"\n", | ||
"x = jnp.ones((1, 224, 224, 3))\n", | ||
"variables = model.init(jax.random.PRNGKey(1), x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Transfer the parameters\n", | ||
"\n", | ||
"Since `variables` are randomly initialized, you now have to transfer the parameters from `vision_model_variables` to the complete `variables` at the appropriate location. This can be done by unfreezing the `variables`, updating the `backbone` parameters, and freezing the `variables` again:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from flax.core.frozen_dict import freeze\n", | ||
"\n", | ||
"variables = variables.unfreeze()\n", | ||
"variables['params']['backbone'] = vision_model_variables['params']\n", | ||
"variables = freeze(variables)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Optimization\n", | ||
"\n", | ||
"If you need to to train different parts of the model separately, you have two options:\n", | ||
"\n", | ||
"1. Use `stop_gradient`.\n", | ||
"2. Filter the parameters for `jax.grad`.\n", | ||
"3. Use multiple optimizers for different parameters.\n", | ||
"\n", | ||
"While each could be useful in different situations, its recommended to use use multiple optimizers via [Optax](https://optax.readthedocs.io/)'s [`optax.multi_transform`](https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform) because it is efficient and can be easily extended to implement differential learning rates. To use `optax.multi_transform` you have to do two things:\n", | ||
"\n", | ||
"1. Define some parameter partitions.\n", | ||
"2. Create a mapping between partitions and their optimizer.\n", | ||
"3. Create a pytree with the same shape as the parameters but its leaves containing the corresponding partition label.\n", | ||
"\n", | ||
"## Freeze layers\n", | ||
"\n", | ||
"To freeze layers with `optax.multi_transform`, create the `trainable` and `frozen` parameter partitions.\n", | ||
"\n", | ||
"In the example below:\n", | ||
"\n", | ||
"- For the `trainable` parameters use the Adam (`optax.adam`) optimizer.\n", | ||
"- For the `frozen` parameters use `optax.set_to_zero`, which zeros-out the gradients.\n", | ||
"- To map parameters to partitions, you can use the [`flax.traverse_util.path_aware_map`](https://flax.readthedocs.io/en/latest/api_reference/flax.traverse_util.html#flax.traverse_util.path_aware_map) function, by leveraging the `path` argument you can map the `backbone` parameters to `frozen` and the rest to `trainable`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"FrozenDict({\n", | ||
" backbone: {\n", | ||
" embeddings: {\n", | ||
" class_embedding: 'frozen',\n", | ||
" patch_embedding: {\n", | ||
" kernel: 'frozen',\n", | ||
" },\n", | ||
" },\n", | ||
" },\n", | ||
" head: {\n", | ||
" bias: 'trainable',\n", | ||
" kernel: 'trainable',\n", | ||
" },\n", | ||
"})" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from flax import traverse_util\n", | ||
"import optax\n", | ||
"\n", | ||
"partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n", | ||
"param_partitions = freeze(traverse_util.path_aware_map(\n", | ||
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', variables['params']))\n", | ||
"tx = optax.multi_transform(partition_optimizers, param_partitions)\n", | ||
"\n", | ||
"# visualize a subset of the param_partitions structure\n", | ||
"flat = list(traverse_util.flatten_dict(param_partitions).items())\n", | ||
"freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To implement _differential learning rates_ simply replace `optax.set_to_zero` with the optimizer of your choice, you can choose different optimizers and partitioning schemes depending on your needs.\n", | ||
"\n", | ||
"For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.\n", | ||
"\n", | ||
"## Create the `TrainState` object for model training\n", | ||
"\n", | ||
"Once you define your module, variables, and optimizer, you can construct the `TrainState` object and proceed to train the model as you normally would." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from flax.training.train_state import TrainState\n", | ||
"\n", | ||
"state = TrainState.create(\n", | ||
" apply_fn=model.apply,\n", | ||
" params=variables['params'],\n", | ||
" tx=tx)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"formats": "ipynb,md:myst" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.8.10 ('.venv': venv)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.14" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "ec7c69eb752b35b8fd728edc4753e382b54c10c43e6028c93b5837f81a552f5c" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.