Skip to content

Bypass LDS for scale B operand for skinny gemms #817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: shared/triton-gfx950-launch
Choose a base branch
from

Conversation

plognjen
Copy link

@plognjen plognjen commented May 29, 2025

Skip LDS for the scale B tensor when warpsPerCTA is {1, numWarps} and
the load layout matches the expected layout for scale B in the dotScaled op.

@plognjen plognjen marked this pull request as ready for review May 29, 2025 15:21
mlir::triton::LinearLayout scaleBLayout =
mlir::triton::gpu::toLinearLayout(scaleBTy.getShape(),
scaleBTy.getEncoding());
bypassLDS = bypassLDS ||

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this doing here? Is it checking if bypassing LDS succeeded?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @plognjen wanted to restore the previous condition, i.e. width < 32 should bypassLDS.
If this is the case, maybe we can use another variable to store the value of (width < 32) rather than bypassLDS to avoid any confusions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this was to restore the previous condition. I will change the name.

@@ -672,7 +673,40 @@ void StreamPipeliner::assignMemoryLayouts() {
// Only use shared memory when feeding into a dot op.
loadInfo.usedByDot = true;
// If the max continugous bits we can read is < 32, buffer in registers.
if (width >= 32) {
bool bypassLDS = width < 32;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we're only bypassing LDS when the we're loading smaller than dword, such as buffer_load_short or buffer_load_ushort?
Are there other cases when bypass LDS could be beneficial? If so, let's add a comment reminding us of those additional scenarios.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to preshuffling, width is guaranteed to be >= 32. Therefore, it's confusing to enable bypassLDS only when width < 32.
More generally, bypassLDS should not check width. Later it checks if the loaded layout is the same as the scale layout, and this makes sure width = 32.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants