Skip to content

use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… #39412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sywangyi
Copy link
Contributor

…tention
the GQA could be accelerated in torch.nn.functional.scaled_dot_product_attention. this pytorch api offer a param to enable gqa. see https://docs.pytorch.org/docs/2.7/generated/torch.nn.functional.scaled_dot_product_attention.html#torch-nn-functional-scaled-dot-product-attention

Signed-off-by: Wang, Yi A <[email protected]>
@liangan1
Copy link

@LuFinch pls help to review this pr.

@sywangyi
Copy link
Contributor Author

FAILED tests/models/nougat/test_image_processing_nougat.py::NougatImageProcessingTest::test_slow_fast_equivalence_batched - AssertionError: 0.005013074725866318 not less than or equal to 0.005 this failure case has nothing to do with the PR. the case does not call sdpa attention

@vasqu
Copy link
Contributor

vasqu commented Jul 15, 2025

Please see #35235 (comment)

The enable_gqa kwarg is pretty restrictive and would need proper checks around it (version, mask) to ensure we do not fall back to the math kernel / use unsupported features of older torch.

Signed-off-by: Wang, Yi A <[email protected]>
@sywangyi
Copy link
Contributor Author

Please see #35235 (comment)

The enable_gqa kwarg is pretty restrictive and would need proper checks around it (version, mask) to ensure we do not fall back to the math kernel / use unsupported features of older torch.

thanks for the review. add check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants