-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Integrate MLX SDPA kernels with mask #2820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Integrate MLX SDPA kernels with mask #2820
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work 🎉
I only have 2 comments really.
One is that at this point it's really time to start precompiling kernels. Not for performance, just for the sake of project structure and maintainability.
I have that ready to go so I'll make a PR soon.
The second is that I wonder if it makes sense to move more of the logic in candle-nn
into candle-metal-kernels
.
The standard asserts etc is fine, but maybe figuring out the correct kernel to call is actually a concept that belongs inside candle-metal-kernels
?
That's open for discussion obviously.
In any case this is ready to merge in my opinion - if all tests pass and it runs smoothly
This PR integrates kernel developments from: ml-explore/mlx#1924.
Specifically, our
candle_nn::ops::sdpa
function now dispatches to optimized implementations for with and without prompts. There is also an option for causal masking, removing the necessity for mask materialization.Overall, this means that we can fuse the attention operation on Metal for prompt and decode phases!
I will update this PR further with benchmarks, but it is tested and working in my fork through mistral.rs.