Skip to content

Commit fa9d5c7

Browse files
james-martensKfacJaxDev
authored and
KfacJaxDev
committed
Revising docstring for optimizer class.
PiperOrigin-RevId: 466954884
1 parent ca07235 commit fa9d5c7

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

kfac_jax/_src/optimizer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,12 @@ def __init__(
130130
value_and_grad_func: Python callable. The function should return the value
131131
of the loss to be optimized and its gradients. If the argument
132132
``value_func_has_aux`` is ``False`` then the interface should be:
133-
``loss, loss_grads = value_and_grad_func(params, batch)``. If
133+
``loss, loss_grads = value_and_grad_func(*args)``. If
134134
``value_func_has_aux`` is ``True`` then the interface should be:
135-
``(loss, aux), loss_grads = value_and_grad_func(params, batch)``.
135+
``(loss, aux), loss_grads = value_and_grad_func(*args)``. Here ``args``
136+
is ``(params, func_state, rng, batch)``, with ``rng`` omitted if
137+
``value_func_has_rng`` is ``False``, and with ``func_state`` omitted if
138+
``value_func_has_state`` is ``False``.
136139
l2_reg: Scalar. Set this value to tell the optimizer what L2
137140
regularization coefficient you are using (if any). Note the coefficient
138141
appears in the regularizer as ``coeff / 2 * sum(param**2)``. This adds
@@ -259,7 +262,7 @@ def __init__(
259262
specifying whether the batch is replicated over multiple devices and
260263
returns the batch size for a single device. (Default: ``None``)
261264
pmap_axis_name: String. The name of the pmap axis to use when
262-
``multi_device`` is set to True. (Default: ``curvature_axis``)
265+
``multi_device`` is set to True. (Default: ``kfac_axis``)
263266
forbid_setting_attributes_after_finalize: Boolean. By default after the
264267
object is finalized, you can not set any of its properties. This is done
265268
in order to protect the user from making changes to the object

0 commit comments

Comments
 (0)