Skip to content

Commit a76cb21

Browse files
KfacJaxDevKfacJaxDev
KfacJaxDev
authored and
KfacJaxDev
committed
Remove gradient normalization from the preconditioning function
PiperOrigin-RevId: 620240464
1 parent fd99145 commit a76cb21

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

kfac_jax/_src/optimizer.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -827,19 +827,16 @@ def _maybe_update_inverse_cache(
827827
pmap_axis_name=self.pmap_axis_name,
828828
)
829829

830-
# TODO(jamesmartens, botev): It's ugly that this method implements the norm
831-
# constraint on top of computing the preconditioned gradient. Should refactor.
832830
@utils.staged
833831
def _compute_preconditioned_gradient(
834832
self,
835833
state: "Optimizer.State",
836834
grads: Params,
837-
coefficient: Optional[Array],
838835
damping: Array,
839-
) -> Tuple[Params, Optional[Array]]:
840-
"""Computes the preconditioned gradient, maybe applying norm-constraint."""
836+
) -> Params:
837+
"""Computes the preconditioned gradient."""
841838

842-
preconditioned_grads = self.estimator.multiply_inverse(
839+
return self.estimator.multiply_inverse(
843840
state=state.estimator_state,
844841
parameter_structured_vector=grads,
845842
identity_weight=(self.l2_reg + damping) * self._precon_damping_mult,
@@ -849,23 +846,25 @@ def _compute_preconditioned_gradient(
849846
norm_to_scale_identity_weight_per_block=self._norm_to_scale_identity_weight_per_block,
850847
)
851848

852-
if self._norm_constraint:
853-
854-
assert not self._use_adaptive_learning_rate
855-
assert coefficient is not None
849+
@utils.staged
850+
def _maybe_apply_norm_constraint(
851+
self, grads: Params, preconditioned_grads: Params, coefficient: Array
852+
) -> Tuple[Params, Optional[Params]]:
853+
"""Scales precon grad to have F-weighted norm <= norm_constraint."""
854+
if self._norm_constraint is None:
855+
return preconditioned_grads, None
856856

857-
sq_norm_grads = utils.inner_product(preconditioned_grads, grads)
857+
assert not self._use_adaptive_learning_rate
858858

859-
sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2
859+
sq_norm_grads = utils.inner_product(preconditioned_grads, grads)
860+
sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2
860861

861-
max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_scaled_grads)
862-
coefficient = jnp.minimum(max_coefficient, 1)
862+
max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_scaled_grads)
863+
coefficient = jnp.minimum(max_coefficient, 1)
863864

864-
preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient)
865-
else:
866-
sq_norm_scaled_grads = None
865+
precon_grad = utils.scalar_mul(preconditioned_grads, coefficient)
867866

868-
return preconditioned_grads, sq_norm_scaled_grads
867+
return precon_grad, sq_norm_scaled_grads
869868

870869
def _compute_quad_change_for_damping(
871870
self,
@@ -1130,10 +1129,17 @@ def _step(
11301129
state = self._maybe_update_inverse_cache(state, damping)
11311130

11321131
# Compute proposed directions
1132+
preconditioned_gradient = self._compute_preconditioned_gradient(
1133+
state, grads, damping
1134+
)
1135+
1136+
# constrain the norms
11331137
preconditioned_gradient, sq_norm_scaled_grads = (
1134-
self._compute_preconditioned_gradient(
1135-
state, grads, learning_rate, damping)
1138+
self._maybe_apply_norm_constraint(
1139+
grads, preconditioned_gradient, learning_rate,
1140+
)
11361141
)
1142+
11371143
vectors = (preconditioned_gradient, state.velocities)
11381144

11391145
# Compute the coefficients for the vectors

0 commit comments

Comments
 (0)