|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import copy
|
| 8 | +import itertools |
8 | 9 | import shutil
|
9 | 10 | import tempfile
|
10 | 11 | import unittest
|
|
17 | 18 | )
|
18 | 19 |
|
19 | 20 | import torch
|
| 21 | +import torch.nn as nn |
20 | 22 | from classy_vision.dataset import build_dataset
|
21 | 23 | from classy_vision.generic.distributed_util import is_distributed_training_run
|
22 | 24 | from classy_vision.generic.util import get_checkpoint_dict
|
23 | 25 | from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook
|
24 | 26 | from classy_vision.losses import ClassyLoss, build_loss, register_loss
|
25 |
| -from classy_vision.models import build_model |
26 |
| -from classy_vision.optim import build_optimizer |
| 27 | +from classy_vision.models import ClassyModel, build_model |
| 28 | +from classy_vision.optim import SGD, build_optimizer |
27 | 29 | from classy_vision.tasks import ClassificationTask, build_task
|
28 | 30 | from classy_vision.trainer import LocalTrainer
|
29 | 31 |
|
@@ -284,3 +286,74 @@ def test_get_classy_state_on_loss(self):
|
284 | 286 | task = build_task(config)
|
285 | 287 | task.prepare()
|
286 | 288 | self.assertIn("alpha", task.get_classy_state()["loss"])
|
| 289 | + |
| 290 | + def test_gradient_clipping(self): |
| 291 | + # Generate a simple model that has a very high gradient w.r.t. to this |
| 292 | + # loss |
| 293 | + class SimpleModel(ClassyModel): |
| 294 | + def __init__(self): |
| 295 | + super().__init__() |
| 296 | + self.param = nn.Parameter(torch.tensor(5.0), requires_grad=True) |
| 297 | + |
| 298 | + def forward(self, x): |
| 299 | + return x + self.param |
| 300 | + |
| 301 | + @classmethod |
| 302 | + def from_config(cls): |
| 303 | + return cls() |
| 304 | + |
| 305 | + class SimpleLoss(nn.Module): |
| 306 | + def forward(self, out, target): |
| 307 | + return out.pow(2).mean() |
| 308 | + |
| 309 | + apex_available = True |
| 310 | + try: |
| 311 | + import apex # noqa F401 |
| 312 | + except ImportError: |
| 313 | + apex_available = False |
| 314 | + |
| 315 | + def train_with_clipped_gradients(amp_args=None): |
| 316 | + task = build_task(get_fast_test_task_config()) |
| 317 | + task.set_num_epochs(1) |
| 318 | + task.set_model(SimpleModel()) |
| 319 | + task.set_loss(SimpleLoss()) |
| 320 | + task.set_meters([]) |
| 321 | + task.set_use_gpu(torch.cuda.is_available()) |
| 322 | + task.set_clip_grad_norm(0.5) |
| 323 | + task.set_amp_args(amp_args) |
| 324 | + |
| 325 | + task.set_optimizer(SGD(lr=1)) |
| 326 | + |
| 327 | + trainer = LocalTrainer() |
| 328 | + trainer.train(task) |
| 329 | + |
| 330 | + return task.model.param.grad.norm() |
| 331 | + |
| 332 | + grad_norm = train_with_clipped_gradients(None) |
| 333 | + self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2) |
| 334 | + |
| 335 | + if apex_available and torch.cuda.is_available(): |
| 336 | + grad_norm = train_with_clipped_gradients({"opt_level": "O2"}) |
| 337 | + self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2) |
| 338 | + |
| 339 | + def test_clip_stateful_loss(self): |
| 340 | + config = get_fast_test_task_config() |
| 341 | + config["loss"] = {"name": "test_stateful_loss", "in_plane": 256} |
| 342 | + config["grad_norm_clip"] = grad_norm_clip = 1 |
| 343 | + task = build_task(config) |
| 344 | + task.set_use_gpu(False) |
| 345 | + task.prepare() |
| 346 | + |
| 347 | + # set fake gradients with norm > grad_norm_clip |
| 348 | + for param in itertools.chain( |
| 349 | + task.base_model.parameters(), task.base_loss.parameters() |
| 350 | + ): |
| 351 | + param.grad = 1.1 + torch.rand(param.shape) |
| 352 | + self.assertGreater(param.grad.norm(), grad_norm_clip) |
| 353 | + |
| 354 | + task._clip_gradients(grad_norm_clip) |
| 355 | + |
| 356 | + for param in itertools.chain( |
| 357 | + task.base_model.parameters(), task.base_loss.parameters() |
| 358 | + ): |
| 359 | + self.assertLessEqual(param.grad.norm(), grad_norm_clip) |
0 commit comments