@@ -130,9 +130,12 @@ def __init__(
130
130
value_and_grad_func: Python callable. The function should return the value
131
131
of the loss to be optimized and its gradients. If the argument
132
132
``value_func_has_aux`` is ``False`` then the interface should be:
133
- ``loss, loss_grads = value_and_grad_func(params, batch )``. If
133
+ ``loss, loss_grads = value_and_grad_func(*args )``. If
134
134
``value_func_has_aux`` is ``True`` then the interface should be:
135
- ``(loss, aux), loss_grads = value_and_grad_func(params, batch)``.
135
+ ``(loss, aux), loss_grads = value_and_grad_func(*args)``. Here ``args``
136
+ is ``(params, func_state, rng, batch)``, with ``rng`` omitted if
137
+ ``value_func_has_rng`` is ``False``, and with ``func_state`` omitted if
138
+ ``value_func_has_state`` is ``False``.
136
139
l2_reg: Scalar. Set this value to tell the optimizer what L2
137
140
regularization coefficient you are using (if any). Note the coefficient
138
141
appears in the regularizer as ``coeff / 2 * sum(param**2)``. This adds
@@ -259,7 +262,7 @@ def __init__(
259
262
specifying whether the batch is replicated over multiple devices and
260
263
returns the batch size for a single device. (Default: ``None``)
261
264
pmap_axis_name: String. The name of the pmap axis to use when
262
- ``multi_device`` is set to True. (Default: ``curvature_axis ``)
265
+ ``multi_device`` is set to True. (Default: ``kfac_axis ``)
263
266
forbid_setting_attributes_after_finalize: Boolean. By default after the
264
267
object is finalized, you can not set any of its properties. This is done
265
268
in order to protect the user from making changes to the object
0 commit comments