6
6
import time
7
7
import traceback
8
8
from contextlib import contextmanager
9
- from typing import Any , Dict , Iterator , List , Optional , Tuple , Union
9
+ from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , Type
10
10
11
11
from allennlp .common .util import int_to_device
12
12
@@ -118,10 +118,10 @@ class GradientDescentTrainer(Trainer):
118
118
stopping. There are many other bells and whistles as well.
119
119
120
120
Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`).
121
- The constructor that is registered is `from_partial_objects` - see the arguments to that
122
- function for the exact keys that should be used, if you are using a configuration file. They
123
- largely match the arguments to `__init__`, and we don't repeat their docstrings in
124
- `from_partial_objects`.
121
+ The constructor that is registered is [ `from_partial_objects`](#from_partial_objects) -
122
+ see the arguments to that function for the exact keys that should be used, if you are using
123
+ a configuration file. They largely match the arguments to `__init__`, and we don't repeat their
124
+ docstrings in `from_partial_objects`.
125
125
126
126
[0]: https://tinyurl.com/y5mv44fw
127
127
@@ -248,6 +248,16 @@ class GradientDescentTrainer(Trainer):
248
248
use_amp : `bool`, optional, (default = `False`)
249
249
If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html).
250
250
251
+ enable_default_callbacks : `bool`, optional (default = `True`)
252
+ When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in
253
+ addition to any other callbacks listed in the `callbacks` parameter.
254
+ When set to `False`, `DEFAULT_CALLBACKS` are not used.
255
+
256
+ run_sanity_checks : `bool`, optional (default = `True`)
257
+ Determines whether model sanity checks, such as
258
+ [`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/),
259
+ are ran.
260
+
251
261
"""
252
262
253
263
def __init__ (
@@ -273,6 +283,8 @@ def __init__(
273
283
world_size : int = 1 ,
274
284
num_gradient_accumulation_steps : int = 1 ,
275
285
use_amp : bool = False ,
286
+ enable_default_callbacks : bool = True ,
287
+ run_sanity_checks : bool = True ,
276
288
) -> None :
277
289
super ().__init__ (serialization_dir , cuda_device , distributed , local_rank , world_size )
278
290
@@ -316,6 +328,15 @@ def __init__(
316
328
self ._moving_average = moving_average
317
329
318
330
self ._callbacks = callbacks or []
331
+ default_callbacks = list (DEFAULT_CALLBACKS ) if enable_default_callbacks else []
332
+ if run_sanity_checks :
333
+ default_callbacks .append (SanityChecksCallback )
334
+ for callback_cls in default_callbacks :
335
+ for callback in self ._callbacks :
336
+ if callback .__class__ == callback_cls :
337
+ break
338
+ else :
339
+ self ._callbacks .append (callback_cls (self ._serialization_dir ))
319
340
320
341
self ._batch_num_total = 0
321
342
self ._last_log = 0.0 # time of last logging
@@ -970,6 +991,7 @@ def from_partial_objects(
970
991
checkpointer : Lazy [Checkpointer ] = Lazy (Checkpointer ),
971
992
callbacks : List [Lazy [TrainerCallback ]] = None ,
972
993
enable_default_callbacks : bool = True ,
994
+ run_sanity_checks : bool = True ,
973
995
) -> "Trainer" :
974
996
"""
975
997
This method exists so that we can have a documented method to construct this class using
@@ -1037,13 +1059,6 @@ def from_partial_objects(
1037
1059
callbacks_ : List [TrainerCallback ] = []
1038
1060
for callback_ in callbacks or []:
1039
1061
callbacks_ .append (callback_ .construct (serialization_dir = serialization_dir ))
1040
- if enable_default_callbacks :
1041
- for callback_cls in DEFAULT_CALLBACKS :
1042
- for callback in callbacks_ :
1043
- if callback .__class__ == callback_cls :
1044
- break
1045
- else :
1046
- callbacks_ .append (callback_cls (serialization_dir ))
1047
1062
1048
1063
return cls (
1049
1064
model ,
@@ -1067,13 +1082,12 @@ def from_partial_objects(
1067
1082
world_size = world_size ,
1068
1083
num_gradient_accumulation_steps = num_gradient_accumulation_steps ,
1069
1084
use_amp = use_amp ,
1085
+ enable_default_callbacks = enable_default_callbacks ,
1086
+ run_sanity_checks = run_sanity_checks ,
1070
1087
)
1071
1088
1072
1089
1073
- DEFAULT_CALLBACKS = (
1074
- SanityChecksCallback ,
1075
- ConsoleLoggerCallback ,
1076
- )
1090
+ DEFAULT_CALLBACKS : Tuple [Type [TrainerCallback ]] = (ConsoleLoggerCallback ,)
1077
1091
"""
1078
1092
The default callbacks used by `GradientDescentTrainer`.
1079
1093
"""
0 commit comments