Replies: 2 comments
-
+1 |
Beta Was this translation helpful? Give feedback.
0 replies
-
Not really. But if you describe what you are trying to do, I can probably suggest alternatives. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have tried
sharding = PositionalSharding(mesh_utils.create_device_mesh((3,), jax.devices() + jax.devices("cpu")))
However, JAX regarded the CPU as the GPU. At the same time, I couldn't put the array on it.
In hugging face, it is possible to put the model on GPU and CPU using
device_map
. Is it possible to do this in JAX?Beta Was this translation helpful? Give feedback.
All reactions