Skip to content

Commit cc88a73

Browse files
author
Flax Authors
committed
Merge pull request #2476 from IvyZX:perturb
PiperOrigin-RevId: 476490270
2 parents b344022 + 8c0a60a commit cc88a73

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

flax/linen/module.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,6 +1462,50 @@ def __call__(self, x):
14621462
self.scope.put_variable(col, name, xs)
14631463
return True
14641464

1465+
def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
1466+
"""Add an zero-value variable ('perturbation') to the intermediate value.
1467+
1468+
The gradient of `value` would be the same as the gradient of this
1469+
perturbation variable. Therefore, if you define your loss function with
1470+
both params and perturbations as standalone arguments, you can get the
1471+
intermediate gradients of `value` by running `jax.grad` on the perturbation
1472+
argument.
1473+
1474+
Note: this is an experimental API and may be tweaked later for better
1475+
performance and usability.
1476+
At its current stage, it creates extra dummy variables that occupies extra
1477+
memory space. Use it only to debug gradients in training.
1478+
1479+
Example::
1480+
1481+
import jax
1482+
import jax.numpy as jnp
1483+
import flax.linen as nn
1484+
1485+
class Foo(nn.Module):
1486+
@nn.compact
1487+
def __call__(self, x):
1488+
x = nn.Dense(3)(x)
1489+
x = self.perturb('dense3', x)
1490+
return nn.Dense(2)(x)
1491+
1492+
def loss(params, perturbations, inputs, targets):
1493+
variables = {'params': params, 'perturbations': perturbations}
1494+
preds = model.apply(variables, inputs)
1495+
return jnp.square(preds - targets).mean()
1496+
1497+
x = jnp.ones((2, 9))
1498+
y = jnp.ones((2, 2))
1499+
model = Foo()
1500+
variables = model.init(jax.random.PRNGKey(0), x)
1501+
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
1502+
print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
1503+
# [-1.456924 -0.44332537 0.02422847]]
1504+
1505+
"""
1506+
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
1507+
return value
1508+
14651509
def tabulate(
14661510
self,
14671511
rngs: Union[PRNGKey, RNGSequences],

tests/linen/linen_module_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,29 @@ def __call__(self, x):
14171417
_, state = Foo().apply({}, 1, capture_intermediates=fn)
14181418
self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}})
14191419

1420+
def test_perturb(self):
1421+
class Foo(nn.Module):
1422+
@nn.compact
1423+
def __call__(self, x):
1424+
x = nn.Dense(10)(x)
1425+
x = self.perturb('before_multiply', x)
1426+
x = 4 * x
1427+
x = self.perturb('after_multiply', x)
1428+
return x
1429+
1430+
def loss(params, perturbations, inputs, targets):
1431+
variables = {'params': params, 'perturbations': perturbations}
1432+
preds = Foo().apply(variables, inputs)
1433+
return jnp.square(preds - targets).mean()
1434+
1435+
x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, ))
1436+
y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10, ))
1437+
variables = Foo().init(jax.random.PRNGKey(0), x)
1438+
pred = Foo().apply(variables, x)
1439+
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
1440+
# activation * 4 so reverse gradient also * 4
1441+
self.assertTrue(all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']))
1442+
14201443
def test_functional_apply(self):
14211444

14221445
class Foo(nn.Module):

0 commit comments

Comments
 (0)