-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add support for selective finetune (freeze parameters by regex from config file) #1427
Conversation
…set through config file)
…as well this is happening with "trainer" configs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, modulo the small changes I requested, I'll let @DeNeutoy weigh in too
allennlp/commands/fine_tune.py
Outdated
@@ -166,6 +166,12 @@ def fine_tune_model(model: Model, | |||
test_data = all_datasets.get('test') | |||
|
|||
trainer_params = params.pop("trainer") | |||
nograd_regex_list = trainer_params.pop("no_grad", ()) | |||
if nograd_regex_list: | |||
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels error-prone to me, I'd rather just
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
it's cleaner and I can't imagine the time difference being noticeable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will change that.
allennlp/commands/train.py
Outdated
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" | ||
for name, parameter in model.named_parameters(): | ||
if re.search(nograd_regex, name): | ||
parameter.requires_grad_(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure.
allennlp/commands/fine_tune.py
Outdated
@@ -10,7 +10,7 @@ | |||
import logging | |||
import os | |||
from copy import deepcopy | |||
|
|||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I like to leave blank lines between the grouped imports: (standard library) -> (external libraries) -> (allennlp modules)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh!! I didn't realise the purpose of blanks...
allennlp/commands/train.py
Outdated
@@ -37,7 +37,7 @@ | |||
import logging | |||
import os | |||
from copy import deepcopy | |||
|
|||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure.
|
||
import re | ||
import shutil | ||
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit: blank line before and after pytest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -2,7 +2,8 @@ | |||
import argparse | |||
from typing import Iterable | |||
import os | |||
|
|||
import shutil | |||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit: blank line before import pytest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
# If regex is matched, parameter name should have same requires_grad | ||
# as the originally loaded model | ||
if regex_list: | ||
nograd_regex = "(" + ")|(".join(regex_list) + ")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more ok with this being in the test, but I'd still prefer the other way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no problem, i can change that.
# If regex is matched, parameter name should have requires_grad False | ||
# Or else True | ||
if regex_list: | ||
nograd_regex = "(" + ")|(".join(regex_list) + ")" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, one broad point - could you make sure to add logging which prints out which exact parameters have been frozen and which have not been? This is the sort of thing which you could easily mis-specify, such that weird stuff happens, like accidentally training the biases of the frozen model or something.
True! Sure, I can add that in logs. thanks |
Also FYI - thanks for your contributions to allennlp - your PRs and conduct are exemplary and i'm glad that you find the library useful. |
@DeNeutoy Thank you! allennlp is being very helpful to me in what I am currently doing and am sure others who try will find the same. I am very glad to contribute in allennlp as well : ) 👍 |
Btw, do we want to log both frozen and non-frozen parameters? Most of the times most of the parameters will be non-frozen. And with all those parameters logged it looks quite messy on the terminal. How about logging only Frozen ones?
Edit Again!: |
Put the individual parameters in separate logging statements in a for loop. |
This should be good now. |
parameter.requires_grad_(False) | ||
nograd_parameter_names.append(name) | ||
else: | ||
grad_parameter_names.append(name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is only one problem here, if the parameter's requires_grad is already False, for example a non-trainable embedding, the log will show that it is tunable. Not a very big problem but it looks a bit confusing sometimes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank You for catching that! I agree it would be confusing. Will fix that.
Fix issue in no-grad parameters logging as mentioned by @rulai-huajunzeng in (#1427). If parameters were already set `requires_grad=False` not through no through nograd regex but other means then they were logged as Tunable instead of Frozen. This pr fixes that. I have made the headings capitalized. They are more distinguishable this way amidst a long list of parameters.
…onfig file) (allenai#1427) * Add support in fine_tune to selectively tune (freeze some parameters set through config file) * Add tests for selective fine tuning. * Allow for turning off gradients in train command (since in fine-tune as well this is happening with "trainer" configs). * Add missing imports in fine_tune_test.py * add tests for using 'no_grad' config with train command * Code cleanup: 1. for regex matches 2. follow import convention * Add logging statements for knowing tunable and frozen parameters.
Fix issue in no-grad parameters logging as mentioned by @rulai-huajunzeng in (allenai#1427). If parameters were already set `requires_grad=False` not through no through nograd regex but other means then they were logged as Tunable instead of Frozen. This pr fixes that. I have made the headings capitalized. They are more distinguishable this way amidst a long list of parameters.
This PR adds support in
fine-tune
command to freeze parameters layers by passing regex in config file. In config file passed in fine-tune command:The above will freeze
conv
layers andtext_embedder
s but not others parameters.I believed best place to allow
no_grad
key setting accessible to is withintrainer
key andtrainer
is also used intrain
command. So I allowed turning of gradients viano_grad
in train command as well.Although I primarily intended this selective turning off of gradients for fine-tune command, it can be useful in
train
command as well : - eg. If one's modules parameters are loaded / transferred from some other pretrained model and want to freeze it there.More context: Issue (#1298 )