This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
added on_backward
trainer callback
#5249
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import TYPE_CHECKING | ||
import torch | ||
|
||
from allennlp.training.callbacks.callback import TrainerCallback | ||
|
||
if TYPE_CHECKING: | ||
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer | ||
|
||
|
||
@TrainerCallback.register("mixed_precision_backward") | ||
class MixedPrecisionBackwardCallback(TrainerCallback): | ||
""" | ||
Performs backpropagation for mixed precision training. | ||
""" | ||
|
||
def on_backward( | ||
self, | ||
trainer: "GradientDescentTrainer", | ||
loss: torch.FloatTensor, | ||
backward_called: bool, | ||
**kwargs | ||
) -> bool: | ||
if not backward_called: | ||
trainer._scaler.scale(loss).backward() # type: ignore | ||
return True | ||
return False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it an error if this gets called with |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -127,6 +127,26 @@ def setup_method(self): | |||||
self.validation_data_loader.index_with(self.vocab) | ||||||
|
||||||
|
||||||
class ZeroGradientsBackwardCallback(TrainerCallback): | ||||||
""" | ||||||
Zeros all gradients after backpropagation. | ||||||
""" | ||||||
|
||||||
def on_backward( | ||||||
self, | ||||||
trainer: "GradientDescentTrainer", | ||||||
loss: torch.FloatTensor, | ||||||
backward_called: bool, | ||||||
**kwargs, | ||||||
) -> bool: | ||||||
if not backward_called: | ||||||
loss.backward() | ||||||
for param in trainer.model.parameters(): | ||||||
param.grad *= 0.0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that really the best way to do that?
Suggested change
I don't know for sure, but I would guess that |
||||||
return True | ||||||
return False | ||||||
|
||||||
|
||||||
class TestTrainer(TrainerTestBase): | ||||||
def test_trainer_can_run(self): | ||||||
trainer = GradientDescentTrainer( | ||||||
|
@@ -168,6 +188,66 @@ def test_trainer_can_run(self): | |||||
assert isinstance(metrics["peak_worker_0_memory_MB"], float) | ||||||
assert metrics["peak_worker_0_memory_MB"] > 0 | ||||||
|
||||||
def test_train_zero_gradients(self): | ||||||
weights = {} | ||||||
for name, param in self.model.named_parameters(): | ||||||
weights[name] = param.data.clone() | ||||||
|
||||||
trainer = GradientDescentTrainer( | ||||||
self.model, | ||||||
self.optimizer, | ||||||
self.data_loader, | ||||||
num_epochs=2, | ||||||
validation_data_loader=self.validation_data_loader, | ||||||
callbacks=[ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR)], | ||||||
) | ||||||
trainer.train() | ||||||
|
||||||
# weights should be the same | ||||||
for name, param in self.model.named_parameters(): | ||||||
assert torch.equal(weights[name], param.data) | ||||||
|
||||||
def test_two_backward_callbacks(self): | ||||||
class SecondBackwardCallback(TrainerCallback): | ||||||
""" | ||||||
Changes all gradients to 1 after backpropagation. | ||||||
""" | ||||||
|
||||||
def on_backward( | ||||||
self, | ||||||
trainer: "GradientDescentTrainer", | ||||||
loss: torch.FloatTensor, | ||||||
backward_called: bool, | ||||||
**kwargs, | ||||||
) -> bool: | ||||||
if not backward_called: | ||||||
loss.backward() | ||||||
for param in trainer.model.parameters(): | ||||||
param.grad = torch.ones_like(param.grad, device=param.grad.device) | ||||||
return True | ||||||
return False | ||||||
|
||||||
weights = {} | ||||||
for name, param in self.model.named_parameters(): | ||||||
weights[name] = param.data.clone() | ||||||
|
||||||
trainer = GradientDescentTrainer( | ||||||
self.model, | ||||||
self.optimizer, | ||||||
self.data_loader, | ||||||
num_epochs=2, | ||||||
validation_data_loader=self.validation_data_loader, | ||||||
callbacks=[ | ||||||
ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR), | ||||||
SecondBackwardCallback(serialization_dir=self.TEST_DIR), | ||||||
], | ||||||
) | ||||||
trainer.train() | ||||||
|
||||||
# weights should be the same | ||||||
for name, param in self.model.named_parameters(): | ||||||
assert torch.equal(weights[name], param.data) | ||||||
|
||||||
def test_trainer_can_run_exponential_moving_average(self): | ||||||
moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999) | ||||||
trainer = GradientDescentTrainer( | ||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't accurate anymore, is it?