@@ -1055,7 +1055,8 @@ def vjp(
1055
1055
vjp_variables : lift .CollectionFilter = 'params' ,
1056
1056
variables : lift .CollectionFilter = True ,
1057
1057
rngs : lift .PRNGSequenceFilter = True ,
1058
- ) -> Tuple [Any , Any ]:
1058
+ multi_scope : bool = False ,
1059
+ ):
1059
1060
"""A lifted version of ``jax.vjp``.
1060
1061
1061
1062
See ``jax.vjp`` for the unlifted vector-Jacobiam product (backward gradient).
@@ -1105,7 +1106,8 @@ def __call__(self, x, y):
1105
1106
variables: other variables collections that are available inside `fn` but
1106
1107
do not receive a cotangent.
1107
1108
rngs: the prngs that are available inside `fn`.
1108
-
1109
+ multi_scope: for Modules containing multiple scopes from outside modules passed in,
1110
+ allow for variable gradients to be returned for multiple scopes instead of erroring.
1109
1111
Returns:
1110
1112
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
1111
1113
``primals_out`` is ``fn(*primals)``.
@@ -1121,7 +1123,7 @@ def __call__(self, x, y):
1121
1123
(fn ,),
1122
1124
mdl ,
1123
1125
* primals ,
1124
- multi_scope = False ,
1126
+ multi_scope = multi_scope ,
1125
1127
has_aux = has_aux ,
1126
1128
reduce_axes = reduce_axes ,
1127
1129
vjp_variables = vjp_variables ,
@@ -1130,6 +1132,184 @@ def __call__(self, x, y):
1130
1132
)
1131
1133
1132
1134
1135
+ def value_and_grad (
1136
+ fn : Callable [..., Any ],
1137
+ mdl : Module ,
1138
+ * primals ,
1139
+ has_aux : bool = False ,
1140
+ reduce_axes = (),
1141
+ variables : lift .CollectionFilter = True ,
1142
+ rngs : lift .PRNGSequenceFilter = True ,
1143
+ ):
1144
+ """A limited, lifted equivalent of ``jax.value_and_grad``.
1145
+
1146
+ Note that for this convenience function, gradients are only calculated for
1147
+ the function inputs, and not with respect to any module variables. The
1148
+ target function must return a scalar-valued output. For a more general
1149
+ lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobiam product.
1150
+
1151
+ Example::
1152
+
1153
+ class LearnScale(nn.Module):
1154
+ @nn.compact
1155
+ def __call__(self, x, y):
1156
+ p = self.param('scale', nn.initializers.zeros_init(), ())
1157
+ return p * x * y
1158
+
1159
+ class Foo(nn.Module):
1160
+ @nn.compact
1161
+ def __call__(self, x, y):
1162
+ z, (x_grad, y_grad) = nn.value_and_grad(
1163
+ lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
1164
+ return z, x_grad, y_grad
1165
+
1166
+ Args:
1167
+ fn: Function to be differentiated. Its arguments should be arrays, scalars,
1168
+ or standard Python containers of arrays or scalars. It should return an
1169
+ array, scalar, or standard Python container of arrays or scalars. It will
1170
+ receive the scope and primals as arguments.
1171
+ mdl: The module of which the variables will be differentiated.
1172
+ *primals: A sequence of primal values at which the Jacobian of ``fn``
1173
+ should be evaluated. The length of ``primals`` should be equal to the
1174
+ number of positional parameters to ``fn``. Each primal value should be a
1175
+ tuple of arrays, scalar, or standard Python containers thereof.
1176
+ has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
1177
+ first element is considered the output of the mathematical function to be
1178
+ differentiated and the second element is auxiliary data. Default False.
1179
+ reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
1180
+ ``fn`` implicitly broadcasts a value over that axis, the backward pass
1181
+ will perform a ``psum`` of the corresponding gradient. Otherwise, the
1182
+ grad will be per-example over named axes. For example, if ``'batch'``
1183
+ is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
1184
+ create a grad function that sums over the batch while ``grad(f, *args)``
1185
+ will create a per-example grad.
1186
+ variables: variables collections that are available inside `fn` but
1187
+ do not receive a cotangent.
1188
+ rngs: the prngs that are available inside `fn`.
1189
+ Returns:
1190
+ If ``has_aux`` is ``False``, returns a ``primals_out, grads`` pair, where
1191
+ ``primals_out`` is ``fn(*primals)``. ``grads`` are the gradients for the
1192
+ corresponding primals and do not include the gradients for module variables.
1193
+ If ``has_aux`` is ``True``, returns a
1194
+ ``(primals_out, aux), grads`` tuple where ``aux`` is the auxiliary data
1195
+ returned by ``fn``.
1196
+ """
1197
+
1198
+ vjp_partial = functools .partial (
1199
+ vjp ,
1200
+ fn ,
1201
+ mdl ,
1202
+ * primals ,
1203
+ has_aux = has_aux ,
1204
+ reduce_axes = reduce_axes ,
1205
+ vjp_variables = False ,
1206
+ variables = variables ,
1207
+ rngs = rngs ,
1208
+ multi_scope = True ,
1209
+ )
1210
+
1211
+ if has_aux :
1212
+ out , vjp_fun , aux = vjp_partial ()
1213
+ if out .shape != ():
1214
+ raise ValueError (
1215
+ 'grad can only work on functions with '
1216
+ f'scalar-valued outputs. out shape={ out .shape } '
1217
+ )
1218
+ _ , * argument_grads = vjp_fun (jax .numpy .ones_like (out ))
1219
+ return (out , aux ), argument_grads
1220
+ else :
1221
+ out , vjp_fun = vjp_partial ()
1222
+ if out .shape != ():
1223
+ raise ValueError (
1224
+ 'grad can only work on functions with '
1225
+ f'scalar-valued outputs. out shape={ out .shape } '
1226
+ )
1227
+ _ , * argument_grads = vjp_fun (jax .numpy .ones_like (out ))
1228
+ return out , argument_grads
1229
+
1230
+
1231
+ def grad (
1232
+ fn : Callable [..., Any ],
1233
+ mdl : Module ,
1234
+ * primals ,
1235
+ has_aux : bool = False ,
1236
+ reduce_axes = (),
1237
+ variables : lift .CollectionFilter = True ,
1238
+ rngs : lift .PRNGSequenceFilter = True ,
1239
+ ):
1240
+ """A limited, lifted equivalent of ``jax.grad``.
1241
+
1242
+ Note that for this convenience function, gradients are only calculated for
1243
+ the function inputs, and not with respect to any module variables. The
1244
+ target function must return a scalar-valued output. For a more general
1245
+ lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobiam product.
1246
+
1247
+ Example::
1248
+
1249
+ class LearnScale(nn.Module):
1250
+ @nn.compact
1251
+ def __call__(self, x, y):
1252
+ p = self.param('scale', nn.initializers.zeros_init(), ())
1253
+ return p * x * y
1254
+
1255
+ class Foo(nn.Module):
1256
+ @nn.compact
1257
+ def __call__(self, x, y):
1258
+ x_grad, y_grad = nn.grad(
1259
+ lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
1260
+ return x_grad, y_grad
1261
+
1262
+ Args:
1263
+ fn: Function to be differentiated. Its arguments should be arrays, scalars,
1264
+ or standard Python containers of arrays or scalars. It should return an
1265
+ array, scalar, or standard Python container of arrays or scalars. It will
1266
+ receive the scope and primals as arguments.
1267
+ mdl: The module of which the variables will be differentiated.
1268
+ *primals: A sequence of primal values at which the Jacobian of ``fn``
1269
+ should be evaluated. The length of ``primals`` should be equal to the
1270
+ number of positional parameters to ``fn``. Each primal value should be a
1271
+ tuple of arrays, scalar, or standard Python containers thereof.
1272
+ has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the
1273
+ first element is considered the output of the mathematical function to be
1274
+ differentiated and the second element is auxiliary data. Default False.
1275
+ reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
1276
+ ``fn`` implicitly broadcasts a value over that axis, the backward pass
1277
+ will perform a ``psum`` of the corresponding gradient. Otherwise, the
1278
+ grad will be per-example over named axes. For example, if ``'batch'``
1279
+ is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
1280
+ create a grad function that sums over the batch while ``grad(f, *args)``
1281
+ will create a per-example grad.
1282
+ variables: variables collections that are available inside `fn` but
1283
+ do not receive a cotangent.
1284
+ rngs: the prngs that are available inside `fn`.
1285
+ Returns:
1286
+ If ``has_aux`` is ``False``, returns ``grads``, where ``grads`` are the
1287
+ gradients for the corresponding primals and do not include the gradients
1288
+ for module variables.
1289
+ If ``has_aux`` is ``True``, returns a
1290
+ ``(grads, aux)`` tuple where ``aux`` is the auxiliary data
1291
+ returned by ``fn``.
1292
+ """
1293
+
1294
+ value_and_grad_partial = functools .partial (
1295
+ value_and_grad ,
1296
+ fn ,
1297
+ mdl ,
1298
+ * primals ,
1299
+ has_aux = has_aux ,
1300
+ reduce_axes = reduce_axes ,
1301
+ variables = variables ,
1302
+ rngs = rngs ,
1303
+ )
1304
+
1305
+ if has_aux :
1306
+ (_ , aux ), argument_grads = value_and_grad_partial ()
1307
+ return argument_grads , aux
1308
+ else :
1309
+ _ , argument_grads = value_and_grad_partial ()
1310
+ return argument_grads
1311
+
1312
+
1133
1313
def jvp (
1134
1314
fn : Callable [..., Any ],
1135
1315
mdl : Module ,
0 commit comments