Where is TPU device classification defined in JAX? #26874
Unanswered
samixyzdev
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Where is TPU device classification defined in JAX?
Hi JAX team and contributors,
I'm looking to contribute to JAX by modifying how TPU devices are classified.
Currently, multi-slice TPU devices are labeled as
MegaScalePjRtDevice
, but I believe they should remainTpuDevice
.I've tried tracing the classification process and followed these steps:
MegaScalePjRtDevice
in the JAX codebase.make_tpu_client()
, which is called inxla_bridge.py
.make_tpu_client()
is insidejaxlib
, likely in a compiled.so
file.libtpu.so
but didn’t find it on my system.pjrt_plugin
, but I couldn't confirm its exact behavior.Question: Where exactly in the JAX codebase is the TPU device classification (
MegaScalePjRtDevice
vs.TpuDevice
) handled? I'd love to contribute a fix, but I need to understand where this happens.Any guidance is appreciated! Thanks in advance. 😊
Beta Was this translation helpful? Give feedback.
All reactions