Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 7664b12

Browse files
HarshTrivediDeNeutoy
authored andcommitted
Add support for selective finetune (freeze parameters by regex from config file) (#1427)
* Add support in fine_tune to selectively tune (freeze some parameters set through config file) * Add tests for selective fine tuning. * Allow for turning off gradients in train command (since in fine-tune as well this is happening with "trainer" configs). * Add missing imports in fine_tune_test.py * add tests for using 'no_grad' config with train command * Code cleanup: 1. for regex matches 2. follow import convention * Add logging statements for knowing tunable and frozen parameters.
1 parent 8855042 commit 7664b12

File tree

4 files changed

+128
-1
lines changed

4 files changed

+128
-1
lines changed

allennlp/commands/fine_tune.py

+19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
from copy import deepcopy
13+
import re
1314

1415
from allennlp.commands.evaluate import evaluate
1516
from allennlp.commands.subcommand import Subcommand
@@ -180,6 +181,24 @@ def fine_tune_model(model: Model,
180181
test_data = all_datasets.get('test')
181182

182183
trainer_params = params.pop("trainer")
184+
no_grad_regexes = trainer_params.pop("no_grad", ())
185+
186+
nograd_parameter_names = []
187+
grad_parameter_names = []
188+
for name, parameter in model.named_parameters():
189+
if any(re.search(regex, name) for regex in no_grad_regexes):
190+
parameter.requires_grad_(False)
191+
nograd_parameter_names.append(name)
192+
else:
193+
grad_parameter_names.append(name)
194+
195+
logger.info("Following parameters are Frozen (without gradient):")
196+
for name in nograd_parameter_names:
197+
logger.info(name)
198+
logger.info("Following parameters are Tunable (with gradient):")
199+
for name in grad_parameter_names:
200+
logger.info(name)
201+
183202
trainer = Trainer.from_params(model,
184203
serialization_dir,
185204
iterator,

allennlp/commands/train.py

+19
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import logging
3838
import os
3939
from copy import deepcopy
40+
import re
4041

4142
import torch
4243

@@ -282,6 +283,24 @@ def train_model(params: Params,
282283
test_data = all_datasets.get('test')
283284

284285
trainer_params = params.pop("trainer")
286+
no_grad_regexes = trainer_params.pop("no_grad", ())
287+
288+
nograd_parameter_names = []
289+
grad_parameter_names = []
290+
for name, parameter in model.named_parameters():
291+
if any(re.search(regex, name) for regex in no_grad_regexes):
292+
parameter.requires_grad_(False)
293+
nograd_parameter_names.append(name)
294+
else:
295+
grad_parameter_names.append(name)
296+
297+
logger.info("Following parameters are Frozen (without gradient):")
298+
for name in nograd_parameter_names:
299+
logger.info(name)
300+
logger.info("Following parameters are Tunable (with gradient):")
301+
for name in grad_parameter_names:
302+
logger.info(name)
303+
285304
trainer = Trainer.from_params(model,
286305
serialization_dir,
287306
iterator,

allennlp/tests/commands/fine_tune_test.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
# pylint: disable=invalid-name,no-self-use
22
import argparse
3+
import re
4+
import shutil
5+
6+
import pytest
37

48
from allennlp.common.testing import AllenNlpTestCase
5-
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, fine_tune_model_from_args
9+
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \
10+
fine_tune_model_from_args, fine_tune_model
11+
from allennlp.common.params import Params
12+
from allennlp.models import load_archive
613

714
class TestFineTune(AllenNlpTestCase):
815
def setUp(self):
@@ -50,3 +57,34 @@ def test_fine_tune_fails_without_required_args(self):
5057
with self.assertRaises(SystemExit) as context:
5158
self.parser.parse_args(["fine-tune", "-s", "serialization_dir", "-c", "path/to/config"])
5259
assert context.exception.code == 2 # argparse code for incorrect usage
60+
61+
def test_fine_tune_nograd_regex(self):
62+
original_model = load_archive(self.model_archive).model
63+
name_parameters_original = dict(original_model.named_parameters())
64+
regex_lists = [[],
65+
[".*attend_feedforward.*", ".*token_embedder.*"],
66+
[".*compare_feedforward.*"]]
67+
for regex_list in regex_lists:
68+
params = Params.from_file(self.config_file)
69+
params["trainer"]["no_grad"] = regex_list
70+
shutil.rmtree(self.serialization_dir, ignore_errors=True)
71+
tuned_model = fine_tune_model(model=original_model,
72+
params=params,
73+
serialization_dir=self.serialization_dir)
74+
# If regex is matched, parameter name should have requires_grad False
75+
# If regex is matched, parameter name should have same requires_grad
76+
# as the originally loaded model
77+
for name, parameter in tuned_model.named_parameters():
78+
if any(re.search(regex, name) for regex in regex_list):
79+
assert not parameter.requires_grad
80+
else:
81+
assert parameter.requires_grad \
82+
== name_parameters_original[name].requires_grad
83+
# If all parameters have requires_grad=False, then error.
84+
with pytest.raises(Exception) as _:
85+
params = Params.from_file(self.config_file)
86+
params["trainer"]["no_grad"] = ["*"]
87+
shutil.rmtree(self.serialization_dir, ignore_errors=True)
88+
tuned_model = fine_tune_model(model=original_model,
89+
params=params,
90+
serialization_dir=self.serialization_dir)

allennlp/tests/commands/train_test.py

+51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import argparse
33
from typing import Iterable
44
import os
5+
import shutil
6+
import re
57

68
import pytest
79
import torch
@@ -232,3 +234,52 @@ def test_train_with_test_set(self):
232234
})
233235

234236
train_model(params, serialization_dir=os.path.join(self.TEST_DIR, 'lazy_test_set'))
237+
238+
def test_train_nograd_regex(self):
239+
params_get = lambda: Params({
240+
"model": {
241+
"type": "simple_tagger",
242+
"text_field_embedder": {
243+
"tokens": {
244+
"type": "embedding",
245+
"embedding_dim": 5
246+
}
247+
},
248+
"encoder": {
249+
"type": "lstm",
250+
"input_size": 5,
251+
"hidden_size": 7,
252+
"num_layers": 2
253+
}
254+
},
255+
"dataset_reader": {"type": "sequence_tagging"},
256+
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
257+
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
258+
"iterator": {"type": "basic", "batch_size": 2},
259+
"trainer": {
260+
"num_epochs": 2,
261+
"optimizer": "adam"
262+
}
263+
})
264+
serialization_dir = os.path.join(self.TEST_DIR, 'test_train_nograd')
265+
regex_lists = [[],
266+
[".*text_field_embedder.*"],
267+
[".*text_field_embedder.*", ".*encoder.*"]]
268+
for regex_list in regex_lists:
269+
params = params_get()
270+
params["trainer"]["no_grad"] = regex_list
271+
shutil.rmtree(serialization_dir, ignore_errors=True)
272+
model = train_model(params, serialization_dir=serialization_dir)
273+
# If regex is matched, parameter name should have requires_grad False
274+
# Or else True
275+
for name, parameter in model.named_parameters():
276+
if any(re.search(regex, name) for regex in regex_list):
277+
assert not parameter.requires_grad
278+
else:
279+
assert parameter.requires_grad
280+
# If all parameters have requires_grad=False, then error.
281+
params = params_get()
282+
params["trainer"]["no_grad"] = ["*"]
283+
shutil.rmtree(serialization_dir, ignore_errors=True)
284+
with pytest.raises(Exception) as _:
285+
model = train_model(params, serialization_dir=serialization_dir)

0 commit comments

Comments
 (0)