Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit aa3aefa

Browse files
author
pab-vmware
authored
Merge branch 'main' into datasets_feature
2 parents 08d3012 + decb875 commit aa3aefa

File tree

5 files changed

+36
-20
lines changed

5 files changed

+36
-20
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Changed
1515

16+
- Sanity checks in the `GradientDescentTrainer` can now be turned off by setting the `run_sanity_checks` parameter to `False`.
1617
- Allow the order of examples in the task cards to be specified explicitly
1718
- `histogram_interval` parameter is now deprecated in `TensorboardWriter`, please use `distribution_interval` instead.
1819
- Memory usage is not logged in tensorboard during training now. `ConsoleLoggerCallback` should be used instead.

allennlp/sanity_checks/normalization_bias_verification.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
2-
Code based almost entirely on
3-
https://github.com/awaelchli/pytorch-lightning-snippets/commit/7db53f774715d635c59ef56f21a17634d246b2c5
2+
Code based almost entirely from the [pytorch-lightning-snippets]
3+
(https://github.com/awaelchli/pytorch-lightning-snippets/commit/7db53f774715d635c59ef56f21a17634d246b2c5)
4+
repository.
45
"""
56

67
import torch

allennlp/training/trainer.py

+30-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
import traceback
88
from contextlib import contextmanager
9-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
9+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type
1010

1111
from allennlp.common.util import int_to_device
1212

@@ -118,10 +118,10 @@ class GradientDescentTrainer(Trainer):
118118
stopping. There are many other bells and whistles as well.
119119
120120
Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`).
121-
The constructor that is registered is `from_partial_objects` - see the arguments to that
122-
function for the exact keys that should be used, if you are using a configuration file. They
123-
largely match the arguments to `__init__`, and we don't repeat their docstrings in
124-
`from_partial_objects`.
121+
The constructor that is registered is [`from_partial_objects`](#from_partial_objects) -
122+
see the arguments to that function for the exact keys that should be used, if you are using
123+
a configuration file. They largely match the arguments to `__init__`, and we don't repeat their
124+
docstrings in `from_partial_objects`.
125125
126126
[0]: https://tinyurl.com/y5mv44fw
127127
@@ -248,6 +248,16 @@ class GradientDescentTrainer(Trainer):
248248
use_amp : `bool`, optional, (default = `False`)
249249
If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html).
250250
251+
enable_default_callbacks : `bool`, optional (default = `True`)
252+
When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in
253+
addition to any other callbacks listed in the `callbacks` parameter.
254+
When set to `False`, `DEFAULT_CALLBACKS` are not used.
255+
256+
run_sanity_checks : `bool`, optional (default = `True`)
257+
Determines whether model sanity checks, such as
258+
[`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/),
259+
are ran.
260+
251261
"""
252262

253263
def __init__(
@@ -273,6 +283,8 @@ def __init__(
273283
world_size: int = 1,
274284
num_gradient_accumulation_steps: int = 1,
275285
use_amp: bool = False,
286+
enable_default_callbacks: bool = True,
287+
run_sanity_checks: bool = True,
276288
) -> None:
277289
super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size)
278290

@@ -316,6 +328,15 @@ def __init__(
316328
self._moving_average = moving_average
317329

318330
self._callbacks = callbacks or []
331+
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []
332+
if run_sanity_checks:
333+
default_callbacks.append(SanityChecksCallback)
334+
for callback_cls in default_callbacks:
335+
for callback in self._callbacks:
336+
if callback.__class__ == callback_cls:
337+
break
338+
else:
339+
self._callbacks.append(callback_cls(self._serialization_dir))
319340

320341
self._batch_num_total = 0
321342
self._last_log = 0.0 # time of last logging
@@ -970,6 +991,7 @@ def from_partial_objects(
970991
checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer),
971992
callbacks: List[Lazy[TrainerCallback]] = None,
972993
enable_default_callbacks: bool = True,
994+
run_sanity_checks: bool = True,
973995
) -> "Trainer":
974996
"""
975997
This method exists so that we can have a documented method to construct this class using
@@ -1037,13 +1059,6 @@ def from_partial_objects(
10371059
callbacks_: List[TrainerCallback] = []
10381060
for callback_ in callbacks or []:
10391061
callbacks_.append(callback_.construct(serialization_dir=serialization_dir))
1040-
if enable_default_callbacks:
1041-
for callback_cls in DEFAULT_CALLBACKS:
1042-
for callback in callbacks_:
1043-
if callback.__class__ == callback_cls:
1044-
break
1045-
else:
1046-
callbacks_.append(callback_cls(serialization_dir))
10471062

10481063
return cls(
10491064
model,
@@ -1067,13 +1082,12 @@ def from_partial_objects(
10671082
world_size=world_size,
10681083
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
10691084
use_amp=use_amp,
1085+
enable_default_callbacks=enable_default_callbacks,
1086+
run_sanity_checks=run_sanity_checks,
10701087
)
10711088

10721089

1073-
DEFAULT_CALLBACKS = (
1074-
SanityChecksCallback,
1075-
ConsoleLoggerCallback,
1076-
)
1090+
DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,)
10771091
"""
10781092
The default callbacks used by `GradientDescentTrainer`.
10791093
"""

dev-requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ nr.databind.core<0.0.17
3939
nr.interface<0.0.4
4040

4141
mkdocs==1.1.2
42-
mkdocs-material>=5.5.0,<7.1.0
42+
mkdocs-material>=5.5.0,<7.2.0
4343
markdown-include==0.6.0
4444

4545
#### PACKAGE-UPLOAD PACKAGES ####

tests/training/trainer_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def test_sanity_check_default(self):
847847
serialization_dir=self.TEST_DIR,
848848
data_loader=data_loader,
849849
num_epochs=1,
850-
enable_default_callbacks=False,
850+
run_sanity_checks=False,
851851
)
852852

853853
# Check is not run, so no failure.

0 commit comments

Comments
 (0)