Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add support for selective finetune (freeze parameters by regex from config file) #1427

Merged
merged 9 commits into from
Jun 28, 2018
Merged
19 changes: 19 additions & 0 deletions allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
from copy import deepcopy
import re

from allennlp.commands.evaluate import evaluate
from allennlp.commands.subcommand import Subcommand
Expand Down Expand Up @@ -180,6 +181,24 @@ def fine_tune_model(model: Model,
test_data = all_datasets.get('test')

trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())

nograd_parameter_names = []
grad_parameter_names = []
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
nograd_parameter_names.append(name)
else:
grad_parameter_names.append(name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one problem here, if the parameter's requires_grad is already False, for example a non-trainable embedding, the log will show that it is tunable. Not a very big problem but it looks a bit confusing sometimes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank You for catching that! I agree it would be confusing. Will fix that.


logger.info("Following parameters are Frozen (without gradient):")
for name in nograd_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in grad_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
serialization_dir,
iterator,
Expand Down
19 changes: 19 additions & 0 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import logging
import os
from copy import deepcopy
import re

import torch

Expand Down Expand Up @@ -282,6 +283,24 @@ def train_model(params: Params,
test_data = all_datasets.get('test')

trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())

nograd_parameter_names = []
grad_parameter_names = []
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
nograd_parameter_names.append(name)
else:
grad_parameter_names.append(name)

logger.info("Following parameters are Frozen (without gradient):")
for name in nograd_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in grad_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
serialization_dir,
iterator,
Expand Down
40 changes: 39 additions & 1 deletion allennlp/tests/commands/fine_tune_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# pylint: disable=invalid-name,no-self-use
import argparse
import re
import shutil

import pytest

from allennlp.common.testing import AllenNlpTestCase
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, fine_tune_model_from_args
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \
fine_tune_model_from_args, fine_tune_model
from allennlp.common.params import Params
from allennlp.models import load_archive

class TestFineTune(AllenNlpTestCase):
def setUp(self):
Expand Down Expand Up @@ -50,3 +57,34 @@ def test_fine_tune_fails_without_required_args(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(["fine-tune", "-s", "serialization_dir", "-c", "path/to/config"])
assert context.exception.code == 2 # argparse code for incorrect usage

def test_fine_tune_nograd_regex(self):
original_model = load_archive(self.model_archive).model
name_parameters_original = dict(original_model.named_parameters())
regex_lists = [[],
[".*attend_feedforward.*", ".*token_embedder.*"],
[".*compare_feedforward.*"]]
for regex_list in regex_lists:
params = Params.from_file(self.config_file)
params["trainer"]["no_grad"] = regex_list
shutil.rmtree(self.serialization_dir, ignore_errors=True)
tuned_model = fine_tune_model(model=original_model,
params=params,
serialization_dir=self.serialization_dir)
# If regex is matched, parameter name should have requires_grad False
# If regex is matched, parameter name should have same requires_grad
# as the originally loaded model
for name, parameter in tuned_model.named_parameters():
if any(re.search(regex, name) for regex in regex_list):
assert not parameter.requires_grad
else:
assert parameter.requires_grad \
== name_parameters_original[name].requires_grad
# If all parameters have requires_grad=False, then error.
with pytest.raises(Exception) as _:
params = Params.from_file(self.config_file)
params["trainer"]["no_grad"] = ["*"]
shutil.rmtree(self.serialization_dir, ignore_errors=True)
tuned_model = fine_tune_model(model=original_model,
params=params,
serialization_dir=self.serialization_dir)
51 changes: 51 additions & 0 deletions allennlp/tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import argparse
from typing import Iterable
import os
import shutil
import re

import pytest
import torch
Expand Down Expand Up @@ -232,3 +234,52 @@ def test_train_with_test_set(self):
})

train_model(params, serialization_dir=os.path.join(self.TEST_DIR, 'lazy_test_set'))

def test_train_nograd_regex(self):
params_get = lambda: Params({
"model": {
"type": "simple_tagger",
"text_field_embedder": {
"tokens": {
"type": "embedding",
"embedding_dim": 5
}
},
"encoder": {
"type": "lstm",
"input_size": 5,
"hidden_size": 7,
"num_layers": 2
}
},
"dataset_reader": {"type": "sequence_tagging"},
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"iterator": {"type": "basic", "batch_size": 2},
"trainer": {
"num_epochs": 2,
"optimizer": "adam"
}
})
serialization_dir = os.path.join(self.TEST_DIR, 'test_train_nograd')
regex_lists = [[],
[".*text_field_embedder.*"],
[".*text_field_embedder.*", ".*encoder.*"]]
for regex_list in regex_lists:
params = params_get()
params["trainer"]["no_grad"] = regex_list
shutil.rmtree(serialization_dir, ignore_errors=True)
model = train_model(params, serialization_dir=serialization_dir)
# If regex is matched, parameter name should have requires_grad False
# Or else True
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in regex_list):
assert not parameter.requires_grad
else:
assert parameter.requires_grad
# If all parameters have requires_grad=False, then error.
params = params_get()
params["trainer"]["no_grad"] = ["*"]
shutil.rmtree(serialization_dir, ignore_errors=True)
with pytest.raises(Exception) as _:
model = train_model(params, serialization_dir=serialization_dir)