Skip to content

How to avoid data copy working with numpy arrays and JAX on the CPU? #27706

Answered by jakevdp
r-aristov asked this question in Q&A
Discussion options

You must be logged in to vote

Your NumPy array is of type float64, and by default, JAX cannot use 64-bit values, so it is converted to a float32 array for use by JAX.

If you'd like to use 64-bit values in JAX, you could set the X64 flag to True (see https://docs.jax.dev/en/latest/default_dtypes.html)

Alternatively, if you'd like to continue using float32, values, you should allocate the NumPy array in float32:

a = np.ones(int(16e8), dtype='float32')

With this change, jnp.asarray should reuse the NumPy buffer as long as it is byte-aligned.

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@r-aristov
Comment options

Answer selected by r-aristov
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants