Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 77298a9

Browse files
HarshTrivediDeNeutoy
authored andcommitted
Fix logging of no-grad parameters. (#1448)
Fix issue in no-grad parameters logging as mentioned by @rulai-huajunzeng in (#1427). If parameters were already set `requires_grad=False` not through no through nograd regex but other means then they were logged as Tunable instead of Frozen. This pr fixes that. I have made the headings capitalized. They are more distinguishable this way amidst a long list of parameters.
1 parent bef52ed commit 77298a9

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

allennlp/commands/fine_tune.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from allennlp.commands.subcommand import Subcommand
1717
from allennlp.commands.train import datasets_from_params
1818
from allennlp.common import Params
19-
from allennlp.common.util import prepare_environment, prepare_global_logging
19+
from allennlp.common.util import prepare_environment, prepare_global_logging, \
20+
get_frozen_and_tunable_parameter_names
2021
from allennlp.data.iterators.data_iterator import DataIterator
2122
from allennlp.models import load_archive, archive_model
2223
from allennlp.models.archival import CONFIG_NAME
@@ -182,21 +183,17 @@ def fine_tune_model(model: Model,
182183

183184
trainer_params = params.pop("trainer")
184185
no_grad_regexes = trainer_params.pop("no_grad", ())
185-
186-
nograd_parameter_names = []
187-
grad_parameter_names = []
188186
for name, parameter in model.named_parameters():
189187
if any(re.search(regex, name) for regex in no_grad_regexes):
190188
parameter.requires_grad_(False)
191-
nograd_parameter_names.append(name)
192-
else:
193-
grad_parameter_names.append(name)
194189

190+
frozen_parameter_names, tunable_parameter_names = \
191+
get_frozen_and_tunable_parameter_names(model)
195192
logger.info("Following parameters are Frozen (without gradient):")
196-
for name in nograd_parameter_names:
193+
for name in frozen_parameter_names:
197194
logger.info(name)
198195
logger.info("Following parameters are Tunable (with gradient):")
199-
for name in grad_parameter_names:
196+
for name in tunable_parameter_names:
200197
logger.info(name)
201198

202199
trainer = Trainer.from_params(model,

allennlp/commands/train.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
from allennlp.commands.subcommand import Subcommand
4646
from allennlp.common.checks import ConfigurationError, check_for_gpu
4747
from allennlp.common import Params
48-
from allennlp.common.util import prepare_environment, prepare_global_logging
48+
from allennlp.common.util import prepare_environment, prepare_global_logging, \
49+
get_frozen_and_tunable_parameter_names
4950
from allennlp.data import Vocabulary
5051
from allennlp.data.instance import Instance
5152
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
@@ -293,21 +294,17 @@ def train_model(params: Params,
293294

294295
trainer_params = params.pop("trainer")
295296
no_grad_regexes = trainer_params.pop("no_grad", ())
296-
297-
nograd_parameter_names = []
298-
grad_parameter_names = []
299297
for name, parameter in model.named_parameters():
300298
if any(re.search(regex, name) for regex in no_grad_regexes):
301299
parameter.requires_grad_(False)
302-
nograd_parameter_names.append(name)
303-
else:
304-
grad_parameter_names.append(name)
305300

301+
frozen_parameter_names, tunable_parameter_names = \
302+
get_frozen_and_tunable_parameter_names(model)
306303
logger.info("Following parameters are Frozen (without gradient):")
307-
for name in nograd_parameter_names:
304+
for name in frozen_parameter_names:
308305
logger.info(name)
309306
logger.info("Following parameters are Tunable (with gradient):")
310-
for name in grad_parameter_names:
307+
for name in tunable_parameter_names:
311308
logger.info(name)
312309

313310
trainer = Trainer.from_params(model,

allennlp/common/util.py

+10
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,13 @@ def is_lazy(iterable: Iterable[A]) -> bool:
344344
which here just means it's not a list.
345345
"""
346346
return not isinstance(iterable, list)
347+
348+
def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
349+
frozen_parameter_names = []
350+
tunable_parameter_names = []
351+
for name, parameter in model.named_parameters():
352+
if not parameter.requires_grad:
353+
frozen_parameter_names.append(name)
354+
else:
355+
tunable_parameter_names.append(name)
356+
return [frozen_parameter_names, tunable_parameter_names]

allennlp/tests/common/test_util.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=no-self-use,invalid-name
22
import sys
3+
from collections import OrderedDict
34

45
import pytest
56
import torch
@@ -65,3 +66,16 @@ def test_import_submodules(self):
6566
assert 'mymodule.submodule' in sys.modules
6667

6768
sys.path.remove(str(self.TEST_DIR))
69+
70+
def test_get_frozen_and_tunable_parameter_names(self):
71+
model = torch.nn.Sequential(OrderedDict([
72+
('conv', torch.nn.Conv1d(5, 5, 5)),
73+
('linear', torch.nn.Linear(5, 10)),
74+
]))
75+
named_parameters = dict(model.named_parameters())
76+
named_parameters['linear.weight'].requires_grad_(False)
77+
named_parameters['linear.bias'].requires_grad_(False)
78+
frozen_parameter_names, tunable_parameter_names = \
79+
util.get_frozen_and_tunable_parameter_names(model)
80+
assert set(frozen_parameter_names) == {'linear.weight', 'linear.bias'}
81+
assert set(tunable_parameter_names) == {'conv.weight', 'conv.bias'}

0 commit comments

Comments
 (0)