From 50a0abf39a3bab9cda354ebf39b1a004441eda20 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 12:41:28 -0400 Subject: [PATCH 1/7] Add support in fine_tune to selectively tune (freeze some parameters set through config file) --- allennlp/commands/fine_tune.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/allennlp/commands/fine_tune.py b/allennlp/commands/fine_tune.py index e51e6d3cd7e..fa45aada778 100644 --- a/allennlp/commands/fine_tune.py +++ b/allennlp/commands/fine_tune.py @@ -10,7 +10,7 @@ import logging import os from copy import deepcopy - +import re from allennlp.commands.evaluate import evaluate from allennlp.commands.subcommand import Subcommand from allennlp.commands.train import datasets_from_params @@ -166,6 +166,12 @@ def fine_tune_model(model: Model, test_data = all_datasets.get('test') trainer_params = params.pop("trainer") + nograd_regex_list = trainer_params.pop("no_grad", ()) + if nograd_regex_list: + nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" + for name, parameter in model.named_parameters(): + if re.search(nograd_regex, name): + parameter.requires_grad_(False) trainer = Trainer.from_params(model, serialization_dir, iterator, From f19e5a79a802ccf6778671a0a05cbaa1edd871a8 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 12:42:44 -0400 Subject: [PATCH 2/7] Add tests for selective fine tuning. --- allennlp/tests/commands/fine_tune_test.py | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/allennlp/tests/commands/fine_tune_test.py b/allennlp/tests/commands/fine_tune_test.py index a9aaeee078d..2e97c946d8c 100644 --- a/allennlp/tests/commands/fine_tune_test.py +++ b/allennlp/tests/commands/fine_tune_test.py @@ -2,7 +2,10 @@ import argparse from allennlp.common.testing import AllenNlpTestCase -from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, fine_tune_model_from_args +from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \ + fine_tune_model_from_args, fine_tune_model +from allennlp.common.params import Params +from allennlp.models import load_archive class TestFineTune(AllenNlpTestCase): def setUp(self): @@ -50,3 +53,36 @@ def test_fine_tune_fails_without_required_args(self): with self.assertRaises(SystemExit) as context: self.parser.parse_args(["fine-tune", "-s", "serialization_dir", "-c", "path/to/config"]) assert context.exception.code == 2 # argparse code for incorrect usage + + def test_fine_tune_nograd_regex(self): + original_model = load_archive(self.model_archive).model + name_parameters_original = dict(original_model.named_parameters()) + regex_lists = [[], + [".*attend_feedforward.*", ".*token_embedder.*"], + [".*compare_feedforward.*"]] + for regex_list in regex_lists: + params = Params.from_file(self.config_file) + params["trainer"]["no_grad"] = regex_list + shutil.rmtree(self.serialization_dir, ignore_errors=True) + tuned_model = fine_tune_model(model=original_model, + params=params, + serialization_dir=self.serialization_dir) + # If regex is matched, parameter name should have requires_grad False + # If regex is matched, parameter name should have same requires_grad + # as the originally loaded model + if regex_list: + nograd_regex = "(" + ")|(".join(regex_list) + ")" + for name, parameter in tuned_model.named_parameters(): + if re.search(nograd_regex, name): + assert not parameter.requires_grad + else: + assert parameter.requires_grad \ + == name_parameters_original[name].requires_grad + # If all parameters have requires_grad=False, then error. + with pytest.raises(Exception) as _: + params = Params.from_file(self.config_file) + params["trainer"]["no_grad"] = ["*"] + shutil.rmtree(self.serialization_dir, ignore_errors=True) + tuned_model = fine_tune_model(model=original_model, + params=params, + serialization_dir=self.serialization_dir) From e3716031d05672356316f93add9cf3cf5f115d36 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 12:43:27 -0400 Subject: [PATCH 3/7] Allow for turning off gradients in train command (since in fine-tune as well this is happening with "trainer" configs). --- allennlp/commands/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index dd67f907404..7a99d707aa3 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -37,7 +37,7 @@ import logging import os from copy import deepcopy - +import re import torch from allennlp.commands.evaluate import evaluate @@ -282,6 +282,12 @@ def train_model(params: Params, test_data = all_datasets.get('test') trainer_params = params.pop("trainer") + nograd_regex_list = trainer_params.pop("no_grad", ()) + if nograd_regex_list: + nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" + for name, parameter in model.named_parameters(): + if re.search(nograd_regex, name): + parameter.requires_grad_(False) trainer = Trainer.from_params(model, serialization_dir, iterator, From 3af6c81e421ce58b39cccd15afe61e7989f838af Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 12:56:14 -0400 Subject: [PATCH 4/7] Add missing imports in fine_tune_test.py --- allennlp/tests/commands/fine_tune_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/allennlp/tests/commands/fine_tune_test.py b/allennlp/tests/commands/fine_tune_test.py index 2e97c946d8c..c38dc9b2a5a 100644 --- a/allennlp/tests/commands/fine_tune_test.py +++ b/allennlp/tests/commands/fine_tune_test.py @@ -1,6 +1,8 @@ # pylint: disable=invalid-name,no-self-use import argparse - +import re +import shutil +import pytest from allennlp.common.testing import AllenNlpTestCase from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \ fine_tune_model_from_args, fine_tune_model From 6f0089b25dad4034d0cf5a81ded99cc96e9830c0 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 13:54:32 -0400 Subject: [PATCH 5/7] add tests for using 'no_grad' config with train command --- allennlp/tests/commands/train_test.py | 54 ++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/allennlp/tests/commands/train_test.py b/allennlp/tests/commands/train_test.py index 38c1244cd39..3a18854bc60 100644 --- a/allennlp/tests/commands/train_test.py +++ b/allennlp/tests/commands/train_test.py @@ -2,7 +2,8 @@ import argparse from typing import Iterable import os - +import shutil +import re import pytest import torch @@ -232,3 +233,54 @@ def test_train_with_test_set(self): }) train_model(params, serialization_dir=os.path.join(self.TEST_DIR, 'lazy_test_set')) + + def test_train_nograd_regex(self): + params_get = lambda: Params({ + "model": { + "type": "simple_tagger", + "text_field_embedder": { + "tokens": { + "type": "embedding", + "embedding_dim": 5 + } + }, + "encoder": { + "type": "lstm", + "input_size": 5, + "hidden_size": 7, + "num_layers": 2 + } + }, + "dataset_reader": {"type": "sequence_tagging"}, + "train_data_path": SEQUENCE_TAGGING_DATA_PATH, + "validation_data_path": SEQUENCE_TAGGING_DATA_PATH, + "iterator": {"type": "basic", "batch_size": 2}, + "trainer": { + "num_epochs": 2, + "optimizer": "adam" + } + }) + serialization_dir = os.path.join(self.TEST_DIR, 'test_train_nograd') + regex_lists = [[], + [".*text_field_embedder.*"], + [".*text_field_embedder.*", ".*encoder.*"]] + for regex_list in regex_lists: + params = params_get() + params["trainer"]["no_grad"] = regex_list + shutil.rmtree(serialization_dir, ignore_errors=True) + model = train_model(params, serialization_dir=serialization_dir) + # If regex is matched, parameter name should have requires_grad False + # Or else True + if regex_list: + nograd_regex = "(" + ")|(".join(regex_list) + ")" + for name, parameter in model.named_parameters(): + if re.search(nograd_regex, name): + assert not parameter.requires_grad + else: + assert parameter.requires_grad + # If all parameters have requires_grad=False, then error. + params = params_get() + params["trainer"]["no_grad"] = ["*"] + shutil.rmtree(serialization_dir, ignore_errors=True) + with pytest.raises(Exception) as _: + model = train_model(params, serialization_dir=serialization_dir) From e6f887db403a25e0a0663b7d0eabc61a95d9f587 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 17:59:57 -0400 Subject: [PATCH 6/7] Code cleanup: 1. for regex matches 2. follow import convention --- allennlp/commands/fine_tune.py | 11 +++++------ allennlp/commands/train.py | 11 +++++------ allennlp/tests/commands/fine_tune_test.py | 16 ++++++++-------- allennlp/tests/commands/train_test.py | 13 ++++++------- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/allennlp/commands/fine_tune.py b/allennlp/commands/fine_tune.py index fa45aada778..208f4af1c6a 100644 --- a/allennlp/commands/fine_tune.py +++ b/allennlp/commands/fine_tune.py @@ -11,6 +11,7 @@ import os from copy import deepcopy import re + from allennlp.commands.evaluate import evaluate from allennlp.commands.subcommand import Subcommand from allennlp.commands.train import datasets_from_params @@ -166,12 +167,10 @@ def fine_tune_model(model: Model, test_data = all_datasets.get('test') trainer_params = params.pop("trainer") - nograd_regex_list = trainer_params.pop("no_grad", ()) - if nograd_regex_list: - nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" - for name, parameter in model.named_parameters(): - if re.search(nograd_regex, name): - parameter.requires_grad_(False) + no_grad_regexes = trainer_params.pop("no_grad", ()) + for name, parameter in model.named_parameters(): + if any(re.search(regex, name) for regex in no_grad_regexes): + parameter.requires_grad_(False) trainer = Trainer.from_params(model, serialization_dir, iterator, diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 7a99d707aa3..2708e726061 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -38,6 +38,7 @@ import os from copy import deepcopy import re + import torch from allennlp.commands.evaluate import evaluate @@ -282,12 +283,10 @@ def train_model(params: Params, test_data = all_datasets.get('test') trainer_params = params.pop("trainer") - nograd_regex_list = trainer_params.pop("no_grad", ()) - if nograd_regex_list: - nograd_regex = "(" + ")|(".join(nograd_regex_list) + ")" - for name, parameter in model.named_parameters(): - if re.search(nograd_regex, name): - parameter.requires_grad_(False) + no_grad_regexes = trainer_params.pop("no_grad", ()) + for name, parameter in model.named_parameters(): + if any(re.search(regex, name) for regex in no_grad_regexes): + parameter.requires_grad_(False) trainer = Trainer.from_params(model, serialization_dir, iterator, diff --git a/allennlp/tests/commands/fine_tune_test.py b/allennlp/tests/commands/fine_tune_test.py index c38dc9b2a5a..a3b8584abe9 100644 --- a/allennlp/tests/commands/fine_tune_test.py +++ b/allennlp/tests/commands/fine_tune_test.py @@ -2,7 +2,9 @@ import argparse import re import shutil + import pytest + from allennlp.common.testing import AllenNlpTestCase from allennlp.commands.fine_tune import FineTune, fine_tune_model_from_file_paths, \ fine_tune_model_from_args, fine_tune_model @@ -72,14 +74,12 @@ def test_fine_tune_nograd_regex(self): # If regex is matched, parameter name should have requires_grad False # If regex is matched, parameter name should have same requires_grad # as the originally loaded model - if regex_list: - nograd_regex = "(" + ")|(".join(regex_list) + ")" - for name, parameter in tuned_model.named_parameters(): - if re.search(nograd_regex, name): - assert not parameter.requires_grad - else: - assert parameter.requires_grad \ - == name_parameters_original[name].requires_grad + for name, parameter in tuned_model.named_parameters(): + if any(re.search(regex, name) for regex in regex_list): + assert not parameter.requires_grad + else: + assert parameter.requires_grad \ + == name_parameters_original[name].requires_grad # If all parameters have requires_grad=False, then error. with pytest.raises(Exception) as _: params = Params.from_file(self.config_file) diff --git a/allennlp/tests/commands/train_test.py b/allennlp/tests/commands/train_test.py index 3a18854bc60..4bfe9caad20 100644 --- a/allennlp/tests/commands/train_test.py +++ b/allennlp/tests/commands/train_test.py @@ -4,6 +4,7 @@ import os import shutil import re + import pytest import torch @@ -271,13 +272,11 @@ def test_train_nograd_regex(self): model = train_model(params, serialization_dir=serialization_dir) # If regex is matched, parameter name should have requires_grad False # Or else True - if regex_list: - nograd_regex = "(" + ")|(".join(regex_list) + ")" - for name, parameter in model.named_parameters(): - if re.search(nograd_regex, name): - assert not parameter.requires_grad - else: - assert parameter.requires_grad + for name, parameter in model.named_parameters(): + if any(re.search(regex, name) for regex in regex_list): + assert not parameter.requires_grad + else: + assert parameter.requires_grad # If all parameters have requires_grad=False, then error. params = params_get() params["trainer"]["no_grad"] = ["*"] From 59ac24c0496ca2da96e031e552728cea409ab1d1 Mon Sep 17 00:00:00 2001 From: harshtrivedi Date: Tue, 26 Jun 2018 21:49:26 -0400 Subject: [PATCH 7/7] Add logging statements for knowing tunable and frozen parameters. --- allennlp/commands/fine_tune.py | 14 ++++++++++++++ allennlp/commands/train.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/allennlp/commands/fine_tune.py b/allennlp/commands/fine_tune.py index 208f4af1c6a..db9f3b30251 100644 --- a/allennlp/commands/fine_tune.py +++ b/allennlp/commands/fine_tune.py @@ -168,9 +168,23 @@ def fine_tune_model(model: Model, trainer_params = params.pop("trainer") no_grad_regexes = trainer_params.pop("no_grad", ()) + + nograd_parameter_names = [] + grad_parameter_names = [] for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) + nograd_parameter_names.append(name) + else: + grad_parameter_names.append(name) + + logger.info("Following parameters are Frozen (without gradient):") + for name in nograd_parameter_names: + logger.info(name) + logger.info("Following parameters are Tunable (with gradient):") + for name in grad_parameter_names: + logger.info(name) + trainer = Trainer.from_params(model, serialization_dir, iterator, diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 2708e726061..54e94a3727d 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -284,9 +284,23 @@ def train_model(params: Params, trainer_params = params.pop("trainer") no_grad_regexes = trainer_params.pop("no_grad", ()) + + nograd_parameter_names = [] + grad_parameter_names = [] for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) + nograd_parameter_names.append(name) + else: + grad_parameter_names.append(name) + + logger.info("Following parameters are Frozen (without gradient):") + for name in nograd_parameter_names: + logger.info(name) + logger.info("Following parameters are Tunable (with gradient):") + for name in grad_parameter_names: + logger.info(name) + trainer = Trainer.from_params(model, serialization_dir, iterator,