Skip to content

Integrate MLX SDPA kernels with mask #2820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

EricLBuehler
Copy link
Member

This PR integrates kernel developments from: ml-explore/mlx#1924.

Specifically, our candle_nn::ops::sdpa function now dispatches to optimized implementations for with and without prompts. There is also an option for causal masking, removing the necessity for mask materialization.

Overall, this means that we can fuse the attention operation on Metal for prompt and decode phases!

I will update this PR further with benchmarks, but it is tested and working in my fork through mistral.rs.

@EricLBuehler EricLBuehler marked this pull request as ready for review March 22, 2025 01:38
Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 🎉

I only have 2 comments really.

One is that at this point it's really time to start precompiling kernels. Not for performance, just for the sake of project structure and maintainability.
I have that ready to go so I'll make a PR soon.

The second is that I wonder if it makes sense to move more of the logic in candle-nn into candle-metal-kernels.
The standard asserts etc is fine, but maybe figuring out the correct kernel to call is actually a concept that belongs inside candle-metal-kernels?
That's open for discussion obviously.

In any case this is ready to merge in my opinion - if all tests pass and it runs smoothly ☺️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants