Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 9fcc795

Browse files
sai-prasannaDeNeutoy
authored andcommitted
Learning Rate Finder (#1776)
Adds a new command `find-lr` that allows one to search for learning rate range where loss drops rapidly. This addresses feature request #537 . Refer the following [blog post](https://medium.com/@surmenok/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0) linked in that issue for overview of how the finder works. The major changes are making few of the fields in `Trainer` "public" (ie remove underscore from names). I have used matplotlib to plot learning rate vs loss graph. I haven't written unit tests, if the current code is ok, will do. I am a little unsure on what to test exactly.
1 parent 63836c4 commit 9fcc795

8 files changed

+521
-44
lines changed

allennlp/commands/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from allennlp.commands.dry_run import DryRun
1313
from allennlp.commands.subcommand import Subcommand
1414
from allennlp.commands.test_install import TestInstall
15+
from allennlp.commands.find_learning_rate import FindLearningRate
1516
from allennlp.commands.train import Train
1617
from allennlp.common.util import import_submodules
1718

@@ -42,6 +43,7 @@ def main(prog: str = None,
4243
"fine-tune": FineTune(),
4344
"dry-run": DryRun(),
4445
"test-install": TestInstall(),
46+
"find-lr": FindLearningRate(),
4547

4648
# Superseded by overrides
4749
**subcommand_overrides
+293
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""
2+
The ``find-lr`` subcommand can be used to find a good learning rate for a model.
3+
It requires a configuration file and a directory in
4+
which to write the results.
5+
6+
.. code-block:: bash
7+
8+
$ allennlp find-lr --help
9+
usage: allennlp train [-h] -s SERIALIZATION_DIR [-o OVERRIDES]
10+
[--start-lr START_LR] [--end-lr END_LR]
11+
[--num-batches NUM_BATCHES] [--linear]
12+
param_path
13+
14+
Train the specified model on the specified dataset.
15+
16+
positional arguments:
17+
param_path path to parameter file describing the model to be
18+
trained
19+
20+
optional arguments:
21+
-h, --help show this help message and exit
22+
-s SERIALIZATION_DIR, --serialization-dir SERIALIZATION_DIR
23+
directory in which to save Learning rate vs loss
24+
-o OVERRIDES, --overrides OVERRIDES
25+
a JSON structure used to override the experiment
26+
configuration
27+
--start-lr START_LR
28+
Learning rate to start the search.
29+
--end-lr END_LR
30+
Learning rate up to which search is done.
31+
--num-batches NUM_BATCHES
32+
Number of mini-batches to run Learning rate finder
33+
--linear Increase learning rate linearly instead of exponential increase
34+
35+
"""
36+
from typing import List, Optional, Tuple
37+
import argparse
38+
import re
39+
import os
40+
import math
41+
import logging
42+
import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements,wrong-import-position
43+
import matplotlib.pyplot as plt # pylint: disablewrong-import-position
44+
45+
from allennlp.commands.subcommand import Subcommand # pylint: disablewrong-import-position
46+
from allennlp.commands.train import datasets_from_params # pylint: disablewrong-import-position
47+
from allennlp.common.checks import ConfigurationError, check_for_gpu # pylint: disablewrong-import-position
48+
from allennlp.common import Params, Tqdm # pylint: disablewrong-import-position
49+
from allennlp.common.util import prepare_environment # pylint: disablewrong-import-position
50+
from allennlp.data import Vocabulary, DataIterator # pylint: disablewrong-import-position
51+
from allennlp.models import Model # pylint: disablewrong-import-position
52+
from allennlp.training import Trainer # pylint: disablewrong-import-position
53+
54+
55+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
56+
57+
58+
class FindLearningRate(Subcommand):
59+
def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
60+
# pylint: disable=protected-access
61+
description = '''Find a learning rate range where loss decreases quickly
62+
for the specified model and dataset.'''
63+
subparser = parser.add_parser(name, description=description, help='Train a model')
64+
65+
subparser.add_argument('param_path',
66+
type=str,
67+
help='path to parameter file describing the model to be trained')
68+
subparser.add_argument('-s', '--serialization-dir',
69+
required=True,
70+
type=str,
71+
help='The directory in which to save results.')
72+
73+
subparser.add_argument('-o', '--overrides',
74+
type=str,
75+
default="",
76+
help='a JSON structure used to override the experiment configuration')
77+
subparser.add_argument('--start-lr',
78+
type=float,
79+
default=1e-5,
80+
help='Learning rate to start the search.')
81+
subparser.add_argument('--end-lr',
82+
type=float,
83+
default=10,
84+
help='Learning rate up to which search is done.')
85+
subparser.add_argument('--num-batches',
86+
type=int,
87+
default=100,
88+
help='Number of mini-batches to run Learning rate finder')
89+
subparser.add_argument('--stopping-factor',
90+
type=float,
91+
default=4.0,
92+
help='Stop the search when the current loss exceeds the best loss recorded by '
93+
'multiple of stopping factor')
94+
subparser.add_argument('--linear',
95+
action='store_true',
96+
help='Increase learning rate linearly instead of exponential increase')
97+
98+
subparser.set_defaults(func=find_learning_rate_from_args)
99+
100+
return subparser
101+
102+
def find_learning_rate_from_args(args: argparse.Namespace) -> None:
103+
"""
104+
Start learning rate finder for given args
105+
"""
106+
params = Params.from_file(args.param_path, args.overrides)
107+
find_learning_rate_model(params, args.serialization_dir,
108+
args.start_lr, args.end_lr,
109+
args.num_batches, args.linear, args.stopping_factor)
110+
111+
def find_learning_rate_model(params: Params,
112+
serialization_dir: str,
113+
start_lr: float,
114+
end_lr: float,
115+
num_batches: int,
116+
linear_steps: bool,
117+
stopping_factor: Optional[float]) -> None:
118+
"""
119+
Runs learning rate search for given `num_batches` and saves the results in ``serialization_dir``
120+
121+
Parameters
122+
----------
123+
trainer: :class:`~allennlp.common.registrable.Registrable`
124+
params : ``Params``
125+
A parameter object specifying an AllenNLP Experiment.
126+
serialization_dir : ``str``
127+
The directory in which to save results.
128+
start_lr: ``float``
129+
Learning rate to start the search.
130+
end_lr: ``float``
131+
Learning rate upto which search is done.
132+
num_batches: ``int``
133+
Number of mini-batches to run Learning rate finder.
134+
linear_steps: ``bool``
135+
Increase learning rate linearly if False exponentially.
136+
stopping_factor: ``float``
137+
Stop the search when the current loss exceeds the best loss recorded by
138+
multiple of stopping factor. If ``None`` search proceeds till the ``end_lr``
139+
"""
140+
141+
if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
142+
raise ConfigurationError(f'Serialization directory {serialization_dir} already exists and is '
143+
f'not empty.')
144+
145+
prepare_environment(params)
146+
os.makedirs(serialization_dir, exist_ok=True)
147+
148+
check_for_gpu(params.get('trainer').get('cuda_device', -1))
149+
150+
all_datasets = datasets_from_params(params)
151+
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
152+
153+
for dataset in datasets_for_vocab_creation:
154+
if dataset not in all_datasets:
155+
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
156+
157+
logger.info("From dataset instances, %s will be considered for vocabulary creation.",
158+
", ".join(datasets_for_vocab_creation))
159+
vocab = Vocabulary.from_params(
160+
params.pop("vocabulary", {}),
161+
(instance for key, dataset in all_datasets.items()
162+
for instance in dataset
163+
if key in datasets_for_vocab_creation)
164+
)
165+
166+
model = Model.from_params(vocab=vocab, params=params.pop('model'))
167+
iterator = DataIterator.from_params(params.pop("iterator"))
168+
iterator.index_with(vocab)
169+
170+
train_data = all_datasets['train']
171+
172+
trainer_params = params.pop("trainer")
173+
no_grad_regexes = trainer_params.pop("no_grad", ())
174+
for name, parameter in model.named_parameters():
175+
if any(re.search(regex, name) for regex in no_grad_regexes):
176+
parameter.requires_grad_(False)
177+
178+
trainer = Trainer.from_params(model,
179+
serialization_dir,
180+
iterator,
181+
train_data,
182+
params=trainer_params,
183+
validation_data=None,
184+
validation_iterator=None)
185+
186+
logger.info(f'Starting learning rate search from {start_lr} to {end_lr} in {num_batches} iterations.')
187+
learning_rates, losses = search_learning_rate(trainer, start_lr,
188+
end_lr, num_batches,
189+
linear_steps, stopping_factor)
190+
logger.info(f'Finished learning rate search.')
191+
losses = _smooth(losses, 0.98)
192+
193+
_save_plot(learning_rates, losses, os.path.join(serialization_dir, 'lr-losses.png'))
194+
195+
def search_learning_rate(trainer: Trainer,
196+
start_lr: float = 1e-5,
197+
end_lr: float = 10,
198+
num_batches: int = 100,
199+
linear_steps: bool = False,
200+
stopping_factor: Optional[float] = 4.0) -> Tuple[List[float], List[float]]:
201+
"""
202+
Runs training loop on the model using :class:`~allennlp.training.trainer.Trainer`
203+
increasing learning rate from ``start_lr`` to ``end_lr`` recording the losses.
204+
205+
Parameters
206+
----------
207+
trainer: :class:`~allennlp.training.trainer.Trainer`
208+
start_lr: ``float``
209+
The learning rate to start the search.
210+
end_lr: ``float``
211+
The learning rate upto which search is done.
212+
num_batches: ``int``
213+
Number of batches to run the learning rate finder.
214+
linear_steps: ``bool``
215+
Increase learning rate linearly if False exponentially.
216+
stopping_factor: ``float``
217+
Stop the search when the current loss exceeds the best loss recorded by
218+
multiple of stopping factor. If ``None`` search proceeds till the ``end_lr``
219+
220+
Returns
221+
-------
222+
(learning_rates, losses): ``Tuple[List[float], List[float]]``
223+
Returns list of learning rates and corresponding losses.
224+
Note: The losses are recorded before applying the corresponding learning rate
225+
"""
226+
if num_batches <= 10:
227+
raise ConfigurationError('The number of iterations for learning rate finder should be greater than 10.')
228+
229+
trainer.model.train()
230+
231+
train_generator = trainer.iterator(trainer.train_data,
232+
shuffle=trainer.shuffle)
233+
train_generator_tqdm = Tqdm.tqdm(train_generator,
234+
total=num_batches)
235+
236+
learning_rates = []
237+
losses = []
238+
best = 1e9
239+
if linear_steps:
240+
lr_update_factor = (end_lr - start_lr) / num_batches
241+
else:
242+
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)
243+
244+
for i, batch in enumerate(train_generator_tqdm):
245+
246+
if linear_steps:
247+
current_lr = start_lr + (lr_update_factor * i)
248+
else:
249+
current_lr = start_lr * (lr_update_factor ** i)
250+
251+
for param_group in trainer.optimizer.param_groups:
252+
param_group['lr'] = current_lr
253+
254+
trainer.optimizer.zero_grad()
255+
loss = trainer.batch_loss(batch, for_training=True)
256+
loss.backward()
257+
loss = loss.detach().cpu().item()
258+
259+
if stopping_factor is not None and (math.isnan(loss) or loss > stopping_factor * best):
260+
logger.info(f'Loss ({loss}) exceeds stopping_factor * lowest recorded loss.')
261+
break
262+
263+
trainer.rescale_gradients()
264+
trainer.optimizer.step()
265+
266+
learning_rates.append(current_lr)
267+
losses.append(loss)
268+
269+
if loss < best and i > 10:
270+
best = loss
271+
272+
if i == num_batches:
273+
break
274+
275+
return learning_rates, losses
276+
277+
278+
def _smooth(values: List[float], beta: float) -> List[float]:
279+
""" Exponential smoothing of values """
280+
avg_value = 0.
281+
smoothed = []
282+
for i, value in enumerate(values):
283+
avg_value = beta * avg_value + (1 - beta) * value
284+
smoothed.append(avg_value / (1 - beta ** (i + 1)))
285+
return smoothed
286+
287+
def _save_plot(learning_rates: List[float], losses: List[float], save_path: str):
288+
plt.ylabel('loss')
289+
plt.xlabel('learning rate (log10 scale)')
290+
plt.xscale('log')
291+
plt.plot(learning_rates, losses)
292+
logger.info(f'Saving learning_rate vs loss plot to {save_path}.')
293+
plt.savefig(save_path)

0 commit comments

Comments
 (0)