Skip to content

compute_norm_stats.py error with shard dimensions #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tswie opened this issue May 13, 2025 · 1 comment
Open

compute_norm_stats.py error with shard dimensions #481

tswie opened this issue May 13, 2025 · 1 comment

Comments

@tswie
Copy link

tswie commented May 13, 2025

Hello, currently going through the fine-tuning tutorial on LIBERO, successfully converted the dataset I downloaded from huggingface, then ran into this issue:

myuser@myserver:/data/pi0/openpi$ uv run scripts/compute_norm_stats.py --config-name pi0_libero
warning: `VIRTUAL_ENV=examples/libero/.venv` does not match the project environment path `.venv` and will be ignored; use `--active` to target the active environment instead
Some kwargs in processor config are unused and will not have any effect: scale, vocab_size, action_dim, min_token, time_horizon.
Some kwargs in processor config are unused and will not have any effect: scale, vocab_size, action_dim, min_token, time_horizon.
Fetching 4 files: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 1862.89it/s]
Fetching 4 files: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 2680.07it/s]
Fetching 1699 files: 100%|███████████████████████████████████████████████████████| 1699/1699 [00:00<00:00, 1923.54it/s]
Resolving data files: 100%|████████████████████████████████████████████████████| 1693/1693 [00:00<00:00, 501196.83it/s]
Loading dataset shards: 100%|████████████████████████████████████████████████████████| 70/70 [00:00<00:00, 4847.23it/s]
Computing stats:   0%|                                                                      | 0/273465 [00:30<?, ?it/s]
Traceback (most recent call last):
  File "/data/pi0/openpi/scripts/compute_norm_stats.py", line 75, in <module>
    tyro.cli(main)
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/tyro/_cli.py", line 189, in cli
    return run_with_args_from_cli()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/scripts/compute_norm_stats.py", line 62, in main
    for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/data/pi0/openpi/src/openpi/training/data_loader.py", line 261, in __iter__
    yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py", line 358, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py", line 358, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ^^^^^^
  File "/data/pi0/openpi/src/openpi/training/data_loader.py", line 261, in <lambda>
    yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/array.py", line 910, in make_array_from_process_local_data
    return api.device_put(local_data, sharding)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 2300, in device_put
    _check_sharding(shaped_abstractify(xf), d)
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 2218, in _check_sharding
    pjit.pjit_check_aval_sharding(
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1440, in pjit_check_aval_sharding
    raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('B': 2), spec=PartitionSpec('B',), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 10, 7))
Exception ignored in atexit callback: <function _exit_function at 0x7f801888aa20>
Traceback (most recent call last):
  File "/home/myuser/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/multiprocessing/util.py", line 360, in _exit_function
    p.join()
  File "/home/myuser/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/myuser/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/myuser/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pi0/openpi/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 306790) is killed by signal: Terminated.

Not sure where to start wtih this, could this imply the dataset download was corrupted? 
@tswie
Copy link
Author

tswie commented May 13, 2025

Of course I "fix" it a couple minutes after making the issue. I ran CUDA_VISIBLE_DEVICES=1 to only use one of my GPUs, 1 being the ID of the device. I also killed the policy server I had running in a different terminal I had forgotten about, but I can confirm that I do require limiting the GPUs it can see in order to run this.

EDIT: I'll wait for a response to close this in case this isn't intended behavior and I should be able to run it with two GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant