Skip to content

Commit a6c9329

Browse files
Jake VanderPlasKfacJaxDev
Jake VanderPlas
authored and
KfacJaxDev
committed
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
PiperOrigin-RevId: 568283267
1 parent e51dc5a commit a6c9329

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

kfac_jax/_src/optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def _coefficients_and_quad_change(
865865
else:
866866
quad_change = jnp.nan
867867

868-
return coefficients, quad_change
868+
return coefficients, quad_change # pytype: disable=bad-return-type # jnp-type
869869

870870
@utils.auto_scope_method
871871
def _update_damping(

kfac_jax/_src/patches_second_moment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def general_loop_body(i, image):
628628
else:
629629
wf_n = weighting_array[in_spec.n_axis]
630630
wf_spatial = [weighting_array.shape[a] for a in in_spec.spatial_axes]
631-
wf_sizes = in_spec.create_shape(wf_n, jnp.ones([]), *wf_spatial)
631+
wf_sizes = in_spec.create_shape(wf_n, jnp.ones([]), *wf_spatial) # pytype: disable=wrong-arg-types # jnp-type
632632
wf_i = _slice_array(weighting_array, index, wf_sizes)
633633
else:
634634
wf_i = None

kfac_jax/_src/utils/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def get_float_dtype_and_check_consistency(obj: ArrayTree) -> DType:
8181
else:
8282
raise ValueError("Non-float dtype detected.")
8383

84-
return dtype
84+
return dtype # pytype: disable=bad-return-type # jnp-type

0 commit comments

Comments
 (0)