Skip to content

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vybhav-ibr opened this issue May 19, 2025 · 0 comments
Open

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: #486

vybhav-ibr opened this issue May 19, 2025 · 0 comments

Comments

@vybhav-ibr
Copy link

I could install the package successfully, but can't run inference using the scripts/serve_policy.py --env ALOHA_SIM.
i am installing the packages on a vm with ubuntu 22.04 and python3.11.

INFO:root:Loading model...
INFO:2025-05-19 16:26:54,912:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extens>
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAll>
INFO:2025-05-19 16:26:54,913:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open l>
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot o>
INFO:absl:orbax-checkpoint version: 0.11.1
INFO:absl:Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=Fals>
INFO:absl:Restoring checkpoint from /root/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params.
INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:absl:[process=0] /jax/checkpoint/read/bytes_per_sec: 398.4 MiB/s (total bytes: 6.0 GiB) (time elapsed: 15 secon>
INFO:absl:Finished restoring checkpoint from /root/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params.
INFO:absl:[process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore
2025-05-19 16:27:11.111685: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:237] Falling back to t>
2025-05-19 16:27:11.111710: W external/xla/xla/stream_executor/cuda/subprocess_compilation.cc:240] Used ptxas at /wo>
Traceback (most recent call last):
File "/workspace/openpi/scripts/serve_policy.py", line 122, in
main(tyro.cli(Args))
File "/workspace/openpi/scripts/serve_policy.py", line 100, in main
policy = create_policy(args)
^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/scripts/serve_policy.py", line 96, in create_policy
return create_default_policy(args.env, default_prompt=args.default_prompt)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/scripts/serve_policy.py", line 82, in create_default_policy
return _policy_config.create_trained_policy(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/src/openpi/policies/policy_config.py", line 56, in create_trained_policy
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/src/openpi/models/model.py", line 228, in load
model = nnx.eval_shape(self.create, jax.random.key(0))
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 218, in key
return _key('key', seed, impl)
^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/random.py", line 198, in _key
return prng.random_seed(seed, impl=impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 534, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 463, in bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 468, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 954, in process_primitive
return primitive.impl(*args, **params)

       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 546, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 551, in random_seed_impl_base
return seed(seeds)
^^^^^^^^^^^
File "/workspace/openpi/.venv/lib/python3.11/site-packages/jax/_src/prng.py", line 767, in threefry_seed
return _threefry_seed(seed)
^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: /workspace/openpi/.venv/lib/python3.11/site-packages/nvidia/cud>

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant