-
Hi, I'm trying to update my code from If you un-comment the lines marked with Versions:
Here's a minimal reproducible example,
Here's the stacktrace,
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The difference between In [1]: import jax
In [2]: jax.random.split(jax.random.key(0), 10).shape
Out[2]: (10,)
In [3]: jax.random.split(jax.random.PRNGKey(0), 10).shape
Out[3]: (10, 2) The sharding specification you used is If you changed your sharding spec to Does that make sense? |
Beta Was this translation helpful? Give feedback.
I suspect this issue comes from attempting to combine
pmap
andshard_map
.pmap
is deprecated, and is incompatible withshard_map
. You should prefer to do the whole computation usingshard_map
.