forked from allenai/allennlp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradient_descent_trainer.py
1237 lines (1054 loc) · 56.4 KB
/
gradient_descent_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import datetime
import glob
import logging
import math
import os
import re
import time
import warnings
from typing import Optional, Union, List, Dict, Tuple, Any, Type
import torch
from torch.cuda import amp
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.cuda.amp.grad_scaler import OptState
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import util as common_util, Tqdm, Lazy
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict
from allennlp.models.model import Model
from allennlp.nn.parallel import DdpAccelerator, DdpWrappedModel, TorchDdpAccelerator
from allennlp.nn.util import dist_reduce_sum
from allennlp.training.callbacks import ConsoleLoggerCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
from allennlp.training.momentum_schedulers.momentum_scheduler import MomentumScheduler
from allennlp.training.moving_average import MovingAverage
from allennlp.training.optimizers import Optimizer
from allennlp.training.trainer import Trainer, TrainerCheckpoint
from allennlp.training.callbacks import TrainerCallback
from allennlp.training import util as training_util
logger = logging.getLogger(__name__)
@Trainer.register("gradient_descent", constructor="from_partial_objects")
class GradientDescentTrainer(Trainer):
"""
A trainer for doing supervised learning with gradient descent. It just takes a labeled dataset
and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over
some fixed number of epochs. You can also pass in a validation data_loader and enable early
stopping. There are many other bells and whistles as well.
Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`).
The constructor that is registered is [`from_partial_objects`](#from_partial_objects) -
see the arguments to that function for the exact keys that should be used, if you are using
a configuration file. They largely match the arguments to `__init__`, and we don't repeat their
docstrings in `from_partial_objects`.
[0]: https://tinyurl.com/y5mv44fw
# Parameters
model : `Model`, required.
An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
their `forward` method returns a dictionary with a "loss" key, containing a
scalar tensor representing the loss function to be optimized.
If you are training your model using GPUs, your model should already be
on the correct device. (If you are using our `train` command this will be
handled for you.)
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
optimizer : `torch.nn.Optimizer`, required.
An instance of a Pytorch Optimizer, instantiated with the parameters of the
model to be optimized.
data_loader : `DataLoader`, required.
A `DataLoader` containing your `Dataset`, yielding padded indexed batches.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
patience : `Optional[int] > 0`, optional (default=`None`)
Number of epochs to be patient before early stopping: the training is stopped
after `patience` epochs with no improvement. If given, it must be `> 0`.
If None, early stopping is disabled.
validation_metric : `Union[str, List[str]]`, optional (default=`"-loss"`)
Validation metric to measure for whether to stop training using patience
and whether to serialize an `is_best` model each epoch. The metric name
must be prepended with either "+" or "-", which specifies whether the metric
is an increasing or decreasing function. If you specify more than one metric,
the metrics will be summed to make the `is_best` decision.
validation_data_loader : `DataLoader`, optional (default=`None`)
A `DataLoader` to use for the validation set. If `None`, then
use the training `DataLoader` with the validation data.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
num_epochs : `int`, optional (default = `20`)
Number of training epochs.
serialization_dir : `str`, optional (default=`None`)
Path to directory for saving and loading model files. Models will not be saved if
this parameter is not passed.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
checkpointer : `Checkpointer`, optional (default=`None`)
A `Checkpointer` is responsible for periodically saving model weights. If none is given
here, we will construct one with default parameters.
cuda_device : `Optional[Union[int, torch.device]]`, optional (default = `None`)
An integer or `torch.device` specifying the CUDA device to use for this process.
If -1, the CPU is used. If `None` and you have a GPU available, that GPU will be used.
!!! Note
If you *don't* intend to use a GPU, but you have one available, you'll need
to explicitly set `cuda_device=-1`.
!!! Note
If you intend to use a GPU, your model already needs to be on the correct device,
which you can do with `model = model.cuda()`.
!!! Note
Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU.
grad_norm : `Union[float, bool]`, optional (default = `False`)
If a float, gradient norms will be rescaled to have a maximum of this value.
If `True`, the gradient norms will be calculated and passed through to any `TrainerCallbacks`,
but won't be rescaled.
If `False`, gradient norms will not be calculated or rescaled.
grad_clipping : `float`, optional (default = `None`)
If provided, gradients will be clipped `during the backward pass` to have an (absolute)
maximum of this value. If you are getting `NaNs` in your gradients during training
that are not solved by using `grad_norm`, you may need this.
learning_rate_scheduler : `LearningRateScheduler`, optional (default = `None`)
If specified, the learning rate will be decayed with respect to
this schedule at the end of each epoch (or batch, if the scheduler implements
the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`,
this will use the `validation_metric` provided to determine if learning has plateaued.
To support updating the learning rate on every batch, this can optionally implement
`step_batch(batch_num_total)` which updates the learning rate given the batch number.
momentum_scheduler : `MomentumScheduler`, optional (default = `None`)
If specified, the momentum will be updated at the end of each batch or epoch
according to the schedule.
moving_average : `MovingAverage`, optional, (default = `None`)
If provided, we will maintain moving averages for all parameters. During training, we
employ a shadow variable for each parameter, which maintains the moving average. During
evaluation, we backup the original parameters and assign the moving averages to corresponding
parameters. Be careful that when saving the checkpoint, we will save the moving averages of
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.
callbacks : `List[TrainerCallback]`, optional (default = `None`)
A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start
and end of training, etc.
distributed : `bool`, optional, (default = `False`)
If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also
requires `world_size` to be greater than 1.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately (you need a top-level "distributed" key, next to
the "trainer" entry, that specifies a list of "cuda_devices").
local_rank : `int`, optional, (default = `0`)
This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is
used as the rank.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
world_size : `int`, (default = `1`)
The number of `Trainer` workers participating in the distributed training.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"trainer", it gets constructed separately.
num_gradient_accumulation_steps : `int`, optional, (default = `1`)
Gradients are accumulated for the given number of steps before doing an optimizer step. This can
be useful to accommodate batches that are larger than the RAM size. Refer [Thomas Wolf's
post][0] for details on Gradient Accumulation.
use_amp : `bool`, optional, (default = `False`)
If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html).
enable_default_callbacks : `bool`, optional (default = `True`)
When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in
addition to any other callbacks listed in the `callbacks` parameter.
When set to `False`, `DEFAULT_CALLBACKS` are not used.
run_confidence_checks : `bool`, optional (default = `True`)
Determines whether model confidence checks, such as
[`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/),
are run.
run_sanity_checks : `bool`, optional (default = `True`)
This parameter is deprecated. Please use `run_confidence_checks` instead.
grad_scaling : `bool`, optional (default = `True`)
When `use_amp` is `True`, this determines whether or not to use a [`GradScaler`]
(https://pytorch.org/docs/stable/amp.html?highlight=gradscaler#torch.cuda.amp.GradScaler).
!!! Note
This parameter is ignored when `use_amp` is `False`.
ddp_wrapped_model : `Optional[DdpWrappedModel]`, optional (default = `None`)
The `model` wrapped with a `DdpAccelerator` for distributed training.
!!! Note
This is required for distributed training.
"""
def __init__(
self,
model: Model,
optimizer: torch.optim.Optimizer,
data_loader: DataLoader,
patience: Optional[int] = None,
validation_metric: Union[str, List[str]] = "-loss",
validation_data_loader: DataLoader = None,
num_epochs: int = 20,
serialization_dir: Optional[Union[str, os.PathLike]] = None,
checkpointer: Optional[Checkpointer] = None,
cuda_device: Optional[Union[int, torch.device]] = None,
grad_norm: Union[float, bool] = False,
grad_clipping: Optional[float] = None,
learning_rate_scheduler: Optional[LearningRateScheduler] = None,
momentum_scheduler: Optional[MomentumScheduler] = None,
moving_average: Optional[MovingAverage] = None,
callbacks: List[TrainerCallback] = None,
distributed: bool = False,
local_rank: int = 0,
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
use_amp: bool = False,
enable_default_callbacks: bool = True,
run_confidence_checks: bool = True,
grad_scaling: bool = True,
ddp_wrapped_model: Optional[DdpWrappedModel] = None,
**kwargs,
) -> None:
super().__init__(
serialization_dir=serialization_dir,
cuda_device=cuda_device,
distributed=distributed,
local_rank=local_rank,
world_size=world_size,
)
if "run_sanity_checks" in kwargs:
warnings.warn(
"'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.",
DeprecationWarning,
)
run_confidence_checks = kwargs["run_sanity_checks"]
# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
self.model = model
self.data_loader = data_loader
self.data_loader.set_target_device(self.cuda_device)
self._validation_data_loader = validation_data_loader
if self._validation_data_loader is not None:
self._validation_data_loader.set_target_device(self.cuda_device)
self.optimizer = optimizer
if patience is None: # no early stopping
if validation_data_loader is not None:
logger.warning(
"You provided a validation dataset but patience was set to None, "
"meaning that early stopping is disabled"
)
elif (not isinstance(patience, int)) or patience <= 0:
raise ConfigurationError(
'{} is an invalid value for "patience": it must be a positive integer '
"or None (if you want to disable early stopping)".format(patience)
)
# For tracking is_best_so_far and should_stop_early
self._metric_tracker = MetricTracker(validation_metric, patience)
self._num_epochs = num_epochs
self._checkpointer: Optional[Checkpointer] = checkpointer
self._grad_norm = grad_norm
self._grad_clipping = grad_clipping
self._learning_rate_scheduler = learning_rate_scheduler
self._momentum_scheduler = momentum_scheduler
self._moving_average = moving_average
self._callbacks = callbacks or []
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []
if run_confidence_checks:
default_callbacks.append(ConfidenceChecksCallback)
for callback_cls in default_callbacks:
for callback in self._callbacks:
if callback.__class__ == callback_cls:
break
else:
self._callbacks.append(callback_cls(self._serialization_dir))
self._num_gradient_accumulation_steps = num_gradient_accumulation_steps
self._ddp_wrapped_model = ddp_wrapped_model
if distributed:
# The model needs to be wrapped before initializing the optimizer,
# so at this point it's too late to wrap the model.
if ddp_wrapped_model is None:
raise ValueError("trainer requires 'ddp_wrapped_model' for distributed training")
# Make sure checkpointer knows if we're working with a sharded model.
if self._checkpointer is not None:
self._checkpointer.state_is_sharded = ddp_wrapped_model.is_sharded
# Enable automatic mixed precision training.
self._scaler: Optional[amp.GradScaler] = None
self._use_amp = use_amp
if self._use_amp:
if self.cuda_device == torch.device("cpu"):
raise ValueError("Using AMP requires a cuda device")
if grad_scaling:
if self._ddp_wrapped_model is None:
self._scaler = amp.GradScaler()
else:
self._scaler = self._ddp_wrapped_model.init_grad_scaler()
# training state management
self._epochs_completed: int = 0
self._start_after_epochs_completed: int = 0
self._batches_in_epoch_completed: int = 0
self._start_after_batches_in_epoch_completed: int = 0
self._best_model_filename: Optional[str] = None
self._should_validate_this_epoch: bool = True
# This is a kind of training state, but it is not serialized with the trainer state, because we can
# re-create it with `epochs_completed` and `batches_in_epoch_completed`.
self._total_batches_completed: int = 0
@property
def _pytorch_model(self):
if self._ddp_wrapped_model is None:
return self.model
return self._ddp_wrapped_model.model
def clip_gradient(self):
"""
Performs gradient clipping.
If the model is in mixed precision training, we would first unscale the gradient.
"""
if self._grad_clipping is not None:
# 1. We have to unscale the gradient before clipping
if self._scaler is not None:
optimizer_state = self._scaler._per_optimizer_states[id(self.optimizer)]
# 2. The `unscale_` shouldn't be performed more than once per optimizer per step call,
# so we only perform `unscale_` if it has not already been called.
if optimizer_state["stage"] is not OptState.UNSCALED:
self._scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_value_(
[p for p in self.model.parameters() if p.grad is not None], self._grad_clipping
)
def rescale_gradients(self) -> Optional[float]:
"""
Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
Returns the norm of the gradients if `grad_norm` is `True` or a `float`,
otherwise returns `None`.
"""
if not isinstance(self._grad_norm, bool):
if self._scaler is not None:
# Need to first unscale gradients in order to clip as usual.
self._scaler.unscale_(self.optimizer)
# Sometimes logic for clipping has to implemented within the model, like
# with FairScale's FullyShardedDataParallel.
if self._ddp_wrapped_model is not None:
return self._ddp_wrapped_model.clip_grad_norm_(self._grad_norm).item()
else:
parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None]
return clip_grad_norm_(parameters_to_clip, self._grad_norm).item()
elif self._grad_norm:
parameters_to_clip = [p for p in self.model.parameters() if p.grad is not None]
return torch.norm(
torch.stack([torch.norm(p.grad.detach()) for p in parameters_to_clip])
).item()
else:
return None
def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]:
"""
Does a forward pass on the given batch and returns the output dictionary that the model
returns, after adding any specified regularization penalty to the loss (if training).
"""
output_dict = self._pytorch_model(**batch)
if for_training:
try:
assert "loss" in output_dict
regularization_penalty = self.model.get_regularization_penalty()
if regularization_penalty is not None:
output_dict["reg_loss"] = regularization_penalty
output_dict["loss"] += regularization_penalty
except AssertionError:
if for_training:
raise RuntimeError(
"The model you are trying to optimize does not contain a"
" 'loss' key in the output of model.forward(inputs)."
)
return output_dict
def _train_epoch(self, epoch: int) -> Dict[str, float]:
"""
Trains one epoch and returns metrics.
"""
logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
cpu_memory_usage = []
for worker, memory in common_util.peak_cpu_memory().items():
cpu_memory_usage.append((worker, memory))
logger.info(f"Worker {worker} memory usage: {common_util.format_size(memory)}")
gpu_memory_usage = []
for gpu, memory in common_util.peak_gpu_memory().items():
gpu_memory_usage.append((gpu, memory))
logger.info(f"GPU {gpu} memory usage: {common_util.format_size(memory)}")
regularization_penalty = self.model.get_regularization_penalty()
train_loss = 0.0
train_reg_loss = None if regularization_penalty is None else 0.0
batch_reg_loss = None if regularization_penalty is None else 0.0
# Set the model to "train" mode.
self._pytorch_model.train()
# Get tqdm for the training batches
batch_generator = iter(self.data_loader)
batch_group_generator = common_util.lazy_groups_of(
batch_generator, self._num_gradient_accumulation_steps
)
logger.info("Training")
num_training_batches: Union[int, float]
try:
len_data_loader = len(self.data_loader)
num_training_batches = math.ceil(
len_data_loader / self._num_gradient_accumulation_steps
)
except TypeError:
num_training_batches = float("inf")
# Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's
# progress is shown
if self._primary:
batch_group_generator_tqdm = Tqdm.tqdm(
batch_group_generator, total=num_training_batches
)
else:
batch_group_generator_tqdm = batch_group_generator
done_early = False
for batch_group in batch_group_generator_tqdm:
if done_early:
break
if self._epochs_completed < self._start_after_epochs_completed or (
self._epochs_completed == self._start_after_epochs_completed
and self._batches_in_epoch_completed < self._start_after_batches_in_epoch_completed
):
self._batches_in_epoch_completed += 1
self._total_batches_completed += 1
continue
self.optimizer.zero_grad()
batch_loss = 0.0
batch_group_outputs = []
for batch in batch_group:
if self._distributed:
# Check whether the other workers have stopped already (due to differing amounts of
# data in each). If so, we can't proceed because we would hang when we hit the
# barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
# here because NCCL process groups apparently don't support BoolTensor.
done = torch.tensor(0, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
if done.item() > 0:
done_early = True
logger.warning(
f"Worker {torch.distributed.get_rank()} finishing training early! "
"This implies that there is an imbalance in your training "
"data across the workers and that some amount of it will be "
"ignored. A small amount of this is fine, but a major imbalance "
"should be avoided. Note: This warning will appear unless your "
"data is perfectly balanced."
)
break
with amp.autocast(self._use_amp):
batch_outputs = self.batch_outputs(batch, for_training=True)
batch_group_outputs.append(batch_outputs)
loss = batch_outputs["loss"]
reg_loss = batch_outputs.get("reg_loss")
if torch.isnan(loss):
raise ValueError("nan loss encountered")
loss = loss / len(batch_group)
batch_loss += loss.item()
if reg_loss is not None:
reg_loss = reg_loss / len(batch_group)
batch_reg_loss = reg_loss.item()
train_reg_loss += batch_reg_loss # type: ignore
backward_called = False
for callback in self._callbacks:
backward_called |= callback.on_backward(self, batch_outputs, backward_called)
if not backward_called:
if self._scaler is not None:
MixedPrecisionBackwardCallback(self._serialization_dir).on_backward(
self, batch_outputs, backward_called
)
else:
loss.backward()
if len(batch_group_outputs) <= 0:
continue
train_loss += batch_loss
batch_grad_norm = self.rescale_gradients()
self.clip_gradient()
if self._learning_rate_scheduler:
self._learning_rate_scheduler.step_batch(self._total_batches_completed + 1)
if self._momentum_scheduler:
self._momentum_scheduler.step_batch(self._total_batches_completed + 1)
if self._scaler is not None:
self._scaler.step(self.optimizer)
self._scaler.update()
else:
self.optimizer.step()
# Update moving averages
if self._moving_average is not None:
self._moving_average.apply(self._total_batches_completed + 1)
self._batches_in_epoch_completed += 1
self._total_batches_completed += 1
# Update the description with the latest metrics
metrics = training_util.get_metrics(
self.model,
train_loss,
train_reg_loss,
batch_loss,
batch_reg_loss,
self._batches_in_epoch_completed,
)
for callback in self._callbacks:
callback.on_batch(
self,
batch_group,
batch_group_outputs,
metrics,
epoch,
self._batches_in_epoch_completed,
is_training=True,
is_primary=self._primary,
batch_grad_norm=batch_grad_norm,
)
if self._primary:
# Updating tqdm only for the primary as the trainers wouldn't have one
description = training_util.description_from_metrics(metrics)
batch_group_generator_tqdm.set_description(description, refresh=False)
if self._checkpointer is not None:
self._checkpointer.maybe_save_checkpoint(
self, self._epochs_completed, self._batches_in_epoch_completed
)
if self._distributed and not done_early:
logger.info(
f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)."
)
# Indicate that we're done so that any workers that have remaining data stop the epoch early.
done = torch.tensor(1, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
assert done.item()
# Let all workers finish their epoch before computing
# the final statistics for the epoch.
if self._distributed:
dist.barrier()
if self._epochs_completed < self._start_after_epochs_completed or (
self._epochs_completed == self._start_after_epochs_completed
and self._batches_in_epoch_completed - 1 < self._start_after_batches_in_epoch_completed
):
metrics = {}
else:
train_loss = dist_reduce_sum(train_loss)
num_batches = dist_reduce_sum(self._batches_in_epoch_completed)
if train_reg_loss is not None:
train_reg_loss = dist_reduce_sum(train_reg_loss)
metrics = training_util.get_metrics(
self.model,
train_loss,
train_reg_loss,
batch_loss=None,
batch_reg_loss=None,
num_batches=num_batches,
reset=True,
)
for (worker, memory) in cpu_memory_usage:
metrics["worker_" + str(worker) + "_memory_MB"] = memory / (1024 * 1024)
for (gpu_num, memory) in gpu_memory_usage:
metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory / (1024 * 1024)
return metrics
def _validation_loss(self, epoch: int) -> Tuple[float, Optional[float], int]:
"""
Computes the validation loss. Returns it and the number of batches.
"""
logger.info("Validating")
self._pytorch_model.eval()
# Replace parameter values with the shadow values from the moving averages.
if self._moving_average is not None:
self._moving_average.assign_average_value()
try:
if self._validation_data_loader is not None:
validation_data_loader = self._validation_data_loader
else:
raise ConfigurationError(
"Validation results cannot be calculated without a validation_data_loader"
)
regularization_penalty = self.model.get_regularization_penalty()
# Having multiple tqdm bars in case of distributed training will be a mess. Hence only the primary's
# progress is shown
if self._primary:
val_generator_tqdm = Tqdm.tqdm(validation_data_loader)
else:
val_generator_tqdm = validation_data_loader
batches_this_epoch = 0
val_loss = 0.0
val_batch_loss = 0.0
val_reg_loss = None if regularization_penalty is None else 0.0
val_batch_reg_loss = None if regularization_penalty is None else 0.0
done_early = False
for batch in val_generator_tqdm:
if self._distributed:
# Check whether the other workers have stopped already (due to differing amounts of
# data in each). If so, we can't proceed because we would hang when we hit the
# barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
# here because NCCL process groups apparently don't support BoolTensor.
done = torch.tensor(0, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
if done.item() > 0:
done_early = True
logger.warning(
f"Worker {torch.distributed.get_rank()} finishing validation early! "
"This implies that there is an imbalance in your validation "
"data across the workers and that some amount of it will be "
"ignored. A small amount of this is fine, but a major imbalance "
"should be avoided. Note: This warning will appear unless your "
"data is perfectly balanced."
)
break
with amp.autocast(self._use_amp):
batch_outputs = self.batch_outputs(batch, for_training=False)
loss = batch_outputs.get("loss")
reg_loss = batch_outputs.get("reg_loss")
if loss is not None:
# You shouldn't necessarily have to compute a loss for validation, so we allow for
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
# currently only used as the divisor for the loss function, so we can safely only
# count those batches for which we actually have a loss. If this variable ever
# gets used for something else, we might need to change things around a bit.
batches_this_epoch += 1
val_batch_loss = loss.item()
val_loss += val_batch_loss
if reg_loss is not None:
val_batch_reg_loss = reg_loss.item()
val_reg_loss += val_batch_reg_loss # type: ignore
# Update the description with the latest metrics
val_metrics = training_util.get_metrics(
self.model,
val_loss,
val_reg_loss,
val_batch_loss,
val_batch_reg_loss,
batches_this_epoch,
)
description = training_util.description_from_metrics(val_metrics)
if self._primary:
val_generator_tqdm.set_description(description, refresh=False)
for callback in self._callbacks:
callback.on_batch(
self,
[batch],
[batch_outputs],
val_metrics,
epoch,
batches_this_epoch,
is_training=False,
is_primary=self._primary,
)
if self._distributed and not done_early:
logger.warning(
f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)."
)
# Indicate that we're done so that any workers that have remaining data stop validation early.
done = torch.tensor(1, device=self.cuda_device)
torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
assert done.item()
return val_loss, val_reg_loss, batches_this_epoch
finally:
# Now restore the original parameter values.
if self._moving_average is not None:
self._moving_average.restore()
def train(self) -> Dict[str, Any]:
"""
Trains the supplied model with the supplied parameters.
"""
try:
self._maybe_restore_checkpoint()
except RuntimeError as e:
configuration_error = ConfigurationError(
f"Could not recover training from the checkpoint in {self._serialization_dir}. "
"Did you mean to output to a different serialization directory or delete the "
"existing serialization directory?"
)
configuration_error.__cause__ = e
raise configuration_error
# Callbacks get their `on_start` call even when we're starting from a checkpoint.
for callback in self._callbacks:
callback.on_start(self, is_primary=self._primary)
# Set default values in case of failure
epoch = None
metrics = None
try:
metrics, epoch = self._try_train()
return metrics
finally:
if self._primary:
self._finalize_best_model_state()
for callback in self._callbacks:
callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary)
def _try_train(self) -> Tuple[Dict[str, Any], int]:
logger.info("Beginning training.")
val_metrics: Dict[str, float] = {}
metrics: Dict[str, Any] = {}
training_start_time = None
metrics["best_epoch"] = self._metric_tracker.best_epoch
for key, value in self._metric_tracker.best_epoch_metrics.items():
metrics["best_validation_" + key] = value
for epoch in range(self._num_epochs):
epoch_start_time = time.time()
train_metrics = self._train_epoch(epoch)
if self._epochs_completed < self._start_after_epochs_completed:
# We're still catching up with the checkpoint, so we do nothing.
# Note that we have to call _train_epoch() even when we know the epoch is skipped. We have to
# read from the data loader, because the data loader and dataset readers might use randomness,
# and we have to make sure we consume exactly the same instances in exactly the same way every
# time we train, even when starting from a checkpoint, so that we update the randomness
# generators in the same way each time.
self._epochs_completed += 1
self._batches_in_epoch_completed = 0
continue
if training_start_time is None:
training_start_time = epoch_start_time
# get peak of memory usage
for key, value in train_metrics.items():
if key.startswith("gpu_") and key.endswith("_memory_MB"):
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
elif key.startswith("worker_") and key.endswith("_memory_MB"):
metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
this_epoch_val_metric: float = 0.0
if self._validation_data_loader is not None and self._should_validate_this_epoch:
with torch.no_grad():
# We have a validation set, so compute all the metrics on it.
val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)
# It is safe again to wait till the validation is done. This is
# important to get the metrics right.
if self._distributed:
dist.barrier()
val_loss = dist_reduce_sum(val_loss)
num_batches = dist_reduce_sum(num_batches)
if val_reg_loss is not None:
val_reg_loss = dist_reduce_sum(val_reg_loss)
val_metrics = training_util.get_metrics(
self.model,
val_loss,
val_reg_loss,
batch_loss=None,
batch_reg_loss=None,
num_batches=num_batches,
reset=True,
)
# Check validation metric for early stopping
this_epoch_val_metric = self._metric_tracker.combined_score(val_metrics)
self._metric_tracker.add_metrics(val_metrics)
# Create overall metrics dict
training_elapsed_time = time.time() - training_start_time
metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
metrics["epoch"] = epoch
for key, value in train_metrics.items():
metrics["training_" + key] = value
for key, value in val_metrics.items():
metrics["validation_" + key] = value
if self._metric_tracker.is_best_so_far():
# Update all the best_ metrics.
# (Otherwise they just stay the same as they were.)
metrics["best_epoch"] = epoch
for key, value in val_metrics.items():
metrics["best_validation_" + key] = value
self._metric_tracker.best_epoch_metrics = val_metrics
if self._serialization_dir and self._primary:
common_util.dump_metrics(
os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"),
metrics,
)
# The Scheduler API is agnostic to whether your schedule requires a validation metric -
# if it doesn't, the validation metric passed here is ignored.
if self._learning_rate_scheduler:
self._learning_rate_scheduler.step(this_epoch_val_metric)
if self._momentum_scheduler:
self._momentum_scheduler.step(this_epoch_val_metric)
for callback in self._callbacks:
callback.on_epoch(self, metrics=metrics, epoch=epoch, is_primary=self._primary)
self._epochs_completed += 1
self._batches_in_epoch_completed = 0
checkpoint_saved = False
if self._checkpointer is not None:
# The checkpointer saves state from the learning rate scheduler, momentum scheduler, moving
# average, and callbacks, so we have to make sure those are updated before we save the
# checkpoint here.
checkpoint_saved = self._checkpointer.maybe_save_checkpoint(
self, self._epochs_completed, self._batches_in_epoch_completed
)
# Wait for each primary process to finish saving the checkpoint
if self._distributed:
dist.barrier()
if self._serialization_dir and self._metric_tracker.is_best_so_far():
should_save_model_state: bool
if self._ddp_wrapped_model is not None and self._ddp_wrapped_model.is_sharded:
# Each worker saves its own shard for now (we combine the shards later).
self._best_model_filename = os.path.join(
self._serialization_dir, f"best_w{self._rank}.th"
)
should_save_model_state = True
else:
self._best_model_filename = os.path.join(self._serialization_dir, "best.th")
should_save_model_state = self._primary
if should_save_model_state:
if self._moving_average is None:
# If we're not using a moving average and the checkpointer just saved a checkpoint,
# we can just copy over that model state checkpoint to the '_best_model_filename'.
# Otherwise we need to save the model state on our own.
if self._checkpointer is not None and checkpoint_saved:
last_checkpoint = self._checkpointer.find_latest_checkpoint()
assert last_checkpoint is not None
model_state_file, _ = last_checkpoint
if os.path.exists(self._best_model_filename):
os.remove(self._best_model_filename)
os.link(model_state_file, self._best_model_filename)
else:
self._save_model_state(self._best_model_filename)
else:
self._moving_average.assign_average_value()
try:
self._save_model_state(self._best_model_filename)
finally:
self._moving_average.restore()
# Wait for the primary process to finish saving the best
if self._distributed:
dist.barrier()
epoch_elapsed_time = time.time() - epoch_start_time
logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
if self._metric_tracker.should_stop_early():
logger.info("Ran out of patience. Stopping training.")
break
if epoch < self._num_epochs - 1:
time_per_epoch = training_elapsed_time / (
(epoch + 1) - self._start_after_epochs_completed
)
# Note: If the first non-skipped epoch is half skipped (because it was checkpointed half-way
# through), then this estimate is going to be optimistic.
estimated_time_remaining = (
time_per_epoch * self._num_epochs
) - training_elapsed_time
formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
logger.info("Estimated training time remaining: %s", formatted_time)
else:
epoch = self._num_epochs - 1
# Load the best model state before returning
if self._best_model_filename is None or self._metric_tracker.is_best_so_far():
self._finalize_model()
else:
# The model we're loading here has already been finalized.
self._load_model_state(self._best_model_filename)
return metrics, epoch
def _save_model_state(self, path: str) -> None:
if self._ddp_wrapped_model is not None:
torch.save(self._ddp_wrapped_model.state_dict(), path)
else:
torch.save(self.model.state_dict(), path)
def _load_model_state(self, path: str) -> None:
if self._ddp_wrapped_model is not None:
self._ddp_wrapped_model.load_state_dict(torch.load(path))
else:
self._pytorch_model.load_state_dict(torch.load(path))
def _finalize_model(self) -> None:
"""If we have a moving average, we have to finalize the model at the end of training."""
if self._moving_average is not None:
self._moving_average.assign_average_value()
def _finalize_best_model_state(self) -> None:
"""
The best model weights might be saved in sharded files, in which case we gather them
up and save them to a single 'best.th' file.
"""
if (
self._serialization_dir
and self._ddp_wrapped_model is not None
and self._ddp_wrapped_model.is_sharded
):
logger.info("Consolidating sharded model states")
sharded_model_state_files = list(
glob.iglob(os.path.join(self._serialization_dir, "best_w*.th"))
)
full_model_state = self._ddp_wrapped_model.consolidate_sharded_state(
sharded_model_state_files
)
self._best_model_filename = os.path.join(self._serialization_dir, "best.th")
torch.save(full_model_state, self._best_model_filename)
def get_checkpoint_state(self) -> Optional[TrainerCheckpoint]:
if self._distributed:
assert self._ddp_wrapped_model is not None