Skip to content

Fused CPU attention kernels (~4x performance increase) #2973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

EricLBuehler
Copy link
Member

This introduces fused CPU attention kernels for optimized CPU inference. This removes the necessity to materialize the attention matrices, thereby dramatically improving throughput.

On an M3 Max with Llama 3.2 3b at 4-bit quantization, I am measuring a 4x increase in decode T/s. This is faster than llama.cpp, even with llama.cpp CPU FlashAttention enabled.

These kernels are loosely based on the work in FlashAttention and CPU implementations in vLLM and ggml, but have been modified for higher performance.

Algorithm

run_flash_attn_cpu

  1. Choose execution path

    • Decode path: if the query length S_q == 1, invoke a specialized “single-Q” routine
    • Batched path: otherwise, invoke the general batched attention routine
  2. Compute attention

    • Parallel setup
      • Uses a custom Rayon thread-pool (FLASH_ATTN_POOL) with macOS QoS hints
      • Installs the pool via FLASH_ATTN_POOL.install(...) to isolate flash-attention tasks
    • Work distribution
      • Batched: flattens the output into chunks of size D and calls
        out.par_chunks_mut(dv)
           .with_min_len(64)
           .enumerate()
           .for_each(...)
        to assign each (batch, head, query_pos) row to a Rayon worker
      • Decode: further splits the KV axis into cache-friendly tiles, then does
        (0..kv_tiles)
          .into_par_iter()
          .map(...)     // per-tile map
          .reduce(...)  // numerically-stable softmax reduce
        achieving nested parallelism for long KV sequences
    • Per-row computation
      1. Gather the query vector
      2. Loop over all key/value positions:
        • Apply mask and positional bias
        • Compute dot-product between query and key
        • Update an online softmax (log-sum-exp) in a streaming fashion
        • Weight and accumulate the value vectors
      3. Normalize the accumulated value sum by the softmax denominator
  3. Assemble result

    • Collect all per-row outputs into a flat buffer
    • Reshape into the final tensor of shape (B, S_q, H, D)
    • Return the result on the CPU device

@EricLBuehler
Copy link
Member Author

@LaurentMazare could you please review this PR?

@AlpineVibrations
Copy link

should this help with quantized qwen3 on Mac with CPU?
Does it help on Mac M1 ?
thanks

@EricLBuehler
Copy link
Member Author

This PR doesnt integrate it into any models yet; but that would be relatively easy.

Once that is done, yes. I saw ~4x T/s increase for CPU inference.

@AlpineVibrations
Copy link

ok. I see. so is there a way we can test it right now? or a sample on how you integrated it?
thanks

@EricLBuehler
Copy link
Member Author

I didn't include it in this PR for ease of review, but you would replace the attention block of any model to call this function.

For Qwen 3:

let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
if let Some(m) = attn_mask {
scores = scores.broadcast_add(m)?;
}
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
let ctx = probs.matmul(&v)?; // (B, H, L, D)

Note that qwen3's q/k/v shapes are (b, h, seq_len, d), but this kernel requires (b, seq_len, h, d). Therefore you need to transpose q/k/v with .transpose(1,2).

For a real-world use-case, I would explore the mistral.rs attention backend and dispatch code, and how it's used in a model (like Qwen 3).

@AlpineVibrations
Copy link

how's this looking. ? is it done? would be great to merge in if is done. and if its not done what is still needed on it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants