Skip to content

Commit b7d430f

Browse files
committed
jnp.repeat: don't cast repeats to array, as they must be static.
1 parent e02faab commit b7d430f

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

jax/_src/numpy/lax_numpy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6989,9 +6989,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
69896989
[3, 3, 4, 4, 4, 4, 4]], dtype=int32)
69906990
"""
69916991
if core.is_dim(repeats):
6992-
arr = util.ensure_arraylike("repeat", a)
6992+
util.check_arraylike("repeat", a)
69936993
else:
6994-
arr, repeats = util.ensure_arraylike("repeat", a, repeats)
6994+
util.check_arraylike("repeat", a, repeats)
6995+
arr = asarray(a)
69956996

69966997
if axis is None:
69976998
arr = arr.ravel()

tests/array_extensibility_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __getitem__(self, shape) -> jax.ShapeDtypeStruct:
442442
NumPyAPI.sig(jnp.real, Complex[5]),
443443
NumPyAPI.sig(jnp.reciprocal, Float[5]),
444444
NumPyAPI.sig(jnp.remainder, Float[5], Float[5]),
445-
NumPyAPI.sig(jnp.repeat, Float[5], Int[5]),
445+
NumPyAPI.sig(jnp.repeat, Float[5], repeats=np.array([2, 3, 1, 5, 4])),
446446
NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)),
447447
NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)),
448448
NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]),

0 commit comments

Comments
 (0)