@@ -827,19 +827,16 @@ def _maybe_update_inverse_cache(
827
827
pmap_axis_name = self .pmap_axis_name ,
828
828
)
829
829
830
- # TODO(jamesmartens, botev): It's ugly that this method implements the norm
831
- # constraint on top of computing the preconditioned gradient. Should refactor.
832
830
@utils .staged
833
831
def _compute_preconditioned_gradient (
834
832
self ,
835
833
state : "Optimizer.State" ,
836
834
grads : Params ,
837
- coefficient : Optional [Array ],
838
835
damping : Array ,
839
- ) -> Tuple [ Params , Optional [ Array ]] :
840
- """Computes the preconditioned gradient, maybe applying norm-constraint ."""
836
+ ) -> Params :
837
+ """Computes the preconditioned gradient."""
841
838
842
- preconditioned_grads = self .estimator .multiply_inverse (
839
+ return self .estimator .multiply_inverse (
843
840
state = state .estimator_state ,
844
841
parameter_structured_vector = grads ,
845
842
identity_weight = (self .l2_reg + damping ) * self ._precon_damping_mult ,
@@ -849,23 +846,25 @@ def _compute_preconditioned_gradient(
849
846
norm_to_scale_identity_weight_per_block = self ._norm_to_scale_identity_weight_per_block ,
850
847
)
851
848
852
- if self ._norm_constraint :
853
-
854
- assert not self ._use_adaptive_learning_rate
855
- assert coefficient is not None
849
+ @utils .staged
850
+ def _maybe_apply_norm_constraint (
851
+ self , grads : Params , preconditioned_grads : Params , coefficient : Array
852
+ ) -> Tuple [Params , Optional [Params ]]:
853
+ """Scales precon grad to have F-weighted norm <= norm_constraint."""
854
+ if self ._norm_constraint is None :
855
+ return preconditioned_grads , None
856
856
857
- sq_norm_grads = utils . inner_product ( preconditioned_grads , grads )
857
+ assert not self . _use_adaptive_learning_rate
858
858
859
- sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2
859
+ sq_norm_grads = utils .inner_product (preconditioned_grads , grads )
860
+ sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2
860
861
861
- max_coefficient = jnp .sqrt (self ._norm_constraint / sq_norm_scaled_grads )
862
- coefficient = jnp .minimum (max_coefficient , 1 )
862
+ max_coefficient = jnp .sqrt (self ._norm_constraint / sq_norm_scaled_grads )
863
+ coefficient = jnp .minimum (max_coefficient , 1 )
863
864
864
- preconditioned_grads = utils .scalar_mul (preconditioned_grads , coefficient )
865
- else :
866
- sq_norm_scaled_grads = None
865
+ precon_grad = utils .scalar_mul (preconditioned_grads , coefficient )
867
866
868
- return preconditioned_grads , sq_norm_scaled_grads
867
+ return precon_grad , sq_norm_scaled_grads
869
868
870
869
def _compute_quad_change_for_damping (
871
870
self ,
@@ -1130,10 +1129,17 @@ def _step(
1130
1129
state = self ._maybe_update_inverse_cache (state , damping )
1131
1130
1132
1131
# Compute proposed directions
1132
+ preconditioned_gradient = self ._compute_preconditioned_gradient (
1133
+ state , grads , damping
1134
+ )
1135
+
1136
+ # constrain the norms
1133
1137
preconditioned_gradient , sq_norm_scaled_grads = (
1134
- self ._compute_preconditioned_gradient (
1135
- state , grads , learning_rate , damping )
1138
+ self ._maybe_apply_norm_constraint (
1139
+ grads , preconditioned_gradient , learning_rate ,
1140
+ )
1136
1141
)
1142
+
1137
1143
vectors = (preconditioned_gradient , state .velocities )
1138
1144
1139
1145
# Compute the coefficients for the vectors
0 commit comments