-
Notifications
You must be signed in to change notification settings - Fork 32
Bypass LDS for scale B operand for skinny gemms #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: shared/triton-gfx950-launch
Are you sure you want to change the base?
Conversation
mlir::triton::LinearLayout scaleBLayout = | ||
mlir::triton::gpu::toLinearLayout(scaleBTy.getShape(), | ||
scaleBTy.getEncoding()); | ||
bypassLDS = bypassLDS || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this doing here? Is it checking if bypassing LDS succeeded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @plognjen wanted to restore the previous condition, i.e. width < 32 should bypassLDS.
If this is the case, maybe we can use another variable to store the value of (width < 32) rather than bypassLDS
to avoid any confusions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this was to restore the previous condition. I will change the name.
@@ -672,7 +673,40 @@ void StreamPipeliner::assignMemoryLayouts() { | |||
// Only use shared memory when feeding into a dot op. | |||
loadInfo.usedByDot = true; | |||
// If the max continugous bits we can read is < 32, buffer in registers. | |||
if (width >= 32) { | |||
bool bypassLDS = width < 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, we're only bypassing LDS when the we're loading smaller than dword, such as buffer_load_short or buffer_load_ushort?
Are there other cases when bypass LDS could be beneficial? If so, let's add a comment reminding us of those additional scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to preshuffling, width
is guaranteed to be >= 32. Therefore, it's confusing to enable bypassLDS only when width < 32.
More generally, bypassLDS should not check width. Later it checks if the loaded layout is the same as the scale layout, and this makes sure width = 32.
Skip LDS for the scale B tensor when warpsPerCTA is {1, numWarps} and
the load layout matches the expected layout for scale B in the dotScaled op.