Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add support for selective finetune (freeze parameters by regex from config file) #1427

Merged
merged 9 commits into from
Jun 28, 2018

Conversation

HarshTrivedi
Copy link
Contributor

@HarshTrivedi HarshTrivedi commented Jun 26, 2018

  • Add support in fine_tune to selectively tune
  • Add tests for selective fine tuning.
  • Allow for turning off gradients in train command.

This PR adds support in fine-tune command to freeze parameters layers by passing regex in config file. In config file passed in fine-tune command:

"trainer": {
    ...
    "no_grad": ["*conv*", ".*text_embedder*"]
}

The above will freeze conv layers and text_embedders but not others parameters.
I believed best place to allow no_grad key setting accessible to is within trainer key and trainer is also used in train command. So I allowed turning of gradients via no_grad in train command as well.

Although I primarily intended this selective turning off of gradients for fine-tune command, it can be useful in train command as well : - eg. If one's modules parameters are loaded / transferred from some other pretrained model and want to freeze it there.

More context: Issue (#1298 )

Copy link
Contributor

@joelgrus joelgrus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, modulo the small changes I requested, I'll let @DeNeutoy weigh in too

@@ -166,6 +166,12 @@ def fine_tune_model(model: Model,
test_data = all_datasets.get('test')

trainer_params = params.pop("trainer")
nograd_regex_list = trainer_params.pop("no_grad", ())
if nograd_regex_list:
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels error-prone to me, I'd rather just

no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
    if any(re.search(regex, name) for regex in no_grad_regexes):
        parameter.requires_grad_(False)

it's cleaner and I can't imagine the time difference being noticeable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will change that.

nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")"
for name, parameter in model.named_parameters():
if re.search(nograd_regex, name):
parameter.requires_grad_(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

@@ -10,7 +10,7 @@
import logging
import os
from copy import deepcopy

import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I like to leave blank lines between the grouped imports: (standard library) -> (external libraries) -> (allennlp modules)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh!! I didn't realise the purpose of blanks...

@@ -37,7 +37,7 @@
import logging
import os
from copy import deepcopy

import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.


import re
import shutil
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit: blank line before and after pytest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -2,7 +2,8 @@
import argparse
from typing import Iterable
import os

import shutil
import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit: blank line before import pytest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

# If regex is matched, parameter name should have same requires_grad
# as the originally loaded model
if regex_list:
nograd_regex = "(" + ")|(".join(regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more ok with this being in the test, but I'd still prefer the other way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem, i can change that.

# If regex is matched, parameter name should have requires_grad False
# Or else True
if regex_list:
nograd_regex = "(" + ")|(".join(regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I had done something similar elsewhere in previous commit : here and used that here.
Should I change there as well?

@joelgrus joelgrus requested a review from DeNeutoy June 26, 2018 20:16
Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, one broad point - could you make sure to add logging which prints out which exact parameters have been frozen and which have not been? This is the sort of thing which you could easily mis-specify, such that weird stuff happens, like accidentally training the biases of the frozen model or something.

@HarshTrivedi
Copy link
Contributor Author

True! Sure, I can add that in logs. thanks

@DeNeutoy
Copy link
Contributor

Also FYI - thanks for your contributions to allennlp - your PRs and conduct are exemplary and i'm glad that you find the library useful.

@HarshTrivedi
Copy link
Contributor Author

@DeNeutoy Thank you! allennlp is being very helpful to me in what I am currently doing and am sure others who try will find the same. I am very glad to contribute in allennlp as well : ) 👍

@HarshTrivedi
Copy link
Contributor Author

HarshTrivedi commented Jun 26, 2018

Btw, do we want to log both frozen and non-frozen parameters? Most of the times most of the parameters will be non-frozen. And with all those parameters logged it looks quite messy on the terminal. How about logging only Frozen ones?

This is from one test: link.

18:17:25 - INFO - allennlp.commands.train - Parameters without gradient (Frozen) : ['text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._char_embedding_weights', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_0.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_1.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_2.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_3.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder.char_conv_4.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.0.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._highways._layers.1.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._projection.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.input_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_linearity.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_projection.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.input_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_linearity.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_0.state_projection.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.input_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_projection.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.input_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.weight', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_linearity.bias', 'text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.backward_layer_1.state_projection.weight', 'text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.gamma', 'text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.0', 'text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.1', 'text_field_embedder.token_embedder_elmo._elmo.scalar_mix_0.scalar_parameters.2', 'text_field_embedder.token_embedder_tokens.weight', 'encoder._module.weight_ih_l0', 'encoder._module.weight_hh_l0', 'encoder._module.bias_ih_l0', 'encoder._module.bias_hh_l0', 'encoder._module.weight_ih_l0_reverse', 'encoder._module.weight_hh_l0_reverse', 'encoder._module.bias_ih_l0_reverse', 'encoder._module.bias_hh_l0_reverse', 'encoder._module.weight_ih_l1', 'encoder._module.weight_hh_l1', 'encoder._module.bias_ih_l1', 'encoder._module.bias_hh_l1', 'encoder._module.weight_ih_l1_reverse', 'encoder._module.weight_hh_l1_reverse', 'encoder._module.bias_ih_l1_reverse', 'encoder._module.bias_hh_l1_reverse', 'tag_projection_layer._module.weight', 'tag_projection_layer._module.bias', 'crf.transitions', 'crf._constraint_mask', 'crf.start_transitions', 'crf.end_transitions']
18:17:25 - INFO - allennlp.commands.train - Parameters with gradient    (Tunable): []

Edit Again!:
Sorry the logging for frozen and otherwise is reverse here ... but the point stands.

@DeNeutoy
Copy link
Contributor

Put the individual parameters in separate logging statements in a for loop.

@HarshTrivedi
Copy link
Contributor Author

This should be good now.

@DeNeutoy DeNeutoy merged commit 7664b12 into allenai:master Jun 28, 2018
parameter.requires_grad_(False)
nograd_parameter_names.append(name)
else:
grad_parameter_names.append(name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one problem here, if the parameter's requires_grad is already False, for example a non-trainable embedding, the log will show that it is tunable. Not a very big problem but it looks a bit confusing sometimes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank You for catching that! I agree it would be confusing. Will fix that.

@HarshTrivedi HarshTrivedi deleted the selective-finetune branch June 30, 2018 14:17
DeNeutoy pushed a commit that referenced this pull request Jul 6, 2018
Fix issue in no-grad parameters logging as mentioned by @rulai-huajunzeng in (#1427).

If parameters were already set `requires_grad=False` not through no through nograd regex but other means then they were logged as Tunable instead of Frozen. This pr fixes that.

I have made the headings capitalized. They are more distinguishable this way amidst a long list of parameters.
gabrielStanovsky pushed a commit to gabrielStanovsky/allennlp that referenced this pull request Sep 7, 2018
…onfig file) (allenai#1427)

* Add support in fine_tune to selectively tune (freeze some parameters set through config file)

* Add tests for selective fine tuning.

* Allow for turning off gradients in train command (since in fine-tune as well this is happening with "trainer" configs).

* Add missing imports in fine_tune_test.py

* add tests for using 'no_grad' config with train command

* Code cleanup: 1. for regex matches 2. follow import convention

* Add logging statements for knowing tunable and frozen parameters.
gabrielStanovsky pushed a commit to gabrielStanovsky/allennlp that referenced this pull request Sep 7, 2018
Fix issue in no-grad parameters logging as mentioned by @rulai-huajunzeng in (allenai#1427).

If parameters were already set `requires_grad=False` not through no through nograd regex but other means then they were logged as Tunable instead of Frozen. This pr fixes that.

I have made the headings capitalized. They are more distinguishable this way amidst a long list of parameters.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants