-
Notifications
You must be signed in to change notification settings - Fork 3k
jax[cuda] 0.5.3 packaging broken #27874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I have noticed the following warning when installing 0.5.3:
|
Thanks for bringing this to our attention! Those CUDA errors you're seeing are due to missing CUDA packages on your system (see end of comment for mitigations) Here's some context on why this is happening: JAX defines its CUDA support as an "extras" feature, which is why we can install it using:
See setup.py for the jax package. Note the line The reason this works with version 0.5.0 but not 0.5.3 appears to be a change in how Python packaging metadata is handled between versions 2.1 and 2.2 of the packaging standards. In version 0.5.0, the METADATA for
However, in version 0.5.3, the METADATA for
It seems that v2.2 automatically converts any key with an underscore to a hyphen within extras_require but not the values of these keys. Therefore, when installing
Mitigations:
This should install all the required CUDA packages from PyPI Interestingly, this issue doesn't seem to occur when using uv. Therefore, another potential workaround is to install the packages using uv:
Thanks again for pointing this out! |
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: #27874 from appearing in future JAX releases PiperOrigin-RevId: 746145117
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: #27874 from appearing in future JAX releases PiperOrigin-RevId: 746145117
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: #27874 from appearing in future JAX releases PiperOrigin-RevId: 746546162
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: jax-ml#27874 from appearing in future JAX releases PiperOrigin-RevId: 746546162
Description
Hi. I am installing jax with cuda using venv and pip. I have found that the latest jax[cuda12] 0.5.3 packaging is somehow broken. If I install 0.5.3, 0.5.0, and then 0.5.3 again in the same environment, something extra is provided by 0.5.0 which fixes the 0.5.3 environment.
this renders the following error output
now I install jax and jaxlib 0.5.0 in the same environment:
I receive no errors. Let's now install 0.5.3 again in the same environment.
The above now runs with no errors.
System info (python version, jaxlib version, accelerator, etc.)
I'm on Ubuntu 22.04.5 with python 3.10.12.
In the example I show jax[cuda12] and jaxlib versions 0.5.0 and 0.5.3.
I have removed my hostname from the output below
The text was updated successfully, but these errors were encountered: