Skip to content

llama-bench : Add --override-tensors arg #12922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

4onen
Copy link

@4onen 4onen commented Apr 12, 2025

A small group over at BeaverAI have been making extensive use of the --override-tensors (-ot) flag for running massive MOE models faster by keeping attention on the GPU and offloading the expert FFNs to the CPU. Informal experimentation in llama-server or llama-cli doesn't compare to the proper llama-bench, though, so this PR adds the --override-tensors arg (and the -ot short form) to llama-bench.

I noticed the // FIXME about leaking memory in args.cpp when copying the --override-tensors argument parsing, and chose to stamp null terminators into the argv, rather than accept the memory leak, as llama-bench calls parse_cmd_params only once. Let me know if you'd like that swapped out for the memory-leaking version from the common arg parser, as it's only a handful of user-entered bytes leaked.

Also planning to do some documentation of --override-tensors a little later on, as it's proving very useful and we'd love to spread the word.

@4onen
Copy link
Author

4onen commented Apr 12, 2025

Sketchy performance comparison on my laptop to show why --override-tensors helps MoE models. I set longer context lengths than the standard llama-bench to emphasize why keeping the attention operations on the GPU is important.

My hardware is an ASUS TUF A14 gaming laptop, so a Ryzen 9 AI HX 370 with 7500MHz LPDDR5 and an RTX 4060 Mobile. I run it for these tests in the ASUS-standard "Turbo" mode.

First, a CPU-only test on my hardware (used 0.3 GB of VRAM during prompt processing)

.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 0 -p 4096 -n 4096
model size params backend ngl threads test t/s
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 0 8 pp4096 631.85 ± 15.23
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 0 8 tg4096 44.04 ± 1.76

Next, running with -ngl 4 to offload some layers. I use such a low layer offload number to limit the VRAM use to just 2.2 GB, e.g. pretending this is a massive model that doesn't fit. Didn't have the time to spend re-running the tests until I got exactly the same VRAM use as with -ot below.

.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 4 -p 4096 -n 4096
model size params backend ngl threads test t/s
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 4 8 pp4096 750.98 ± 4.15
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 4 8 tg4096 36.27 ± 0.19

Next, enabling the --override-tensors via the -ot short-form. Because of the CPU-overridden tensors, we can set -ngl 99 and still only use 1.3GB of VRAM.

.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 99 -ot "\d+\.ffn_.*exp.=CPU" -p 4096 -n 4096
model size params backend ngl threads test t/s
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 99 8 pp4096 736.91 ± 2.13
olmoe A1.7B Q8_0 6.85 GiB 6.92 B CUDA,RPC 99 8 tg4096 46.26 ± 0.93

