|
| 1 | +""" |
| 2 | +The ``find-lr`` subcommand can be used to find a good learning rate for a model. |
| 3 | +It requires a configuration file and a directory in |
| 4 | +which to write the results. |
| 5 | +
|
| 6 | +.. code-block:: bash |
| 7 | +
|
| 8 | + $ allennlp find-lr --help |
| 9 | + usage: allennlp train [-h] -s SERIALIZATION_DIR [-o OVERRIDES] |
| 10 | + [--start-lr START_LR] [--end-lr END_LR] |
| 11 | + [--num-batches NUM_BATCHES] [--linear] |
| 12 | + param_path |
| 13 | +
|
| 14 | + Train the specified model on the specified dataset. |
| 15 | +
|
| 16 | + positional arguments: |
| 17 | + param_path path to parameter file describing the model to be |
| 18 | + trained |
| 19 | +
|
| 20 | + optional arguments: |
| 21 | + -h, --help show this help message and exit |
| 22 | + -s SERIALIZATION_DIR, --serialization-dir SERIALIZATION_DIR |
| 23 | + directory in which to save Learning rate vs loss |
| 24 | + -o OVERRIDES, --overrides OVERRIDES |
| 25 | + a JSON structure used to override the experiment |
| 26 | + configuration |
| 27 | + --start-lr START_LR |
| 28 | + Learning rate to start the search. |
| 29 | + --end-lr END_LR |
| 30 | + Learning rate up to which search is done. |
| 31 | + --num-batches NUM_BATCHES |
| 32 | + Number of mini-batches to run Learning rate finder |
| 33 | + --linear Increase learning rate linearly instead of exponential increase |
| 34 | +
|
| 35 | +""" |
| 36 | +from typing import List, Optional, Tuple |
| 37 | +import argparse |
| 38 | +import re |
| 39 | +import os |
| 40 | +import math |
| 41 | +import logging |
| 42 | +import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements,wrong-import-position |
| 43 | +import matplotlib.pyplot as plt # pylint: disablewrong-import-position |
| 44 | + |
| 45 | +from allennlp.commands.subcommand import Subcommand # pylint: disablewrong-import-position |
| 46 | +from allennlp.commands.train import datasets_from_params # pylint: disablewrong-import-position |
| 47 | +from allennlp.common.checks import ConfigurationError, check_for_gpu # pylint: disablewrong-import-position |
| 48 | +from allennlp.common import Params, Tqdm # pylint: disablewrong-import-position |
| 49 | +from allennlp.common.util import prepare_environment # pylint: disablewrong-import-position |
| 50 | +from allennlp.data import Vocabulary, DataIterator # pylint: disablewrong-import-position |
| 51 | +from allennlp.models import Model # pylint: disablewrong-import-position |
| 52 | +from allennlp.training import Trainer # pylint: disablewrong-import-position |
| 53 | + |
| 54 | + |
| 55 | +logger = logging.getLogger(__name__) # pylint: disable=invalid-name |
| 56 | + |
| 57 | + |
| 58 | +class FindLearningRate(Subcommand): |
| 59 | + def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: |
| 60 | + # pylint: disable=protected-access |
| 61 | + description = '''Find a learning rate range where loss decreases quickly |
| 62 | + for the specified model and dataset.''' |
| 63 | + subparser = parser.add_parser(name, description=description, help='Train a model') |
| 64 | + |
| 65 | + subparser.add_argument('param_path', |
| 66 | + type=str, |
| 67 | + help='path to parameter file describing the model to be trained') |
| 68 | + subparser.add_argument('-s', '--serialization-dir', |
| 69 | + required=True, |
| 70 | + type=str, |
| 71 | + help='The directory in which to save results.') |
| 72 | + |
| 73 | + subparser.add_argument('-o', '--overrides', |
| 74 | + type=str, |
| 75 | + default="", |
| 76 | + help='a JSON structure used to override the experiment configuration') |
| 77 | + subparser.add_argument('--start-lr', |
| 78 | + type=float, |
| 79 | + default=1e-5, |
| 80 | + help='Learning rate to start the search.') |
| 81 | + subparser.add_argument('--end-lr', |
| 82 | + type=float, |
| 83 | + default=10, |
| 84 | + help='Learning rate up to which search is done.') |
| 85 | + subparser.add_argument('--num-batches', |
| 86 | + type=int, |
| 87 | + default=100, |
| 88 | + help='Number of mini-batches to run Learning rate finder') |
| 89 | + subparser.add_argument('--stopping-factor', |
| 90 | + type=float, |
| 91 | + default=4.0, |
| 92 | + help='Stop the search when the current loss exceeds the best loss recorded by ' |
| 93 | + 'multiple of stopping factor') |
| 94 | + subparser.add_argument('--linear', |
| 95 | + action='store_true', |
| 96 | + help='Increase learning rate linearly instead of exponential increase') |
| 97 | + |
| 98 | + subparser.set_defaults(func=find_learning_rate_from_args) |
| 99 | + |
| 100 | + return subparser |
| 101 | + |
| 102 | +def find_learning_rate_from_args(args: argparse.Namespace) -> None: |
| 103 | + """ |
| 104 | + Start learning rate finder for given args |
| 105 | + """ |
| 106 | + params = Params.from_file(args.param_path, args.overrides) |
| 107 | + find_learning_rate_model(params, args.serialization_dir, |
| 108 | + args.start_lr, args.end_lr, |
| 109 | + args.num_batches, args.linear, args.stopping_factor) |
| 110 | + |
| 111 | +def find_learning_rate_model(params: Params, |
| 112 | + serialization_dir: str, |
| 113 | + start_lr: float, |
| 114 | + end_lr: float, |
| 115 | + num_batches: int, |
| 116 | + linear_steps: bool, |
| 117 | + stopping_factor: Optional[float]) -> None: |
| 118 | + """ |
| 119 | + Runs learning rate search for given `num_batches` and saves the results in ``serialization_dir`` |
| 120 | +
|
| 121 | + Parameters |
| 122 | + ---------- |
| 123 | + trainer: :class:`~allennlp.common.registrable.Registrable` |
| 124 | + params : ``Params`` |
| 125 | + A parameter object specifying an AllenNLP Experiment. |
| 126 | + serialization_dir : ``str`` |
| 127 | + The directory in which to save results. |
| 128 | + start_lr: ``float`` |
| 129 | + Learning rate to start the search. |
| 130 | + end_lr: ``float`` |
| 131 | + Learning rate upto which search is done. |
| 132 | + num_batches: ``int`` |
| 133 | + Number of mini-batches to run Learning rate finder. |
| 134 | + linear_steps: ``bool`` |
| 135 | + Increase learning rate linearly if False exponentially. |
| 136 | + stopping_factor: ``float`` |
| 137 | + Stop the search when the current loss exceeds the best loss recorded by |
| 138 | + multiple of stopping factor. If ``None`` search proceeds till the ``end_lr`` |
| 139 | + """ |
| 140 | + |
| 141 | + if os.path.exists(serialization_dir) and os.listdir(serialization_dir): |
| 142 | + raise ConfigurationError(f'Serialization directory {serialization_dir} already exists and is ' |
| 143 | + f'not empty.') |
| 144 | + |
| 145 | + prepare_environment(params) |
| 146 | + os.makedirs(serialization_dir, exist_ok=True) |
| 147 | + |
| 148 | + check_for_gpu(params.get('trainer').get('cuda_device', -1)) |
| 149 | + |
| 150 | + all_datasets = datasets_from_params(params) |
| 151 | + datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets)) |
| 152 | + |
| 153 | + for dataset in datasets_for_vocab_creation: |
| 154 | + if dataset not in all_datasets: |
| 155 | + raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}") |
| 156 | + |
| 157 | + logger.info("From dataset instances, %s will be considered for vocabulary creation.", |
| 158 | + ", ".join(datasets_for_vocab_creation)) |
| 159 | + vocab = Vocabulary.from_params( |
| 160 | + params.pop("vocabulary", {}), |
| 161 | + (instance for key, dataset in all_datasets.items() |
| 162 | + for instance in dataset |
| 163 | + if key in datasets_for_vocab_creation) |
| 164 | + ) |
| 165 | + |
| 166 | + model = Model.from_params(vocab=vocab, params=params.pop('model')) |
| 167 | + iterator = DataIterator.from_params(params.pop("iterator")) |
| 168 | + iterator.index_with(vocab) |
| 169 | + |
| 170 | + train_data = all_datasets['train'] |
| 171 | + |
| 172 | + trainer_params = params.pop("trainer") |
| 173 | + no_grad_regexes = trainer_params.pop("no_grad", ()) |
| 174 | + for name, parameter in model.named_parameters(): |
| 175 | + if any(re.search(regex, name) for regex in no_grad_regexes): |
| 176 | + parameter.requires_grad_(False) |
| 177 | + |
| 178 | + trainer = Trainer.from_params(model, |
| 179 | + serialization_dir, |
| 180 | + iterator, |
| 181 | + train_data, |
| 182 | + params=trainer_params, |
| 183 | + validation_data=None, |
| 184 | + validation_iterator=None) |
| 185 | + |
| 186 | + logger.info(f'Starting learning rate search from {start_lr} to {end_lr} in {num_batches} iterations.') |
| 187 | + learning_rates, losses = search_learning_rate(trainer, start_lr, |
| 188 | + end_lr, num_batches, |
| 189 | + linear_steps, stopping_factor) |
| 190 | + logger.info(f'Finished learning rate search.') |
| 191 | + losses = _smooth(losses, 0.98) |
| 192 | + |
| 193 | + _save_plot(learning_rates, losses, os.path.join(serialization_dir, 'lr-losses.png')) |
| 194 | + |
| 195 | +def search_learning_rate(trainer: Trainer, |
| 196 | + start_lr: float = 1e-5, |
| 197 | + end_lr: float = 10, |
| 198 | + num_batches: int = 100, |
| 199 | + linear_steps: bool = False, |
| 200 | + stopping_factor: Optional[float] = 4.0) -> Tuple[List[float], List[float]]: |
| 201 | + """ |
| 202 | + Runs training loop on the model using :class:`~allennlp.training.trainer.Trainer` |
| 203 | + increasing learning rate from ``start_lr`` to ``end_lr`` recording the losses. |
| 204 | +
|
| 205 | + Parameters |
| 206 | + ---------- |
| 207 | + trainer: :class:`~allennlp.training.trainer.Trainer` |
| 208 | + start_lr: ``float`` |
| 209 | + The learning rate to start the search. |
| 210 | + end_lr: ``float`` |
| 211 | + The learning rate upto which search is done. |
| 212 | + num_batches: ``int`` |
| 213 | + Number of batches to run the learning rate finder. |
| 214 | + linear_steps: ``bool`` |
| 215 | + Increase learning rate linearly if False exponentially. |
| 216 | + stopping_factor: ``float`` |
| 217 | + Stop the search when the current loss exceeds the best loss recorded by |
| 218 | + multiple of stopping factor. If ``None`` search proceeds till the ``end_lr`` |
| 219 | +
|
| 220 | + Returns |
| 221 | + ------- |
| 222 | + (learning_rates, losses): ``Tuple[List[float], List[float]]`` |
| 223 | + Returns list of learning rates and corresponding losses. |
| 224 | + Note: The losses are recorded before applying the corresponding learning rate |
| 225 | + """ |
| 226 | + if num_batches <= 10: |
| 227 | + raise ConfigurationError('The number of iterations for learning rate finder should be greater than 10.') |
| 228 | + |
| 229 | + trainer.model.train() |
| 230 | + |
| 231 | + train_generator = trainer.iterator(trainer.train_data, |
| 232 | + shuffle=trainer.shuffle) |
| 233 | + train_generator_tqdm = Tqdm.tqdm(train_generator, |
| 234 | + total=num_batches) |
| 235 | + |
| 236 | + learning_rates = [] |
| 237 | + losses = [] |
| 238 | + best = 1e9 |
| 239 | + if linear_steps: |
| 240 | + lr_update_factor = (end_lr - start_lr) / num_batches |
| 241 | + else: |
| 242 | + lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches) |
| 243 | + |
| 244 | + for i, batch in enumerate(train_generator_tqdm): |
| 245 | + |
| 246 | + if linear_steps: |
| 247 | + current_lr = start_lr + (lr_update_factor * i) |
| 248 | + else: |
| 249 | + current_lr = start_lr * (lr_update_factor ** i) |
| 250 | + |
| 251 | + for param_group in trainer.optimizer.param_groups: |
| 252 | + param_group['lr'] = current_lr |
| 253 | + |
| 254 | + trainer.optimizer.zero_grad() |
| 255 | + loss = trainer.batch_loss(batch, for_training=True) |
| 256 | + loss.backward() |
| 257 | + loss = loss.detach().cpu().item() |
| 258 | + |
| 259 | + if stopping_factor is not None and (math.isnan(loss) or loss > stopping_factor * best): |
| 260 | + logger.info(f'Loss ({loss}) exceeds stopping_factor * lowest recorded loss.') |
| 261 | + break |
| 262 | + |
| 263 | + trainer.rescale_gradients() |
| 264 | + trainer.optimizer.step() |
| 265 | + |
| 266 | + learning_rates.append(current_lr) |
| 267 | + losses.append(loss) |
| 268 | + |
| 269 | + if loss < best and i > 10: |
| 270 | + best = loss |
| 271 | + |
| 272 | + if i == num_batches: |
| 273 | + break |
| 274 | + |
| 275 | + return learning_rates, losses |
| 276 | + |
| 277 | + |
| 278 | +def _smooth(values: List[float], beta: float) -> List[float]: |
| 279 | + """ Exponential smoothing of values """ |
| 280 | + avg_value = 0. |
| 281 | + smoothed = [] |
| 282 | + for i, value in enumerate(values): |
| 283 | + avg_value = beta * avg_value + (1 - beta) * value |
| 284 | + smoothed.append(avg_value / (1 - beta ** (i + 1))) |
| 285 | + return smoothed |
| 286 | + |
| 287 | +def _save_plot(learning_rates: List[float], losses: List[float], save_path: str): |
| 288 | + plt.ylabel('loss') |
| 289 | + plt.xlabel('learning rate (log10 scale)') |
| 290 | + plt.xscale('log') |
| 291 | + plt.plot(learning_rates, losses) |
| 292 | + logger.info(f'Saving learning_rate vs loss plot to {save_path}.') |
| 293 | + plt.savefig(save_path) |
0 commit comments