Skip to content

Commit 0883c71

Browse files
committed
jnp.repeat: don't cast repeats to array, as they must be static.
1 parent e02faab commit 0883c71

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

jax/_src/numpy/lax_numpy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6989,9 +6989,10 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
69896989
[3, 3, 4, 4, 4, 4, 4]], dtype=int32)
69906990
"""
69916991
if core.is_dim(repeats):
6992-
arr = util.ensure_arraylike("repeat", a)
6992+
util.check_arraylike("repeat", a)
69936993
else:
6994-
arr, repeats = util.ensure_arraylike("repeat", a, repeats)
6994+
util.check_arraylike("repeat", a, repeats)
6995+
arr = asarray(arr)
69956996

69966997
if axis is None:
69976998
arr = arr.ravel()

0 commit comments

Comments
 (0)