Skip to content

[AMD] Added canonicalization pattern to propagate DotOp attrs #792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 50 commits into
base: refine-ops-pass
Choose a base branch
from

Conversation

ravil-mobile
Copy link

ravil-mobile and others added 30 commits April 28, 2025 09:49
guacamoleo and others added 20 commits April 28, 2025 09:49
This PR:

    makes Refine pass FuncOp based, instead of ModuleOp
    replaces walkers with PatternRewriters

Co-authored-by: ravil-mobile <[email protected]>
[AMD] Added `local_alloc` refinement
- Added a lit-test to `elementwise.mlir` which test refinement of `convert_layout` from one `mma` to another `mma` layoyts
- Added a bug fix
@guacamoleo
Copy link

Alright, I'm getting back to this after our May sprint.

For this PR can you add a clarifying comment as to what direction attributes are being propagated and why. You could add something like the below explanation, but fixup the syntax a bit.

// DotOpMFMAConverter() add dotAttr to refine dot ops, and to the extract_slice of the dot operands.
// When the canonicalizer matches the concat of local_loads and extract_slice of the dot operands,
// the dotAttr needs to be coppied from the extract_slice of the dot operand to the local_load.
// This allows Triton to match refined local_loads to refined dots. For example
regA = local_load addrA
regB = local_load addrB
dot regA regB

--- after refinement ---

addrA0 = extract_slice addrA, 0
addrA1 = extract_slice addrA, 1
addrB0 = extract_slice addrB, 0
addrB1 = extract_slice addrB, 1

regA0 = local_load addrA0
regA1 = local_load addrA1
regB0 = local_load addrB0
regB1 = local_load addrB1

regA = concat regA0 regA1
regB = concat regB0 regB1

regA0 = extract_slice regA, 0  dotAddr00
regA1 = extract_slice regA, 1  dotAddr10
regB0 = extract_slice regB, 0  dotAddr00
regB1 = extract_slice regB, 1  dotAddr01

dot regA0, regB0  dotAttr00
dot regA0 regB1  dotAttr01
dot regA1, regB0  dotAttr10
dot regA1, regB1  dotAttr11

--- after canonicalization ---

addrA0 = extract_slice addrA, 0
addrA1 = extract_slice addrA, 1
addrB0 = extract_slice addrB, 0
addrB1 = extract_slice addrB, 1

regA0 = local_load addrA0  dotAddr00
regA1 = local_load addrA1 dotAddr10
regB0 = local_load addrB0  dotAddr00
regB1 = local_load addrB1  dotAddr01

dot regA0, regB0  dotAttr00
dot regA0 regB1  dotAttr01
dot regA1, regB0  dotAttr10
dot regA1, regB1  dotAttr11


auto definingOp = operand.getDefiningOp();
if (definingOp->getBlock() != currBlock)
continue;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what this is prohibiting; we do want to allow these attributes to be carries across basic blocks. For example, they could be loop carried, or we could prefetch a tile of data in the prologue, and we want to get the order of loads in the prologue to match the dot ordering in the loop.

allowedDialects |=
mlir::isa<amdgpu::TritonAMDGPUDialect>(definingOp->getDialect());
if (!allowedDialects)
continue;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this prohibiting and allowing?

Comment on lines +468 to +480
SmallVector<Attribute> updatedAttrs(operandArrayAttr.getValue());
auto result = operandArrayAttr.walk([&targetAttr](AttrType itemAttr) {
return itemAttr == targetAttr ? WalkResult::interrupt()
: WalkResult::advance();
});

if (!result.wasInterrupted()) {
updatedAttrs.push_back(targetAttr);
}

if (!updatedAttrs.empty())
definingOp->setAttr(
attrName, mlir::ArrayAttr::get(op->getContext(), updatedAttrs));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this block of code ensuring we only add a dotAttr if one doesn't already exist?
If so, this can also be used as a recursion exit criteria. For example, after we've labeled all the refined local_loads with their dotAttr, then we recursively label all their respective subviews with dot attribute, which is good.
But then when we recur from all the subview to the same parent, we can't place all the children't dotAttr onto the same parent (since the labels only refer to refinement), so we can exit the recursion at that point. Meaning when an op already has a dotAttr, don't recur to it's parent.

@ravil-mobile ravil-mobile force-pushed the refine-ops-pass branch 2 times, most recently from 3070646 to 0341d75 Compare June 23, 2025 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants