Pipelined calculation involving scan
and pure_callback
#25232
Unanswered
mfschubert
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
This question is in regards to pipelined calculations involving
jax.lax.scan
andjax.pure_callback
. I have an expensive calculation with two parts:pure_callback
Both calculations are used in a scan operation, as shown in the example below.
I am hoping to pipeline the calculation to speed things up, as follows:
I would expect the compute time to be cut in half here, since the
slow_callback_fn(key)
anddummy_jax_fn(carry)
take an equal amount of time and can run independently. However, this doesn't seem to be the case in practice.Is this expected? Is there some other way I can force these two calculations to run in parallel?
Beta Was this translation helpful? Give feedback.
All reactions