Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c3b5ed7

Browse files
authored
zero grad optimization (#4673)
1 parent 9dabf3f commit c3b5ed7

File tree

5 files changed

+20
-4
lines changed

5 files changed

+20
-4
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4949
- Fixed a bug where `cached_path()` would fail if passed a `cache_dir` with the user home shortcut `~/`.
5050
- Fixed a bug in our doc building script where markdown links did not render properly
5151
if the "href" part of the link (the part inside the `()`) was on a new line.
52+
- Changed how gradients are zeroed out with an optimization. See [this video from NVIDIA](https://www.youtube.com/watch?v=9mS1fIYj1So)
53+
at around the 9 minute mark.
5254

5355

5456
## [v1.1.0](https://github.com/allenai/allennlp/releases/tag/v1.1.0) - 2020-09-08

allennlp/commands/find_learning_rate.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,12 @@ def search_learning_rate(
287287

288288
for param_group in trainer.optimizer.param_groups:
289289
param_group["lr"] = current_lr
290+
# Zero gradients.
291+
# NOTE: this is actually more efficient than calling `self.optimizer.zero_grad()`
292+
# because it avoids a read op when the gradients are first updated below.
293+
for p in param_group["params"]:
294+
p.grad = None
290295

291-
trainer.optimizer.zero_grad()
292296
loss = trainer.batch_outputs(batch, for_training=True)["loss"]
293297
loss.backward()
294298
loss = loss.detach().cpu().item()

allennlp/common/testing/model_test_case.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def check_model_computes_gradients_correctly(
277277
disable_dropout: bool = True,
278278
):
279279
print("Checking gradients")
280-
model.zero_grad()
280+
for p in model.parameters():
281+
p.grad = None
281282
model.train()
282283

283284
original_dropouts: Dict[str, float] = {}

allennlp/predictors/predictor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
109109
)
110110

111111
loss = outputs["loss"]
112-
self._model.zero_grad()
112+
# Zero gradients.
113+
# NOTE: this is actually more efficient than calling `self._model.zero_grad()`
114+
# because it avoids a read op when the gradients are first updated below.
115+
for p in self._model.parameters():
116+
p.grad = None
113117
loss.backward()
114118

115119
for hook in hooks:

allennlp/training/trainer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,12 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
581581
self._batch_num_total += 1
582582
batch_num_total = self._batch_num_total
583583

584-
self.optimizer.zero_grad()
584+
# Zero gradients.
585+
# NOTE: this is actually more efficient than calling `self.optimizer.zero_grad()`
586+
# because it avoids a read op when the gradients are first updated below.
587+
for param_group in self.optimizer.param_groups:
588+
for p in param_group["params"]:
589+
p.grad = None
585590

586591
batch_group_outputs = []
587592
for batch in batch_group:

0 commit comments

Comments
 (0)