Skip to content

LIBERO inference #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Jerryisqx opened this issue Mar 29, 2025 · 2 comments
Open

LIBERO inference #407

Jerryisqx opened this issue Mar 29, 2025 · 2 comments
Assignees

Comments

@Jerryisqx
Copy link

Jerryisqx commented Mar 29, 2025

Hi Team,
I'm trying to make inference and evaluation on LIBERO dataset. However, during loading the checkpoing of pi-0-fast-libero, I met the cuda_dnn error:

username@my_host_machine$ uv run scripts/serve_policy.py --env LIBERO
      Built draccus @ git+https://github.com/dlwh/draccus.git@9b690730ca108930519f48cc5dead72a72fd27cb
Uninstalled 1 package in 6.23s
Installed 1 package in 9.61s
INFO:root:Loading model...
E0330 18:23:43.985579   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:23:43.985789   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
INFO:2025-03-30 18:23:43,986:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-03-30 18:23:43,987:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:absl:orbax-checkpoint version: 0.11.1
INFO:absl:Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=None
INFO:absl:Restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:absl:[process=0] /jax/checkpoint/read/bytes_per_sec: 327.5 MiB/s (total bytes: 5.4 GiB) (time elapsed: 17 seconds) (per-host)
INFO:absl:Finished restoring checkpoint from /scratch_net/biwidl313_second/chenqing/openpi/openpi-assets/checkpoints/pi0_fast_libero/params.
INFO:absl:[process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore
E0330 18:24:01.090760   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
E0330 18:24:01.091413   78452 cuda_dnn.cc:502] There was an error before creating cudnn handle (35): cudaErrorInsufficientDriver : CUDA driver version is insufficient for CUDA runtime version
Traceback (most recent call last):
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 123, in <module>
    main(tyro.cli(Args))
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 101, in main
    policy = create_policy(args)
             ^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 97, in create_policy
    return create_default_policy(args.env, default_prompt=args.default_prompt)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/scripts/serve_policy.py", line 83, in create_default_policy
    return _policy_config.create_trained_policy(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/src/openpi/policies/policy_config.py", line 56, in create_trained_policy
    model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/src/openpi/models/model.py", line 228, in load
    model = nnx.eval_shape(self.create, jax.random.key(0))
                                        ^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 218, in key
    return _key('key', seed, impl)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key
    return prng.random_seed(seed, impl=impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 529, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5820, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5653, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 612, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 3254, in _convert_element_type_bind_with_trace
    operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 954, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 89, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 198, in _python_pjit_helper
    out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1669, in _pjit_call_impl_python
    ).compile()
      ^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2419, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2922, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2723, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 464, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 665, in _compile_and_write_cache
    executable = backend_compile(
                 ^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 321, in backend_compile
    raise e
  File "/scratch_net/biwidl313_second/openpi/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 315, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I install the depency correctly and didn't not meet any problem. I'm using CUDA11.4 with driver version 470.82.01 on cluster.

@Jerryisqx Jerryisqx changed the title pi_0 quantization LIBERO inference Mar 30, 2025
@Jerryisqx
Copy link
Author

See cuurent

nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA TITAN X ...  On   | 00000000:04:00.0 Off |                  N/A |
| 23%   27C    P8     8W / 250W |      1MiB / 12196MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

nvcc-V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_Oct_11_21:27:02_PDT_2021
Cuda compilation tools, release 11.4, V11.4.152
Build cuda_11.4.r11.4/compiler.30521435_0

@Jerryisqx Jerryisqx reopened this Mar 30, 2025
@uzhilinsky
Copy link
Collaborator

Have you tried running with Docker?
If it works, JAX is likely picking up the wrong cuda dependencies.

@uzhilinsky uzhilinsky self-assigned this Apr 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants