Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Fix logging of no-grad parameters. #1448

Merged
merged 4 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.commands.train import datasets_from_params
from allennlp.common import Params
from allennlp.common.util import prepare_environment, prepare_global_logging
from allennlp.common.util import prepare_environment, prepare_global_logging, \
get_frozen_and_tunable_parameter_names
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.models import load_archive, archive_model
from allennlp.models.archival import CONFIG_NAME
Expand Down Expand Up @@ -182,21 +183,17 @@ def fine_tune_model(model: Model,

trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())

nograd_parameter_names = []
grad_parameter_names = []
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
nograd_parameter_names.append(name)
else:
grad_parameter_names.append(name)

frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in nograd_parameter_names:
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in grad_parameter_names:
for name in tunable_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
Expand Down
15 changes: 6 additions & 9 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params
from allennlp.common.util import prepare_environment, prepare_global_logging
from allennlp.common.util import prepare_environment, prepare_global_logging, \
get_frozen_and_tunable_parameter_names
from allennlp.data import Vocabulary
from allennlp.data.instance import Instance
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
Expand Down Expand Up @@ -293,21 +294,17 @@ def train_model(params: Params,

trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())

nograd_parameter_names = []
grad_parameter_names = []
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
nograd_parameter_names.append(name)
else:
grad_parameter_names.append(name)

frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in nograd_parameter_names:
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in grad_parameter_names:
for name in tunable_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
Expand Down
10 changes: 10 additions & 0 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,13 @@ def is_lazy(iterable: Iterable[A]) -> bool:
which here just means it's not a list.
"""
return not isinstance(iterable, list)

def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
frozen_parameter_names = []
tunable_parameter_names = []
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
frozen_parameter_names.append(name)
else:
tunable_parameter_names.append(name)
return [frozen_parameter_names, tunable_parameter_names]
14 changes: 14 additions & 0 deletions allennlp/tests/common/test_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=no-self-use,invalid-name
import sys
from collections import OrderedDict

import pytest
import torch
Expand Down Expand Up @@ -65,3 +66,16 @@ def test_import_submodules(self):
assert 'mymodule.submodule' in sys.modules

sys.path.remove(str(self.TEST_DIR))

def test_get_frozen_and_tunable_parameter_names(self):
model = torch.nn.Sequential(OrderedDict([
('conv', torch.nn.Conv1d(5, 5, 5)),
('linear', torch.nn.Linear(5, 10)),
]))
named_parameters = dict(model.named_parameters())
named_parameters['linear.weight'].requires_grad_(False)
named_parameters['linear.bias'].requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
util.get_frozen_and_tunable_parameter_names(model)
assert set(frozen_parameter_names) == {'linear.weight', 'linear.bias'}
assert set(tunable_parameter_names) == {'conv.weight', 'conv.bias'}