Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Add support for selective finetune (freeze parameters by regex from config file) #1427

Merged
merged 9 commits into from
Jun 28, 2018
8 changes: 7 additions & 1 deletion allennlp/commands/fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import os
from copy import deepcopy

import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I like to leave blank lines between the grouped imports: (standard library) -> (external libraries) -> (allennlp modules)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh!! I didn't realise the purpose of blanks...

from allennlp.commands.evaluate import evaluate
from allennlp.commands.subcommand import Subcommand
from allennlp.commands.train import datasets_from_params
Expand Down Expand Up @@ -166,6 +166,12 @@ def fine_tune_model(model: Model,
test_data = all_datasets.get('test')

trainer_params = params.pop("trainer")
nograd_regex_list = trainer_params.pop("no_grad", ())
if nograd_regex_list:
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels error-prone to me, I'd rather just

no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
    if any(re.search(regex, name) for regex in no_grad_regexes):
        parameter.requires_grad_(False)

it's cleaner and I can't imagine the time difference being noticeable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will change that.

for name, parameter in model.named_parameters():
if re.search(nograd_regex, name):
parameter.requires_grad_(False)
trainer = Trainer.from_params(model,
serialization_dir,
iterator,
Expand Down
8 changes: 7 additions & 1 deletion allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import logging
import os
from copy import deepcopy

import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

import torch

from allennlp.commands.evaluate import evaluate
Expand Down Expand Up @@ -282,6 +282,12 @@ def train_model(params: Params,
test_data = all_datasets.get('test')

trainer_params = params.pop("trainer")
nograd_regex_list = trainer_params.pop("no_grad", ())
if nograd_regex_list:
nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")"
for name, parameter in model.named_parameters():
if re.search(nograd_regex, name):
parameter.requires_grad_(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

trainer = Trainer.from_params(model,
serialization_dir,
iterator,
Expand Down
42 changes: 40 additions & 2 deletions allennlp/tests/commands/fine_tune_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# pylint: disable=invalid-name,no-self-use
import argparse

import re
import shutil
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit: blank line before and after pytest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

from allennlp.common.testing import AllenNlpTestCase
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, fine_tune_model_from_args
from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \
fine_tune_model_from_args, fine_tune_model
from allennlp.common.params import Params
from allennlp.models import load_archive

class TestFineTune(AllenNlpTestCase):
def setUp(self):
Expand Down Expand Up @@ -50,3 +55,36 @@ def test_fine_tune_fails_without_required_args(self):
with self.assertRaises(SystemExit) as context:
self.parser.parse_args(["fine-tune", "-s", "serialization_dir", "-c", "path/to/config"])
assert context.exception.code == 2 # argparse code for incorrect usage

def test_fine_tune_nograd_regex(self):
original_model = load_archive(self.model_archive).model
name_parameters_original = dict(original_model.named_parameters())
regex_lists = [[],
[".*attend_feedforward.*", ".*token_embedder.*"],
[".*compare_feedforward.*"]]
for regex_list in regex_lists:
params = Params.from_file(self.config_file)
params["trainer"]["no_grad"] = regex_list
shutil.rmtree(self.serialization_dir, ignore_errors=True)
tuned_model = fine_tune_model(model=original_model,
params=params,
serialization_dir=self.serialization_dir)
# If regex is matched, parameter name should have requires_grad False
# If regex is matched, parameter name should have same requires_grad
# as the originally loaded model
if regex_list:
nograd_regex = "(" + ")|(".join(regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more ok with this being in the test, but I'd still prefer the other way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem, i can change that.

for name, parameter in tuned_model.named_parameters():
if re.search(nograd_regex, name):
assert not parameter.requires_grad
else:
assert parameter.requires_grad \
== name_parameters_original[name].requires_grad
# If all parameters have requires_grad=False, then error.
with pytest.raises(Exception) as _:
params = Params.from_file(self.config_file)
params["trainer"]["no_grad"] = ["*"]
shutil.rmtree(self.serialization_dir, ignore_errors=True)
tuned_model = fine_tune_model(model=original_model,
params=params,
serialization_dir=self.serialization_dir)
54 changes: 53 additions & 1 deletion allennlp/tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import argparse
from typing import Iterable
import os

import shutil
import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit: blank line before import pytest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

import pytest
import torch

Expand Down Expand Up @@ -232,3 +233,54 @@ def test_train_with_test_set(self):
})

train_model(params, serialization_dir=os.path.join(self.TEST_DIR, 'lazy_test_set'))

def test_train_nograd_regex(self):
params_get = lambda: Params({
"model": {
"type": "simple_tagger",
"text_field_embedder": {
"tokens": {
"type": "embedding",
"embedding_dim": 5
}
},
"encoder": {
"type": "lstm",
"input_size": 5,
"hidden_size": 7,
"num_layers": 2
}
},
"dataset_reader": {"type": "sequence_tagging"},
"train_data_path": SEQUENCE_TAGGING_DATA_PATH,
"validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
"iterator": {"type": "basic", "batch_size": 2},
"trainer": {
"num_epochs": 2,
"optimizer": "adam"
}
})
serialization_dir = os.path.join(self.TEST_DIR, 'test_train_nograd')
regex_lists = [[],
[".*text_field_embedder.*"],
[".*text_field_embedder.*", ".*encoder.*"]]
for regex_list in regex_lists:
params = params_get()
params["trainer"]["no_grad"] = regex_list
shutil.rmtree(serialization_dir, ignore_errors=True)
model = train_model(params, serialization_dir=serialization_dir)
# If regex is matched, parameter name should have requires_grad False
# Or else True
if regex_list:
nograd_regex = "(" + ")|(".join(regex_list) + ")"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I had done something similar elsewhere in previous commit : here and used that here.
Should I change there as well?

for name, parameter in model.named_parameters():
if re.search(nograd_regex, name):
assert not parameter.requires_grad
else:
assert parameter.requires_grad
# If all parameters have requires_grad=False, then error.
params = params_get()
params["trainer"]["no_grad"] = ["*"]
shutil.rmtree(serialization_dir, ignore_errors=True)
with pytest.raises(Exception) as _:
model = train_model(params, serialization_dir=serialization_dir)