Question on Load Imbalance Using JAX with MPI #27949
Unanswered
kousuke-nakano
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Dear users and developers,
I'm currently performing parallel computations using
JAX
combined with MPI (viampi4py
) on a CPU cluster machine. I'm encountering issues related to load imbalance among MPI processes, which is puzzling given that each process executes exactly the same computation.System Info:
Python Info:
Below is a simplified version of my JAX-MPI Python program. Since MPI communications are needed outside jitted functions in my code, I use mpi4py, not mpi4jax.
When running this script on 32 nodes (3840 MPI processes),
where I bound each MPI process to each core (by
—bind-to-core
) and switched off the multithreading (byexport XLA_FLAGS="--xla_force_host_platform_device_count=1 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1”
).I'm observing significant load imbalance.
Given the theoretical expectation that all processes should be completed simultaneously before the MPI barrier (since they execute identical workloads), the high barrier time indicates severe load imbalance.
Questions:
.block_until_ready()
) correct to accurately measure JAX computation times?Any advice or suggestions would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions