Skip to content

feat: unlock MLA attention for sm89 (L40/L40s/4090) #814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 12, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Feb 12, 2025

This PR changes the MLA attention template to support sm89 GPUs, which has small shared memory size (99kb per sm), so we have to further reduce shared memory usage: the NUM_STAGES can only be set to 1, and CTA_TILE_KV could only be set to atmost 16.

We add an option QK_SHARD in the KernelTraits (our previous template only supports QK_SHARD=true):

  1. If true, we use the schedule mentioned in perf: memory efficient deepseek mla fused page-attention kernel #804, and shards the QK computation on KV dimension, each warpgroup compute half of it, and we need to perform a round of allgather on shared memory for getting the full P in PV computation.
  2. If false, we duplicate QK computation on two warpgroups (which is not necessary) but we save the allgather step for P.

We set QK_SHARD=true for A100/H100 (shared memory limit is 164kb and 228kb, correspondingly), and QK_SHARD=false for sm89.

Reference

The effect of QK_SHARD on H100 SXM5 (3352 GB/s):

QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 2010.78 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 2036.13 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2085.52 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 2068.62 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 2085.84 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2080.85 GB/s

QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 1610.81 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 1638.73 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 1690.86 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 1636.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 1651.57 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 1653.31 GB/s

The effect of QK_SHARD on A100 SXM 40GB (1555 GB/s):

QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 891.30 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 929.65 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 954.24 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 923.07 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 933.77 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 943.48 GB/s

QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 753.89 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 780.96 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 804.61 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 785.70 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 796.87 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 808.83 GB/s

@yzh119 yzh119 merged commit 3de690a into main Feb 12, 2025
@MasterJH5574
Copy link
Collaborator

MasterJH5574 commented Feb 12, 2025

NOTE: Numbers below are outdated. Please check #814 (comment)

Performance report on NVIDIA RTX 4090 (memory bandwidth limit 1008 GB/s):

QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 928.17 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 955.77 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 984.69 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 933.95 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 942.48 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 957.72 GB/s

@zhyncs zhyncs deleted the add-qk-sharding-option branch February 12, 2025 21:21
yzh119 added a commit that referenced this pull request Feb 13, 2025
Follow up of #814 , we found some correctness issue of sm89 MLA kernels,
this PR fixes them.
@yzh119
Copy link
Collaborator Author

yzh119 commented Feb 13, 2025

@MasterJH5574 can you try #821 ? Our previous benchmarking on sm89 might not be meaningful because of the kernel bugs.

@MasterJH5574
Copy link
Collaborator

Can you try #821 ? Our previous benchmarking on sm89 might not be meaningful because of the kernel bugs.

Sure, here's the benchmark result after #821


Performance report on NVIDIA RTX 4090 (peak memory bandwidth 1008 GB/s):

Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 688.69 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 703.08 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 721.74 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 694.83 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 701.85 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 704.72 GB/s

yzh119 pushed a commit that referenced this pull request Feb 14, 2025
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my
mind today, can't help to change few lines to verify this idea.
We can use asymmetric warp config to solve the register file size limit
issue, the solution is simply to use 8 warps for the output mma stage,
and keep other parts unchanged, because the limitation is on the reg num
per cuda block not the whole SM, there is 64K 32b registers per SM which
is enough for the f32 output of 64 heads.
So we now have 4 warps for the att mma stage, 2 warps for the softmax
stage, 8 warps for output mma stage, and 4 warps for data load stage,
the diagram is updated below:

![image](https://github.com/user-attachments/assets/2af8c5d9-d5a5-47e6-bd63-7e6b4305a529)

After the change, output mma stage needs more computation, the benchmark
drops a little as expected, but still looks good:

![image](https://github.com/user-attachments/assets/470ec576-ba91-4e71-9604-fcd6f0a9d691)

It seems the performance of this CuTe implementation is slightly better
than the current FA2 implementation according to #814

![image](https://github.com/user-attachments/assets/9f61e2ff-4bb6-4581-a199-bb6176173192)


So I think this CuTe implementation still has its value, consider such
interesting scheduling design and better performance, maybe we can
regard it as an ad hoc implementation for (decode only /128 q-heads /
SM80) case, and JIT logic can accommodate this kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants