@@ -1462,6 +1462,50 @@ def __call__(self, x):
1462
1462
self .scope .put_variable (col , name , xs )
1463
1463
return True
1464
1464
1465
+ def perturb (self , name : str , value : T , collection : str = 'perturbations' ) -> T :
1466
+ """Add an zero-value variable ('perturbation') to the intermediate value.
1467
+
1468
+ The gradient of `value` would be the same as the gradient of this
1469
+ perturbation variable. Therefore, if you define your loss function with
1470
+ both params and perturbations as standalone arguments, you can get the
1471
+ intermediate gradients of `value` by running `jax.grad` on the perturbation
1472
+ argument.
1473
+
1474
+ Note: this is an experimental API and may be tweaked later for better
1475
+ performance and usability.
1476
+ At its current stage, it creates extra dummy variables that occupies extra
1477
+ memory space. Use it only to debug gradients in training.
1478
+
1479
+ Example::
1480
+
1481
+ import jax
1482
+ import jax.numpy as jnp
1483
+ import flax.linen as nn
1484
+
1485
+ class Foo(nn.Module):
1486
+ @nn.compact
1487
+ def __call__(self, x):
1488
+ x = nn.Dense(3)(x)
1489
+ x = self.perturb('dense3', x)
1490
+ return nn.Dense(2)(x)
1491
+
1492
+ def loss(params, perturbations, inputs, targets):
1493
+ variables = {'params': params, 'perturbations': perturbations}
1494
+ preds = model.apply(variables, inputs)
1495
+ return jnp.square(preds - targets).mean()
1496
+
1497
+ x = jnp.ones((2, 9))
1498
+ y = jnp.ones((2, 2))
1499
+ model = Foo()
1500
+ variables = model.init(jax.random.PRNGKey(0), x)
1501
+ intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
1502
+ print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
1503
+ # [-1.456924 -0.44332537 0.02422847]]
1504
+
1505
+ """
1506
+ value += self .variable (collection , name , lambda : jnp .zeros_like (value )).value
1507
+ return value
1508
+
1465
1509
def tabulate (
1466
1510
self ,
1467
1511
rngs : Union [PRNGKey , RNGSequences ],
0 commit comments