Effects are significantly more pronounced in larger MoE models, especially with more experts and some experts that are re-used for every pass (e.g. Llama 4 Scout and Maverick, although those models are beyond my devices' capabilities.) I tried to demonstrate with Deepseek-V2-Lite, but ran into CUDA errors if I tried to apply flash attention, cache quantization, or override-tensors. I don't have the experience with llama.cpp's codebase to track those down, but another Beaver has suggested it may be related to #12798

@4onen 4onen changed the title Add --override-tensors option to llama-bench llama-bench : Add --override-tensors arg Apr 13, 2025
@4onen
Copy link
Author

4onen commented Apr 14, 2025

PR #12891 has resolved my issue running flash attention and override-tensors with Deepseek-V2-Lite. Some performance numbers for that, same hardware as my last set:

CPU Only (Used 0.8GB of VRAM during prompt processing)

.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
    -p 4096 -n 4096 -t 8 ^
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 0
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 0 8 q8_0 q8_0 1 pp4096 76.48 ± 2.94
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 0 8 q8_0 q8_0 1 tg4096 20.13 ± 1.65

Completely Filled GPU (Used 8.0GB of VRAM during prompt processing)

.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
    -p 4096 -n 4096 -t 8 ^
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 14
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 14 8 q8_0 q8_0 1 pp4096 102.89 ± 0.54
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 14 8 q8_0 q8_0 1 tg4096 15.36 ± 1.39

Comparable VRAM GPU (Used 2.8GB of VRAM during prompt processing)

.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
    -p 4096 -n 4096 -t 8 ^
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 4
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 4 8 q8_0 q8_0 1 pp4096 61.07 ± 10.01
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 4 8 q8_0 q8_0 1 tg4096 13.25 ± 0.36

Override-Tensors Run (Used 1.8GB of VRAM during prompt processing)

.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
    -p 4096 -n 4096 -t 8 ^
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "\d+\.ffn_.*exp.=CPU" 
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 8 q8_0 q8_0 1 pp4096 100.06 ± 1.92
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 8 q8_0 q8_0 1 tg4096 13.11 ± 0.21

Tuned Override-Tensors (Used 6.3GB of VRAM during prompt processing)

This run, I'm leaving 6 of the 26 layers' conditional experts on the GPU as well as all the shexp (shared expert) layers, to try to better fill the VRAM and hopefully get the full best of both worlds.

.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
    -p 4096 -n 4096 -t 8 ^
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "[12]\d\.ffn_.*exps.=CPU" 
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 8 q8_0 q8_0 1 pp4096 63.12 ± 0.37
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 8 q8_0 q8_0 1 tg4096 13.98 ± 0.13

Turns out my GPU was far more underpowered than I expected, but y'all can see the point of being able to benchmark this kind of thing.

@4onen
Copy link
Author

4onen commented Apr 14, 2025

Ran another set of experiments on another device (RTX 3070 and an AMD Ryzen 7 5800X 8-Core with two sticks of 2133MHz DDR4)

CPU Only (Used 836MB of VRAM during prompt processing)

./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
    -p 4096 -n 4096 -t 4 \
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 0
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 0 4 q8_0 q8_0 1 pp4096 62.50 ± 0.09
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 0 4 q8_0 q8_0 1 tg4096 9.51 ± 0.19

Full GPU (Used 7626MB of VRAM during prompt processing)

./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
    -p 4096 -n 4096 -t 4 \
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 13
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 13 4 q8_0 q8_0 1 pp4096 67.20 ± 0.11
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 13 4 q8_0 q8_0 1 tg4096 11.80 ± 0.03

Comparable VRAM GPU (Used 2930MB of VRAM during prompt processing)

./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
    -p 4096 -n 4096 -t 4 \
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 4
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 4 4 q8_0 q8_0 1 pp4096 62.74 ± 0.14
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 4 4 q8_0 q8_0 1 tg4096 10.13 ± 0.01

Override-Tensors Full CPU Experts (except shared) (Used 2276MB of VRAM during prompt processing)

./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
    -p 4096 -n 4096 -t 4 \
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "\d+.ffn_.*exps.=CPU"
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 4 q8_0 q8_0 1 pp4096 62.79 ± 0.13
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 4 q8_0 q8_0 1 tg4096 11.80 ± 0.03

Override-Tensors Tuned (Used 7034MB of VRAM during prompt processing)

./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
    -p 4096 -n 4096 -t 4 \
    -fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "[2.]\d.ffn_.*exps.=CPU"
model size params backend ngl threads type_k type_v fa test t/s
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 4 q8_0 q8_0 1 pp4096 66.80 ± 0.06
deepseek2 16B Q6_K 13.56 GiB 15.71 B CUDA,RPC 99 4 q8_0 q8_0 1 tg4096 14.05 ± 0.02

Now, as the processor doesn't have AVX512 and relatively high bandwidth memory, we see the GPU eeking out a performance boost and override-tensors helping significantly.

@ddh0
Copy link
Contributor

ddh0 commented Apr 15, 2025

You can also use this to offload the entire KV cache to GPU while keeping the model on CPU: -ngl 999 -ot "^.*$=CPU"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants