-
Notifications
You must be signed in to change notification settings - Fork 290
feat: unlock MLA attention for sm89 (L40/L40s/4090) #814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
NOTE: Numbers below are outdated. Please check #814 (comment)
|
yzh119
added a commit
that referenced
this pull request
Feb 13, 2025
Follow up of #814 , we found some correctness issue of sm89 MLA kernels, this PR fixes them.
@MasterJH5574 can you try #821 ? Our previous benchmarking on sm89 might not be meaningful because of the kernel bugs. |
Sure, here's the benchmark result after #821 Performance report on NVIDIA RTX 4090 (peak memory bandwidth 1008 GB/s):
|
yzh119
pushed a commit
that referenced
this pull request
Feb 14, 2025
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my mind today, can't help to change few lines to verify this idea. We can use asymmetric warp config to solve the register file size limit issue, the solution is simply to use 8 warps for the output mma stage, and keep other parts unchanged, because the limitation is on the reg num per cuda block not the whole SM, there is 64K 32b registers per SM which is enough for the f32 output of 64 heads. So we now have 4 warps for the att mma stage, 2 warps for the softmax stage, 8 warps for output mma stage, and 4 warps for data load stage, the diagram is updated below:  After the change, output mma stage needs more computation, the benchmark drops a little as expected, but still looks good:  It seems the performance of this CuTe implementation is slightly better than the current FA2 implementation according to #814  So I think this CuTe implementation still has its value, consider such interesting scheduling design and better performance, maybe we can regard it as an ad hoc implementation for (decode only /128 q-heads / SM80) case, and JIT logic can accommodate this kernel.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR changes the MLA attention template to support sm89 GPUs, which has small shared memory size (99kb per sm), so we have to further reduce shared memory usage: the
NUM_STAGES
can only be set to 1, andCTA_TILE_KV
could only be set to atmost 16.We add an option
QK_SHARD
in the KernelTraits (our previous template only supportsQK_SHARD=true
):We set
QK_SHARD=true
for A100/H100 (shared memory limit is 164kb and 228kb, correspondingly), andQK_SHARD=false
for sm89.Reference
The effect of
QK_SHARD
on H100 SXM5 (3352 GB/s):The effect of
QK_SHARD
on A100 SXM 40GB (1555 GB/s):