Closed
Description
Description
Running the inference stage from Alphafold 3, some users are running into the error
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-ng30101.narval.calcul.quebec-48db254e-784897-628d6ac249c58, line 5; fatal : Unsupported .version 8.4; current version is '8.2'
ptxas fatal : Ptx assembly aborted due to errors
It is my understanding, according to the jax documentation that jaxlib 0.4.34 was built with cuda 12.3 but is compatible with 12.1+.
Also, according to nvida ptxas documentation, cuda 12.3 is ISA 8.3, yet the error says 8.4 which corresponds to cuda 12.4.
Hence, which cuda was used to actually built jaxlib and its plugins ?
Possible alternative solution: use cuda 12.4+
System info (python version, jaxlib version, accelerator, etc.)
Jaxlib : v0.4.34
From pypi, patched and tested to work with our installed cuda 12.2
Cuda version: 12.2
GPU device: A100
nvidia driver: 550.127.08