Skip to content

Bad error message when shardings are non-divisible #27808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yashk2810 opened this issue Apr 7, 2025 · 0 comments
Open

Bad error message when shardings are non-divisible #27808

yashk2810 opened this issue Apr 7, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@yashk2810
Copy link
Collaborator

Description

Repro:

In [11]: mesh = jax.make_mesh((4, 2), ('x', 'y'))

In [12]: @partial(jax.jit, out_shardings=s)
    ...: def f(x):
    ...:     return x
    ...:

In [13]: s = NamedSharding(mesh, P('x'))

In [14]: f(jnp.arange(5))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[14], line 1
----> 1 f(jnp.arange(5))

    [... skipping hidden 9 frame]

File ~/venv/lib/python3.12/site-packages/jax/_src/pjit.py:1417, in pjit_check_aval_sharding(shardings, flat_avals, names, what_aval, allow_uneven_sharding)
   1415 for i, size in enumerate(num_ways_dim_sharded):
   1416   if not allow_uneven_sharding and shape[i] % size != 0:
-> 1417     raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
   1418                      f"of {s}, which implies that "
   1419                      f"the global size of its dimension {i} should be "
   1420                      f"divisible by {size}, but it is equal to {shape[i]} "
   1421                      f"(full shape: {shape})")

ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('x': 4, 'y': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('x',), memory_kind=unpinned_host), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 5 (full shape: (5,))

This sentence in the error message: pytree key path result is confusing because there is no name given for the output. So we should just not print this as it leads to more confusion.

System info (python version, jaxlib version, accelerator, etc.)

N/A

@yashk2810 yashk2810 added the bug Something isn't working label Apr 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants