Skip to content

jax[cuda] 0.5.3 packaging broken #27874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
whbdupree opened this issue Apr 9, 2025 · 2 comments
Open

jax[cuda] 0.5.3 packaging broken #27874

whbdupree opened this issue Apr 9, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@whbdupree
Copy link

Description

Hi. I am installing jax with cuda using venv and pip. I have found that the latest jax[cuda12] 0.5.3 packaging is somehow broken. If I install 0.5.3, 0.5.0, and then 0.5.3 again in the same environment, something extra is provided by 0.5.0 which fixes the 0.5.3 environment.

python3 -m venv jax-gpu
source jax-gpu/bin/activate
pip install jax[cuda12]==0.5.3 jaxlib==0.5.3
python
>>> from jax import numpy as jnp
>>> a=jnp.array((3,4,5))-0.123

this renders the following error output

E0409 09:38:35.711191   28819 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0409 09:38:35.711256   28819 cuda_dnn.cc:538] Memory usage: 50544508928 bytes free, 50919505920 bytes total.
E0409 09:38:35.711574   28819 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0409 09:38:35.711625   28819 cuda_dnn.cc:538] Memory usage: 50544508928 bytes free, 50919505920 bytes total.
E0409 09:38:35.741310   28819 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0409 09:38:35.741356   28819 cuda_dnn.cc:538] Memory usage: 50544508928 bytes free, 50919505920 bytes total.
E0409 09:38:35.741753   28819 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0409 09:38:35.741777   28819 cuda_dnn.cc:538] Memory usage: 50544508928 bytes free, 50919505920 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 5555, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1614, in _convert_element_type
    return convert_element_type_p.bind(
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/core.py", line 502, in bind
    return self._true_bind(*args, **params)
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/core.py", line 520, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 4701, in _convert_element_type_bind_with_trace
    operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/core.py", line 525, in bind_with_trace
    return trace.process_primitive(self, args, params)
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/core.py", line 1029, in process_primitive
    return primitive.impl(*args, **params)
  File "/home/whbarnet/jax-gpu/lib/python3.10/site-packages/jax/_src/dispatch.py", line 88, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

now I install jax and jaxlib 0.5.0 in the same environment:

pip install jax[cuda12]==0.5.0 jaxlib==0.5.0
python
>>> from jax import numpy as jnp
>>> a=jnp.array((3,4,5))-0.123

I receive no errors. Let's now install 0.5.3 again in the same environment.

pip install jax[cuda12]==0.5.3 jaxlib==0.5.3
python
>>> from jax import numpy as jnp
>>> a=jnp.array((3,4,5))-0.123

The above now runs with no errors.

System info (python version, jaxlib version, accelerator, etc.)

I'm on Ubuntu 22.04.5 with python 3.10.12.
In the example I show jax[cuda12] and jaxlib versions 0.5.0 and 0.5.3.
I have removed my hostname from the output below

python
Python 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.5.3
jaxlib: 0.5.3
numpy:  2.2.4
python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
device info: NVIDIA RTX A6000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='xyz', release='6.8.0-57-generic', version='#59~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Mar 19 17:07:41 UTC 2', machine='x86_64')


$ nvidia-smi
Wed Apr  9 09:46:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A6000               Off |   00000000:B3:00.0 Off |                  Off |
| 30%   35C    P2             28W /  300W |     362MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1635      G   /usr/lib/xorg/Xorg                             64MiB |
|    0   N/A  N/A      1986      G   /usr/bin/gnome-shell                            8MiB |
|    0   N/A  N/A     29281      C   python                                        266MiB |
+-----------------------------------------------------------------------------------------+

@whbdupree whbdupree added the bug Something isn't working label Apr 9, 2025
@whbdupree whbdupree changed the title jax 0.5.3 packaing broken jax[cuda] 0.5.3 packaing broken Apr 9, 2025
@whbdupree whbdupree changed the title jax[cuda] 0.5.3 packaing broken jax[cuda] 0.5.3 packaging broken Apr 9, 2025
@whbdupree
Copy link
Author

whbdupree commented Apr 9, 2025

I have noticed the following warning when installing 0.5.3:

$ pip install jax[cuda12]==0.5.3 jaxlib==0.5.3
Collecting jax[cuda12]==0.5.3
  Using cached jax-0.5.3-py3-none-any.whl (2.4 MB)
Collecting jaxlib==0.5.3
  Using cached jaxlib-0.5.3-cp310-cp310-manylinux2014_x86_64.whl (105.1 MB)
Requirement already satisfied: opt_einsum in ./jax-gpu/lib/python3.10/site-packages (from jax[cuda12]==0.5.3) (3.4.0)
Requirement already satisfied: ml_dtypes>=0.4.0 in ./jax-gpu/lib/python3.10/site-packages (from jax[cuda12]==0.5.3) (0.5.1)
Requirement already satisfied: numpy>=1.25 in ./jax-gpu/lib/python3.10/site-packages (from jax[cuda12]==0.5.3) (2.2.4)
Requirement already satisfied: scipy>=1.11.1 in ./jax-gpu/lib/python3.10/site-packages (from jax[cuda12]==0.5.3) (1.15.2)
Collecting jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3
  Using cached jax_cuda12_plugin-0.5.3-cp310-cp310-manylinux2014_x86_64.whl (16.7 MB)
WARNING: jax-cuda12-plugin 0.5.3 does not provide the extra 'with_cuda'

@nitins17
Copy link
Collaborator

Thanks for bringing this to our attention!

Those CUDA errors you're seeing are due to missing CUDA packages on your system (see end of comment for mitigations)

Here's some context on why this is happening: JAX defines its CUDA support as an "extras" feature, which is why we can install it using:

pip install jax[cuda12]

See setup.py for the jax package. Note the line jax-cuda12-plugin[with_cuda]. This instructs pip to install the jax-cuda12-plugin package along with its with_cuda extra (see setup.py for Plugin package).

The reason this works with version 0.5.0 but not 0.5.3 appears to be a change in how Python packaging metadata is handled between versions 2.1 and 2.2 of the packaging standards.

In version 0.5.0, the METADATA for jax-cuda12-plugin was generated as (METADATA version 2.1):

Provides-Extra: with_cuda

However, in version 0.5.3, the METADATA for jax-cuda12-plugin is generated as (METADATA version 2.2):

Provides-Extra: with-cuda

It seems that v2.2 automatically converts any key with an underscore to a hyphen within extras_require but not the values of these keys. Therefore, when installing jax, pip still tries to find the with_cuda extra to install jax-cuda12-plugin[with_cuda] but it cannot find that extra in jax-cuda12-plugin anymore which is why we get the error:

WARNING: jax-cuda12-plugin 0.5.3 does not provide the extra 'with_cuda'

Mitigations:
We should have this fixed in our next release. In the meantime, you can use the following command as a temporary solution:

pip install jax==0.5.3 jax-cuda12-plugin[with-cuda]==0.5.3

This should install all the required CUDA packages from PyPI

Interestingly, this issue doesn't seem to occur when using uv. Therefore, another potential workaround is to install the packages using uv:

pip install uv
uv pip install jax[cuda12]==0.5.3 jaxlib==0.5.3

Thanks again for pointing this out!

copybara-service bot pushed a commit that referenced this issue Apr 10, 2025
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: #27874  from appearing in future JAX releases

PiperOrigin-RevId: 746145117
copybara-service bot pushed a commit that referenced this issue Apr 11, 2025
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: #27874  from appearing in future JAX releases

PiperOrigin-RevId: 746145117
copybara-service bot pushed a commit that referenced this issue Apr 11, 2025
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: #27874  from appearing in future JAX releases

PiperOrigin-RevId: 746546162
charleshofer pushed a commit to ROCm/jax that referenced this issue Apr 30, 2025
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: jax-ml#27874  from appearing in future JAX releases

PiperOrigin-RevId: 746546162
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants