Skip to content

CUDA: batched+noncont MMQ, refactor bs>1 MoE code #13199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Apr 29, 2025

This PR makes the following changes:

  • Extend the CUDA code for GET_ROWS to allow for type conversion during the operation.
  • Refactor of the MoE CUDA code for batch sizes >1. If possible, the matrix multiplications are done batched via MMQ (see below). Otherwise, calculate ids for sorting src1 to be sorted by expert via GET_ROWS as well as for the inverse operation on dst. The sorting in either direction can be done in a single kernel launch, the dedicated kernels that have been used so far can be removed.
  • Extend MMQ to support batched matrix multiplication. This makes prompt processing with quantized K cache and without FlashAttention a bit faster.
  • For MoE using MMQ, provide the kernel with information regarding which columns are used for which expert. If there is a mismatch for a tile to be calculated, skip that tile. Results are re-arranged at the end of the kernel with the provided row ids.
Performance changes
GPU Model Microbatch size K type Test t/s master t/s PR Speedup
P40 deepseek2 16B Q4_0 2 f16 pp2048 38.57 45.23 1.17
P40 deepseek2 16B Q4_0 4 f16 pp2048 61.19 75.49 1.23
P40 deepseek2 16B Q4_0 8 f16 pp2048 86.13 96.88 1.12
P40 deepseek2 16B Q4_0 16 f16 pp2048 131.16 167.23 1.28
P40 deepseek2 16B Q4_0 32 f16 pp2048 203.13 260.11 1.28
P40 deepseek2 16B Q4_0 64 f16 pp2048 307.59 427.49 1.39
P40 deepseek2 16B Q4_0 128 f16 pp2048 451.00 657.59 1.46
P40 deepseek2 16B Q4_0 256 f16 pp2048 615.30 896.69 1.46
P40 deepseek2 16B Q4_0 512 f16 pp2048 770.49 1091.92 1.42
P40 deepseek2 16B Q4_0 1024 f16 pp2048 896.85 1226.97 1.37
P40 deepseek2 16B Q4_0 2048 f16 pp2048 917.99 1172.51 1.28
P40 llama 8B Q4_0 512 f16 pp512 967.93 959.68 0.99
P40 llama 8B Q4_0 512 q8_0 pp512 971.74 983.56 1.01
2x P40 deepseek2 16B F16 2 f16 pp2048 16.78 18.15 1.08
2x P40 deepseek2 16B F16 4 f16 pp2048 24.30 26.88 1.11
2x P40 deepseek2 16B F16 8 f16 pp2048 32.81 36.57 1.11
2x P40 deepseek2 16B F16 16 f16 pp2048 47.75 53.55 1.12
2x P40 deepseek2 16B F16 32 f16 pp2048 71.59 80.42 1.12
2x P40 deepseek2 16B F16 64 f16 pp2048 110.44 123.02 1.11
2x P40 deepseek2 16B F16 128 f16 pp2048 174.71 195.21 1.12
2x P40 deepseek2 16B F16 256 f16 pp2048 280.20 314.95 1.12
2x P40 deepseek2 16B F16 512 f16 pp2048 403.56 459.92 1.14
2x P40 deepseek2 16B F16 1024 f16 pp2048 545.76 629.49 1.15
2x P40 deepseek2 16B F16 2048 f16 pp2048 641.06 748.06 1.17
RTX 3090 deepseek2 16B Q4_0 2 f16 pp2048 125.04 153.63 1.23
RTX 3090 deepseek2 16B Q4_0 4 f16 pp2048 181.34 250.44 1.38
RTX 3090 deepseek2 16B Q4_0 8 f16 pp2048 256.66 373.73 1.46
RTX 3090 deepseek2 16B Q4_0 16 f16 pp2048 245.31 493.13 2.01
RTX 3090 deepseek2 16B Q4_0 32 f16 pp2048 396.85 855.84 2.16
RTX 3090 deepseek2 16B Q4_0 64 f16 pp2048 626.72 1279.05 2.04
RTX 3090 deepseek2 16B Q4_0 128 f16 pp2048 933.16 2047.66 2.19
RTX 3090 deepseek2 16B Q4_0 256 f16 pp2048 1536.15 3111.54 2.03
RTX 3090 deepseek2 16B Q4_0 512 f16 pp2048 2230.09 3963.50 1.78
RTX 3090 deepseek2 16B Q4_0 1024 f16 pp2048 2894.33 4444.53 1.54
RTX 3090 deepseek2 16B Q4_0 2048 f16 pp2048 3418.50 4543.59 1.33
RTX 3090 llama 8B Q4_0 512 f16 pp512 4882.03 4823.46 0.99
RTX 3090 llama 8B Q4_0 512 q8_0 pp512 4239.11 4739.13 1.12
RTX 4090 deepseek2 16B Q4_0 2 f16 pp2048 114.83 202.57 1.76
RTX 4090 deepseek2 16B Q4_0 4 f16 pp2048 171.59 366.63 2.14
RTX 4090 deepseek2 16B Q4_0 8 f16 pp2048 255.52 594.02 2.32
RTX 4090 deepseek2 16B Q4_0 16 f16 pp2048 257.61 857.61 3.33
RTX 4090 deepseek2 16B Q4_0 32 f16 pp2048 427.93 1493.68 3.49
RTX 4090 deepseek2 16B Q4_0 64 f16 pp2048 728.54 2317.36 3.18
RTX 4090 deepseek2 16B Q4_0 128 f16 pp2048 1207.56 3874.00 3.21
RTX 4090 deepseek2 16B Q4_0 256 f16 pp2048 2209.40 5936.60 2.69
RTX 4090 deepseek2 16B Q4_0 512 f16 pp2048 3491.50 7603.30 2.18
RTX 4090 deepseek2 16B Q4_0 1024 f16 pp2048 4904.70 8385.58 1.71
RTX 4090 deepseek2 16B Q4_0 2048 f16 pp2048 6066.02 8076.67 1.33
RTX 4090 llama 8B Q4_0 512 f16 pp512 11726.57 11830.85 1.01
RTX 4090 llama 8B Q4_0 512 q8_0 pp512 9085.42 11901.64 1.31
2x RTX 4090 deepseek2 16B F16 2 f16 pp2048 102.34 111.36 1.09
2x RTX 4090 deepseek2 16B F16 4 f16 pp2048 149.56 179.42 1.20
2x RTX 4090 deepseek2 16B F16 8 f16 pp2048 214.69 279.62 1.30
2x RTX 4090 deepseek2 16B F16 16 f16 pp2048 319.29 446.22 1.40
2x RTX 4090 deepseek2 16B F16 32 f16 pp2048 488.92 721.06 1.47
2x RTX 4090 deepseek2 16B F16 64 f16 pp2048 807.07 1223.89 1.52
2x RTX 4090 deepseek2 16B F16 128 f16 pp2048 1323.58 2018.89 1.53
2x RTX 4090 deepseek2 16B F16 256 f16 pp2048 2303.90 3357.35 1.46
2x RTX 4090 deepseek2 16B F16 512 f16 pp2048 3666.98 4986.41 1.36
2x RTX 4090 deepseek2 16B F16 1024 f16 pp2048 5307.58 6578.51 1.24
2x RTX 4090 deepseek2 16B F16 2048 f16 pp2048 6353.03 7040.15 1.11
RX 6800 deepseek2 16B Q4_0 2 f16 pp2048 39.97 49.19 1.23
RX 6800 deepseek2 16B Q4_0 4 f16 pp2048 62.40 81.22 1.30
RX 6800 deepseek2 16B Q4_0 8 f16 pp2048 94.15 106.16 1.13
RX 6800 deepseek2 16B Q4_0 16 f16 pp2048 120.26 168.81 1.40
RX 6800 deepseek2 16B Q4_0 32 f16 pp2048 172.69 230.00 1.33
RX 6800 deepseek2 16B Q4_0 64 f16 pp2048 241.90 360.25 1.49
RX 6800 deepseek2 16B Q4_0 128 f16 pp2048 338.53 533.42 1.58
RX 6800 deepseek2 16B Q4_0 256 f16 pp2048 477.19 770.28 1.61
RX 6800 deepseek2 16B Q4_0 512 f16 pp2048 561.73 871.07 1.55
RX 6800 deepseek2 16B Q4_0 1024 f16 pp2048 742.31 1144.77 1.54
RX 6800 deepseek2 16B Q4_0 2048 f16 pp2048 773.94 1073.95 1.39
RX 6800 llama 8B Q4_0 512 f16 pp512 773.45 777.47 1.01
RX 6800 llama 8B Q4_0 512 q8_0 pp512 779.44 821.89 1.05

Performance increases most for small batch sizes and fast GPUs where the kernel launch overhead has more impact. I think there is still a lot of potential for optimization in the MMQ kernel. For the generic MoE code there are currently still unnecessary type conversions for FP16 and BF16; eliminating them will require some changes to the cuBLAS code. I did not try cublasGemmGroupedBatchedEx because it to my disappointment only supports CUBLAS_COMPUTE_32F, so no tensor cores. It may be worthwhile to instead do an implementation with regular batched GEMM by padding all src1 matrices to the max. number of tokens per expert - on modern GPUs this may end up being faster even if some of the work is wasted.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 29, 2025
@slaren
Copy link
Member

slaren commented Apr 30, 2025

I also see a good improvement on Windows:

Model Microbatch size Test t/s master t/s cuda-moe-mmq-5 Speedup
deepseek2 16B Q4_0 16 pp2048 142.01 459.31 3.23
deepseek2 16B Q4_0 32 pp2048 248.05 843.66 3.40
deepseek2 16B Q4_0 64 pp2048 351.65 1349.81 3.84
deepseek2 16B Q4_0 128 pp2048 652.36 2186.80 3.35
deepseek2 16B Q4_0 256 pp2048 1135.34 3512.99 3.09
deepseek2 16B Q4_0 512 pp2048 1863.29 4589.94 2.46
deepseek2 16B Q4_0 1024 pp2048 2923.51 5335.78 1.83
deepseek2 16B Q4_0 2048 pp2048 3787.59 5622.27 1.48
Model K type Test t/s master t/s cuda-moe-mmq-5 Speedup
llama 8B Q4_0 f16 pp512 5998.92 5884.94 0.98
llama 8B Q4_0 q8_0 pp512 4422.87 5800.51 1.31

@JohannesGaessler JohannesGaessler merged commit e1e8e09 into ggml-org:master Apr 30, 2025
48 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 1, 2025
* origin/master:
sync : ggml
whisper : add check that target name exists (whisper/3103)
ggml : suppress Windows compiler warnings (whisper/3075)
mtmd : add **vision** support for Mistral Small 3.1 (ggml-org#13231)
arg : remove CURLINFO_EFFECTIVE_METHOD (ggml-org#13228)
llama-model : fix the reported size class for nomic-embed-text-v2-moe (ggml-org#13223)
sync : ggml
ggml : fix ggml_gallocr_ptr type (ggml/1205)
cuda : fix unused variable compile warning (whisper/0)
CUDA: batched+noncont MMQ, refactor bs>1 MoE code (ggml-org#13199)
arg : -hf do not fail if url mismatch (ggml-org#13219)
fix typo: `n_ctx_pre_seq` -> `n_ctx_per_seq` (ggml-org#13221)
convert : improve model arch handling (ggml-org#13122)
llava : remove duplicate include (ggml-org#13207)
common : add -jf / --json-schema-file flag (ggml-org#12011)
danielhanchen added a commit to unslothai/llama.cpp that referenced this pull request May 2, 2025
@danielhanchen
Copy link
Contributor

@JohannesGaessler Sadly I'm getting:

/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:75: CUDA error
CUDA error: invalid configuration argument
  current device: 0, in function ggml_cuda_mul_mat_id at /llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2055
  cudaGetLastError()

If I revert the commit, then everything works fine.

I'm using H100 and CUDA 12.6

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented May 2, 2025

Using which model and which exact command?

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request May 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants