Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 7aaf5b0

Browse files
vreisfacebook-github-bot
authored andcommitted
Implement gradient clippping (#643)
Summary: Pull Request resolved: #643 Add support for gradient clipping in ClassificationTask Reviewed By: mannatsingh Differential Revision: D24736675 fbshipit-source-id: 9ed5c7a26f1708a81cf0d61f052629e1ff093983
1 parent e3ac96c commit 7aaf5b0

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

classy_vision/tasks/classification_task.py

+32
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class ClassificationTask(ClassyTask):
123123
:var data_iterator: Iterator which can be used to obtain batches
124124
:var losses: Loss curve
125125
:var perf_log: list of training speed measurements, to be logged
126+
:var clip_grad_norm: maximum gradient norm (default None)
126127
"""
127128

128129
def __init__(self):
@@ -165,6 +166,7 @@ def __init__(self):
165166
self.dataloader_mp_context = "spawn"
166167
self.bn_weight_decay = False
167168
self._train_only = True
169+
self.clip_grad_norm = None
168170

169171
def set_use_gpu(self, use_gpu: bool):
170172
self.use_gpu = use_gpu
@@ -175,6 +177,19 @@ def set_use_gpu(self, use_gpu: bool):
175177

176178
return self
177179

180+
def set_clip_grad_norm(self, clip_grad_norm: Optional[float]):
181+
"""Sets maximum gradient norm.
182+
183+
None means gradient clipping is disabled. Defaults to None."""
184+
self.clip_grad_norm = clip_grad_norm
185+
if clip_grad_norm is None:
186+
logging.info("Disabled gradient norm clipping.")
187+
else:
188+
logging.info(
189+
f"Enabled gradient norm clipping with threshold: {clip_grad_norm}"
190+
)
191+
return self
192+
178193
def set_checkpoint(self, checkpoint_path: str):
179194
"""Sets checkpoint on task.
180195
@@ -489,6 +504,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
489504
.set_distributed_options(**distributed_options)
490505
.set_hooks(hooks)
491506
.set_bn_weight_decay(config.get("bn_weight_decay", False))
507+
.set_clip_grad_norm(config.get("clip_grad_norm"))
492508
)
493509

494510
if not test_only:
@@ -934,10 +950,26 @@ def run_optimizer(self, loss):
934950
else:
935951
self.optimizer.backward(loss)
936952

953+
if self.clip_grad_norm is not None:
954+
self._clip_gradients(self.clip_grad_norm)
955+
937956
self.check_inf_nan(loss)
938957

939958
self.optimizer.step(where=self.where)
940959

960+
def _clip_gradients(self, max_norm):
961+
def all_params(optimizer):
962+
for group in optimizer.param_groups:
963+
for p in group["params"]:
964+
yield p
965+
966+
if self.amp_args is not None:
967+
params_iter = apex.amp.master_params(self.optimizer)
968+
else:
969+
params_iter = all_params(self.optimizer)
970+
971+
nn.utils.clip_grad_norm_(params_iter, max_norm)
972+
941973
def update_meters(self, model_output, sample):
942974
target = sample["target"].detach().cpu()
943975
model_output = model_output.detach().cpu()

test/tasks_classification_task_test.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import itertools
89
import shutil
910
import tempfile
1011
import unittest
@@ -17,13 +18,14 @@
1718
)
1819

1920
import torch
21+
import torch.nn as nn
2022
from classy_vision.dataset import build_dataset
2123
from classy_vision.generic.distributed_util import is_distributed_training_run
2224
from classy_vision.generic.util import get_checkpoint_dict
2325
from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook
2426
from classy_vision.losses import ClassyLoss, build_loss, register_loss
25-
from classy_vision.models import build_model
26-
from classy_vision.optim import build_optimizer
27+
from classy_vision.models import ClassyModel, build_model
28+
from classy_vision.optim import SGD, build_optimizer
2729
from classy_vision.tasks import ClassificationTask, build_task
2830
from classy_vision.trainer import LocalTrainer
2931

@@ -284,3 +286,74 @@ def test_get_classy_state_on_loss(self):
284286
task = build_task(config)
285287
task.prepare()
286288
self.assertIn("alpha", task.get_classy_state()["loss"])
289+
290+
def test_gradient_clipping(self):
291+
# Generate a simple model that has a very high gradient w.r.t. to this
292+
# loss
293+
class SimpleModel(ClassyModel):
294+
def __init__(self):
295+
super().__init__()
296+
self.param = nn.Parameter(torch.tensor(5.0), requires_grad=True)
297+
298+
def forward(self, x):
299+
return x + self.param
300+
301+
@classmethod
302+
def from_config(cls):
303+
return cls()
304+
305+
class SimpleLoss(nn.Module):
306+
def forward(self, out, target):
307+
return out.pow(2).mean()
308+
309+
apex_available = True
310+
try:
311+
import apex # noqa F401
312+
except ImportError:
313+
apex_available = False
314+
315+
def train_with_clipped_gradients(amp_args=None):
316+
task = build_task(get_fast_test_task_config())
317+
task.set_num_epochs(1)
318+
task.set_model(SimpleModel())
319+
task.set_loss(SimpleLoss())
320+
task.set_meters([])
321+
task.set_use_gpu(torch.cuda.is_available())
322+
task.set_clip_grad_norm(0.5)
323+
task.set_amp_args(amp_args)
324+
325+
task.set_optimizer(SGD(lr=1))
326+
327+
trainer = LocalTrainer()
328+
trainer.train(task)
329+
330+
return task.model.param.grad.norm()
331+
332+
grad_norm = train_with_clipped_gradients(None)
333+
self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)
334+
335+
if apex_available and torch.cuda.is_available():
336+
grad_norm = train_with_clipped_gradients({"opt_level": "O2"})
337+
self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)
338+
339+
def test_clip_stateful_loss(self):
340+
config = get_fast_test_task_config()
341+
config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
342+
config["grad_norm_clip"] = grad_norm_clip = 1
343+
task = build_task(config)
344+
task.set_use_gpu(False)
345+
task.prepare()
346+
347+
# set fake gradients with norm > grad_norm_clip
348+
for param in itertools.chain(
349+
task.base_model.parameters(), task.base_loss.parameters()
350+
):
351+
param.grad = 1.1 + torch.rand(param.shape)
352+
self.assertGreater(param.grad.norm(), grad_norm_clip)
353+
354+
task._clip_gradients(grad_norm_clip)
355+
356+
for param in itertools.chain(
357+
task.base_model.parameters(), task.base_loss.parameters()
358+
):
359+
self.assertLessEqual(param.grad.norm(), grad_norm_clip)

0 commit comments

Comments
 (0)