-
Hello. I can not figure out how to implement the following: I'd like to have a large buffer in RAM (on CPU) and write items to it without creating a copy, like I can do with numpy - buffer[ptr] = data. But I also want to sample batches from it fast and move them to GPU. I thought I can use JITed JAX function that efficiently handles sampling, but ran into several problems. Even on the CPU JAX can not modify data in place - am I correct, or did I miss something? But this is not a real issue, my real question is following. Trying to solve this, I've made a numpy array CPU buffer, which allows in-place modifications. And tried to convert it to JAX array on the CPU for sampling without copying the whole buffer. But JAX still creates a copy of the data, it does not matter whether I call jnp.asarray or jnp.array.
This code prints:
First of all, I can not quite understand where these numbers are coming from - I mean, why there is 6 GB increase and not 12 GB if this is a copy?
I get this error:
Which suggests that it sees backend=gpu. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Your NumPy array is of type If you'd like to use 64-bit values in JAX, you could set the X64 flag to True (see https://docs.jax.dev/en/latest/default_dtypes.html) Alternatively, if you'd like to continue using float32, values, you should allocate the NumPy array in float32: a = np.ones(int(16e8), dtype='float32') With this change, |
Beta Was this translation helpful? Give feedback.
-
UPD: I see, there is a difference between default device and default backend.
Well, that solves the issue. It's a bit unexpected that context manager does not affect a backend though. I have one related question left - I am trying to remove unnecessary data copying from my pipeline. But is there a way to efficiently sample directly from the list of arrays without combining them into one in JAX? One way to do it is to use python for loop to run through a list and sample items from each of the arrays - when JITed it is petty fast, but JITting takes ages even for a small buffer size of 100-200 items. |
Beta Was this translation helpful? Give feedback.
Your NumPy array is of type
float64
, and by default, JAX cannot use 64-bit values, so it is converted to a float32 array for use by JAX.If you'd like to use 64-bit values in JAX, you could set the X64 flag to True (see https://docs.jax.dev/en/latest/default_dtypes.html)
Alternatively, if you'd like to continue using float32, values, you should allocate the NumPy array in float32:
With this change,
jnp.asarray
should reuse the NumPy buffer as long as it is byte-aligned.