-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add support for selective finetune (freeze parameters by regex from config file) #1427
Changes from 5 commits
50a0abf
f19e5a7
e371603
3af6c81
6f0089b
e6f887d
59ac24c
3a45cf5
351987e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import logging | ||
import os | ||
from copy import deepcopy | ||
|
||
import re | ||
from allennlp.commands.evaluate import evaluate | ||
from allennlp.commands.subcommand import Subcommand | ||
from allennlp.commands.train import datasets_from_params | ||
|
@@ -166,6 +166,12 @@ def fine_tune_model(model: Model, | |
test_data = all_datasets.get('test') | ||
|
||
trainer_params = params.pop("trainer") | ||
nograd_regex_list = trainer_params.pop("no_grad", ()) | ||
if nograd_regex_list: | ||
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels error-prone to me, I'd rather just
it's cleaner and I can't imagine the time difference being noticeable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will change that. |
||
for name, parameter in model.named_parameters(): | ||
if re.search(nograd_regex, name): | ||
parameter.requires_grad_(False) | ||
trainer = Trainer.from_params(model, | ||
serialization_dir, | ||
iterator, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
import logging | ||
import os | ||
from copy import deepcopy | ||
|
||
import re | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. |
||
import torch | ||
|
||
from allennlp.commands.evaluate import evaluate | ||
|
@@ -282,6 +282,12 @@ def train_model(params: Params, | |
test_data = all_datasets.get('test') | ||
|
||
trainer_params = params.pop("trainer") | ||
nograd_regex_list = trainer_params.pop("no_grad", ()) | ||
if nograd_regex_list: | ||
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" | ||
for name, parameter in model.named_parameters(): | ||
if re.search(nograd_regex, name): | ||
parameter.requires_grad_(False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. |
||
trainer = Trainer.from_params(model, | ||
serialization_dir, | ||
iterator, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,13 @@ | ||
# pylint: disable=invalid-name,no-self-use | ||
import argparse | ||
|
||
import re | ||
import shutil | ||
import pytest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit: blank line before and after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
from allennlp.common.testing import AllenNlpTestCase | ||
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, fine_tune_model_from_args | ||
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \ | ||
fine_tune_model_from_args, fine_tune_model | ||
from allennlp.common.params import Params | ||
from allennlp.models import load_archive | ||
|
||
class TestFineTune(AllenNlpTestCase): | ||
def setUp(self): | ||
|
@@ -50,3 +55,36 @@ def test_fine_tune_fails_without_required_args(self): | |
with self.assertRaises(SystemExit) as context: | ||
self.parser.parse_args(["fine-tune", "-s", "serialization_dir", "-c", "path/to/config"]) | ||
assert context.exception.code == 2 # argparse code for incorrect usage | ||
|
||
def test_fine_tune_nograd_regex(self): | ||
original_model = load_archive(self.model_archive).model | ||
name_parameters_original = dict(original_model.named_parameters()) | ||
regex_lists = [[], | ||
[".*attend_feedforward.*", ".*token_embedder.*"], | ||
[".*compare_feedforward.*"]] | ||
for regex_list in regex_lists: | ||
params = Params.from_file(self.config_file) | ||
params["trainer"]["no_grad"] = regex_list | ||
shutil.rmtree(self.serialization_dir, ignore_errors=True) | ||
tuned_model = fine_tune_model(model=original_model, | ||
params=params, | ||
serialization_dir=self.serialization_dir) | ||
# If regex is matched, parameter name should have requires_grad False | ||
# If regex is matched, parameter name should have same requires_grad | ||
# as the originally loaded model | ||
if regex_list: | ||
nograd_regex = "(" + ")|(".join(regex_list) + ")" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm more ok with this being in the test, but I'd still prefer the other way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no problem, i can change that. |
||
for name, parameter in tuned_model.named_parameters(): | ||
if re.search(nograd_regex, name): | ||
assert not parameter.requires_grad | ||
else: | ||
assert parameter.requires_grad \ | ||
== name_parameters_original[name].requires_grad | ||
# If all parameters have requires_grad=False, then error. | ||
with pytest.raises(Exception) as _: | ||
params = Params.from_file(self.config_file) | ||
params["trainer"]["no_grad"] = ["*"] | ||
shutil.rmtree(self.serialization_dir, ignore_errors=True) | ||
tuned_model = fine_tune_model(model=original_model, | ||
params=params, | ||
serialization_dir=self.serialization_dir) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,8 @@ | |
import argparse | ||
from typing import Iterable | ||
import os | ||
|
||
import shutil | ||
import re | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit: blank line before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
import pytest | ||
import torch | ||
|
||
|
@@ -232,3 +233,54 @@ def test_train_with_test_set(self): | |
}) | ||
|
||
train_model(params, serialization_dir=os.path.join(self.TEST_DIR, 'lazy_test_set')) | ||
|
||
def test_train_nograd_regex(self): | ||
params_get = lambda: Params({ | ||
"model": { | ||
"type": "simple_tagger", | ||
"text_field_embedder": { | ||
"tokens": { | ||
"type": "embedding", | ||
"embedding_dim": 5 | ||
} | ||
}, | ||
"encoder": { | ||
"type": "lstm", | ||
"input_size": 5, | ||
"hidden_size": 7, | ||
"num_layers": 2 | ||
} | ||
}, | ||
"dataset_reader": {"type": "sequence_tagging"}, | ||
"train_data_path": SEQUENCE_TAGGING_DATA_PATH, | ||
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH, | ||
"iterator": {"type": "basic", "batch_size": 2}, | ||
"trainer": { | ||
"num_epochs": 2, | ||
"optimizer": "adam" | ||
} | ||
}) | ||
serialization_dir = os.path.join(self.TEST_DIR, 'test_train_nograd') | ||
regex_lists = [[], | ||
[".*text_field_embedder.*"], | ||
[".*text_field_embedder.*", ".*encoder.*"]] | ||
for regex_list in regex_lists: | ||
params = params_get() | ||
params["trainer"]["no_grad"] = regex_list | ||
shutil.rmtree(serialization_dir, ignore_errors=True) | ||
model = train_model(params, serialization_dir=serialization_dir) | ||
# If regex is matched, parameter name should have requires_grad False | ||
# Or else True | ||
if regex_list: | ||
nograd_regex = "(" + ")|(".join(regex_list) + ")" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
for name, parameter in model.named_parameters(): | ||
if re.search(nograd_regex, name): | ||
assert not parameter.requires_grad | ||
else: | ||
assert parameter.requires_grad | ||
# If all parameters have requires_grad=False, then error. | ||
params = params_get() | ||
params["trainer"]["no_grad"] = ["*"] | ||
shutil.rmtree(serialization_dir, ignore_errors=True) | ||
with pytest.raises(Exception) as _: | ||
model = train_model(params, serialization_dir=serialization_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I like to leave blank lines between the grouped imports: (standard library) -> (external libraries) -> (allennlp modules)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh!! I didn't realise the purpose of blanks...