Type annotation for PRNG keys #27577
-
What's the recommended type annotation for a PRNG key (as opposed to a generic Array)? I think it used to be PRNGKey, but that's legacy. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The right annotation is New-style typed keys (produced by |
Beta Was this translation helpful? Give feedback.
The right annotation is
jax.Array
– this was decided here: https://docs.jax.dev/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keysNew-style typed keys (produced by
jax.random.key
) are defined by their dtype, and JAX does not offer any dtype-specific array annotations, soArray
is the only appropriate annotation. Old-style keys (produced byjax.random.PRNGKey
) are just arrays with dtypeuint32
and a trailing dimension of a particular size depending on the default PRNG impl, and soArray
is the only appropriate annotation.