Does jax.lax.linalg.tridiagonal_solve
correctly use the cusparse
batched implementations when appropriate?
#28371
Unanswered
jpbrodrick89
asked this question in
Q&A
Replies: 1 comment
-
Good question! It looks like this op currently bottoms out in the unbatched implementation (the relevant backend code is here) with a loop over the batch dimensions. It seems like a good feature request to use the batched solvers when we can, and shouldn't be too hard to implement! Perhaps it's worth opening the feature request as an issue with more details about your specific use case? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've been trying to follow through the source code to answer this myself, but I keep getting lost/confused.
When looking at the
cusparse
library docs I noticed there are specific batched implementations, e.g.cusparse<t>gtsv2StridedBatch()
andcusparse<t>gtsvInterleavedBatch()
, which I assume are quite efficient. However, when profilingtridiagonal_solve
with an array size of 200 and batch sizes up to 6400, I observe a linear complexity throughout despite not reaching memory/kernel limits of A100. Alternative implementations using unrolled loops show that sub-linear scaling should be possible in this regime, which makes me think that the batched implementations might not actually be used here. Is this a "bug", a "missing feature", or simply the current upstream behaviour of the batchedcusparse
routines? If not the latter, how complicated would it be to address?Thank you!
Beta Was this translation helpful? Give feedback.
All reactions