Skip to content

Commit e2c3dfd

Browse files
author
Flax Authors
committed
Merge pull request #3463 from levskaya:vjp_fix
PiperOrigin-RevId: 579977891
2 parents 85245ad + 5ff36ba commit e2c3dfd

File tree

4 files changed

+316
-3
lines changed

4 files changed

+316
-3
lines changed

docs/api_reference/flax.linen/transformations.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Transformations
2828
map_variables
2929
jvp
3030
vjp
31+
grad
32+
value_and_grad
3133
custom_vjp
3234
while_loop
3335
cond

flax/linen/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@
144144
scan as scan,
145145
switch as switch,
146146
vjp as vjp,
147+
grad as grad,
148+
value_and_grad as value_and_grad,
147149
vmap as vmap,
148150
while_loop as while_loop,
149151
)

flax/linen/transforms.py

Lines changed: 183 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ def vjp(
10551055
vjp_variables: lift.CollectionFilter = 'params',
10561056
variables: lift.CollectionFilter = True,
10571057
rngs: lift.PRNGSequenceFilter = True,
1058-
) -> Tuple[Any, Any]:
1058+
multi_scope: bool = False,
1059+
):
10591060
"""A lifted version of ``jax.vjp``.
10601061
10611062
See ``jax.vjp`` for the unlifted vector-Jacobiam product (backward gradient).
@@ -1105,7 +1106,8 @@ def __call__(self, x, y):
11051106
variables: other variables collections that are available inside `fn` but
11061107
do not receive a cotangent.
11071108
rngs: the prngs that are available inside `fn`.
1108-
1109+
multi_scope: for Modules containing multiple scopes from outside modules passed in,
1110+
allow for variable gradients to be returned for multiple scopes instead of erroring.
11091111
Returns:
11101112
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
11111113
``primals_out`` is ``fn(*primals)``.
@@ -1121,7 +1123,7 @@ def __call__(self, x, y):
11211123
(fn,),
11221124
mdl,
11231125
*primals,
1124-
multi_scope=False,
1126+
multi_scope=multi_scope,
11251127
has_aux=has_aux,
11261128
reduce_axes=reduce_axes,
11271129
vjp_variables=vjp_variables,
@@ -1130,6 +1132,184 @@ def __call__(self, x, y):
11301132
)
11311133

11321134

1135+
def value_and_grad(
1136+
fn: Callable[..., Any],
1137+
mdl: Module,
1138+
*primals,
1139+
has_aux: bool = False,
1140+
reduce_axes=(),
1141+
variables: lift.CollectionFilter = True,
1142+
rngs: lift.PRNGSequenceFilter = True,
1143+
):
1144+
"""A limited, lifted equivalent of ``jax.value_and_grad``.
1145+
1146+
Note that for this convenience function, gradients are only calculated for
1147+
the function inputs, and not with respect to any module variables. The
1148+
target function must return a scalar-valued output. For a more general
1149+
lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobiam product.
1150+
1151+
Example::
1152+
1153+
class LearnScale(nn.Module):
1154+
@nn.compact
1155+
def __call__(self, x, y):
1156+
p = self.param('scale', nn.initializers.zeros_init(), ())
1157+
return p * x * y
1158+
1159+
class Foo(nn.Module):
1160+
@nn.compact
1161+
def __call__(self, x, y):
1162+
z, (x_grad, y_grad) = nn.value_and_grad(
1163+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
1164+
return z, x_grad, y_grad
1165+
1166+
Args:
1167+
fn: Function to be differentiated. Its arguments should be arrays, scalars,
1168+
or standard Python containers of arrays or scalars. It should return an
1169+
array, scalar, or standard Python container of arrays or scalars. It will
1170+
receive the scope and primals as arguments.
1171+
mdl: The module of which the variables will be differentiated.
1172+
*primals: A sequence of primal values at which the Jacobian of ``fn``
1173+
should be evaluated. The length of ``primals`` should be equal to the
1174+
number of positional parameters to ``fn``. Each primal value should be a
1175+
tuple of arrays, scalar, or standard Python containers thereof.
1176+
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
1177+
first element is considered the output of the mathematical function to be
1178+
differentiated and the second element is auxiliary data. Default False.
1179+
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
1180+
``fn`` implicitly broadcasts a value over that axis, the backward pass
1181+
will perform a ``psum`` of the corresponding gradient. Otherwise, the
1182+
grad will be per-example over named axes. For example, if ``'batch'``
1183+
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
1184+
create a grad function that sums over the batch while ``grad(f, *args)``
1185+
will create a per-example grad.
1186+
variables: variables collections that are available inside `fn` but
1187+
do not receive a cotangent.
1188+
rngs: the prngs that are available inside `fn`.
1189+
Returns:
1190+
If ``has_aux`` is ``False``, returns a ``primals_out, grads`` pair, where
1191+
``primals_out`` is ``fn(*primals)``. ``grads`` are the gradients for the
1192+
corresponding primals and do not include the gradients for module variables.
1193+
If ``has_aux`` is ``True``, returns a
1194+
``(primals_out, aux), grads`` tuple where ``aux`` is the auxiliary data
1195+
returned by ``fn``.
1196+
"""
1197+
1198+
vjp_partial = functools.partial(
1199+
vjp,
1200+
fn,
1201+
mdl,
1202+
*primals,
1203+
has_aux=has_aux,
1204+
reduce_axes=reduce_axes,
1205+
vjp_variables=False,
1206+
variables=variables,
1207+
rngs=rngs,
1208+
multi_scope=True,
1209+
)
1210+
1211+
if has_aux:
1212+
out, vjp_fun, aux = vjp_partial()
1213+
if out.shape != ():
1214+
raise ValueError(
1215+
'grad can only work on functions with '
1216+
f'scalar-valued outputs. out shape={out.shape}'
1217+
)
1218+
_, *argument_grads = vjp_fun(jax.numpy.ones_like(out))
1219+
return (out, aux), argument_grads
1220+
else:
1221+
out, vjp_fun = vjp_partial()
1222+
if out.shape != ():
1223+
raise ValueError(
1224+
'grad can only work on functions with '
1225+
f'scalar-valued outputs. out shape={out.shape}'
1226+
)
1227+
_, *argument_grads = vjp_fun(jax.numpy.ones_like(out))
1228+
return out, argument_grads
1229+
1230+
1231+
def grad(
1232+
fn: Callable[..., Any],
1233+
mdl: Module,
1234+
*primals,
1235+
has_aux: bool = False,
1236+
reduce_axes=(),
1237+
variables: lift.CollectionFilter = True,
1238+
rngs: lift.PRNGSequenceFilter = True,
1239+
):
1240+
"""A limited, lifted equivalent of ``jax.grad``.
1241+
1242+
Note that for this convenience function, gradients are only calculated for
1243+
the function inputs, and not with respect to any module variables. The
1244+
target function must return a scalar-valued output. For a more general
1245+
lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobiam product.
1246+
1247+
Example::
1248+
1249+
class LearnScale(nn.Module):
1250+
@nn.compact
1251+
def __call__(self, x, y):
1252+
p = self.param('scale', nn.initializers.zeros_init(), ())
1253+
return p * x * y
1254+
1255+
class Foo(nn.Module):
1256+
@nn.compact
1257+
def __call__(self, x, y):
1258+
x_grad, y_grad = nn.grad(
1259+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
1260+
return x_grad, y_grad
1261+
1262+
Args:
1263+
fn: Function to be differentiated. Its arguments should be arrays, scalars,
1264+
or standard Python containers of arrays or scalars. It should return an
1265+
array, scalar, or standard Python container of arrays or scalars. It will
1266+
receive the scope and primals as arguments.
1267+
mdl: The module of which the variables will be differentiated.
1268+
*primals: A sequence of primal values at which the Jacobian of ``fn``
1269+
should be evaluated. The length of ``primals`` should be equal to the
1270+
number of positional parameters to ``fn``. Each primal value should be a
1271+
tuple of arrays, scalar, or standard Python containers thereof.
1272+
has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
1273+
first element is considered the output of the mathematical function to be
1274+
differentiated and the second element is auxiliary data. Default False.
1275+
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
1276+
``fn`` implicitly broadcasts a value over that axis, the backward pass
1277+
will perform a ``psum`` of the corresponding gradient. Otherwise, the
1278+
grad will be per-example over named axes. For example, if ``'batch'``
1279+
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
1280+
create a grad function that sums over the batch while ``grad(f, *args)``
1281+
will create a per-example grad.
1282+
variables: variables collections that are available inside `fn` but
1283+
do not receive a cotangent.
1284+
rngs: the prngs that are available inside `fn`.
1285+
Returns:
1286+
If ``has_aux`` is ``False``, returns ``grads``, where ``grads`` are the
1287+
gradients for the corresponding primals and do not include the gradients
1288+
for module variables.
1289+
If ``has_aux`` is ``True``, returns a
1290+
``(grads, aux)`` tuple where ``aux`` is the auxiliary data
1291+
returned by ``fn``.
1292+
"""
1293+
1294+
value_and_grad_partial = functools.partial(
1295+
value_and_grad,
1296+
fn,
1297+
mdl,
1298+
*primals,
1299+
has_aux=has_aux,
1300+
reduce_axes=reduce_axes,
1301+
variables=variables,
1302+
rngs=rngs,
1303+
)
1304+
1305+
if has_aux:
1306+
(_, aux), argument_grads = value_and_grad_partial()
1307+
return argument_grads, aux
1308+
else:
1309+
_, argument_grads = value_and_grad_partial()
1310+
return argument_grads
1311+
1312+
11331313
def jvp(
11341314
fn: Callable[..., Any],
11351315
mdl: Module,

tests/linen/linen_transforms_test.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,6 +2007,135 @@ def __call__(self, x):
20072007
)
20082008
self.assertEqual(jax.tree_map(jnp.shape, vs), outer_expect)
20092009

2010+
def test_grad_simple(self):
2011+
class LearnScale(nn.Module):
2012+
@nn.compact
2013+
def __call__(self, x, y):
2014+
p = self.param('scale', nn.initializers.ones_init(), ())
2015+
return jnp.sum(p * x * y)
2016+
2017+
class Foo(nn.Module):
2018+
@nn.compact
2019+
def __call__(self, x, y):
2020+
x_grad, y_grad = nn.grad(
2021+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y
2022+
)
2023+
return x_grad, y_grad
2024+
2025+
x = random.uniform(random.key(1), (4,))
2026+
y = random.uniform(random.key(2), (4,))
2027+
vs = Foo().init(random.key(0), x, y)
2028+
2029+
x_grad, y_grad = Foo().apply(vs, x, y)
2030+
self.assertTrue(tree_allclose(x_grad, y))
2031+
self.assertTrue(tree_allclose(y_grad, x))
2032+
2033+
def test_grad_simple_with_aux(self):
2034+
class LearnScale(nn.Module):
2035+
@nn.compact
2036+
def __call__(self, x, y):
2037+
p = self.param('scale', nn.initializers.ones_init(), ())
2038+
return jnp.sum(p * x * y), p
2039+
2040+
class Foo(nn.Module):
2041+
@nn.compact
2042+
def __call__(self, x, y):
2043+
(x_grad, y_grad), aux = nn.grad(
2044+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y, has_aux=True
2045+
)
2046+
return aux, x_grad, y_grad
2047+
2048+
x = random.uniform(random.key(1), (4,))
2049+
y = random.uniform(random.key(2), (4,))
2050+
vs = Foo().init(random.key(0), x, y)
2051+
2052+
aux, x_grad, y_grad = Foo().apply(vs, x, y)
2053+
self.assertTrue(tree_allclose(x_grad, y))
2054+
self.assertTrue(tree_allclose(y_grad, x))
2055+
self.assertTrue(tree_allclose(aux, vs['params']['LearnScale_0']['scale']))
2056+
2057+
def test_value_and_grad_simple(self):
2058+
class LearnScale(nn.Module):
2059+
@nn.compact
2060+
def __call__(self, x, y):
2061+
p = self.param('scale', nn.initializers.ones_init(), ())
2062+
return jnp.sum(p * x * y)
2063+
2064+
class Foo(nn.Module):
2065+
@nn.compact
2066+
def __call__(self, x, y):
2067+
z, (x_grad, y_grad) = nn.value_and_grad(
2068+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y
2069+
)
2070+
return z, x_grad, y_grad
2071+
2072+
x = random.uniform(random.key(1), (4,))
2073+
y = random.uniform(random.key(2), (4,))
2074+
vs = Foo().init(random.key(0), x, y)
2075+
2076+
z, x_grad, y_grad = Foo().apply(vs, x, y)
2077+
self.assertTrue(tree_allclose(x_grad, y))
2078+
self.assertTrue(tree_allclose(y_grad, x))
2079+
2080+
def test_value_and_grad_simple_with_aux(self):
2081+
class LearnScale(nn.Module):
2082+
@nn.compact
2083+
def __call__(self, x, y):
2084+
p = self.param('scale', nn.initializers.ones_init(), ())
2085+
return jnp.sum(p * x * y), p
2086+
2087+
class Foo(nn.Module):
2088+
@nn.compact
2089+
def __call__(self, x, y):
2090+
(z, aux), (x_grad, y_grad) = nn.value_and_grad(
2091+
lambda mdl, x, y: mdl(x, y), LearnScale(), x, y, has_aux=True
2092+
)
2093+
return z, aux, x_grad, y_grad
2094+
2095+
x = random.uniform(random.key(1), (4,))
2096+
y = random.uniform(random.key(2), (4,))
2097+
vs = Foo().init(random.key(0), x, y)
2098+
2099+
z, aux, x_grad, y_grad = Foo().apply(vs, x, y)
2100+
self.assertTrue(tree_allclose(x_grad, y))
2101+
self.assertTrue(tree_allclose(y_grad, x))
2102+
self.assertTrue(tree_allclose(aux, vs['params']['LearnScale_0']['scale']))
2103+
2104+
def test_value_and_grad_multiscope(self):
2105+
class Foo(nn.Module):
2106+
bar: nn.Module
2107+
2108+
@nn.compact
2109+
def __call__(self, x, y):
2110+
def fn(self, x, y):
2111+
qup = nn.Dense(y.shape[-1])
2112+
delta = y - self.bar(qup(x))
2113+
return jnp.sum(delta**2)
2114+
2115+
z, (x_grad, y_grad) = nn.value_and_grad(fn, self, x, y)
2116+
return z, x_grad, y_grad
2117+
2118+
class Baz(nn.Module):
2119+
@nn.compact
2120+
def __call__(self, x, y):
2121+
bar = nn.Dense(y.shape[-1])
2122+
return Foo(bar=bar)(x, y)
2123+
2124+
x = random.uniform(random.key(1), (4,))
2125+
y = random.uniform(random.key(2), (4,))
2126+
vs = Baz().init(random.key(0), x, y)
2127+
z, x_grad, y_grad = Baz().apply(vs, x, y)
2128+
2129+
def comparison_fn(x, y):
2130+
w1 = vs['params']['Foo_0']['Dense_0']['kernel']
2131+
w2 = vs['params']['Dense_0']['kernel']
2132+
delta = y - jnp.dot(jnp.dot(x, w1), w2)
2133+
return jnp.sum(delta**2)
2134+
2135+
self.assertTrue(tree_allclose(comparison_fn(x, y), z))
2136+
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad))
2137+
self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad))
2138+
20102139

20112140
if __name__ == '__main__':
20122141
absltest.main()

0 commit comments

Comments
 (0)