Skip to content

[AMD] Added bufferOps refinement #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: refine-ops-pass
Choose a base branch
from

Conversation

ravil-mobile
Copy link

@guacamoleo, I added the buffer ops refinement. I tested it with our GEMM kernel. Numerics is ok but performance dropped by a factor 2. @giuseros, what do you think?

❯ AMDGCN_USE_BUFFER_OPS=1 TRITON_HIP_STREAM_MAX_DEPTH=1 MLIR_ENABLE_DUMP=0 TRITON_PRINT_AUTOTUNING=0 python3 ./gemm-ex.py -f ./exp-config.yaml --dump-ir ttgir
MASKING load/store: disabled
MATRIX B TRANSPOSED: false
use_bias=False
matmul_kernel
perf: 215.41678697824958 TFLOP/s
✅ Triton and Torch match

❯ AMDGCN_USE_BUFFER_OPS=1 TRITON_HIP_STREAM_MAX_DEPTH=1 MLIR_ENABLE_DUMP=1 TRITON_PRINT_AUTOTUNING=0 python3 ./gemm-ex.py -f ./exp-config.yaml --dump-ir ttgir --trans-b
MASKING load/store: disabled
MATRIX B TRANSPOSED: true
use_bias=False
matmul_kernel
perf: 213.28888766541104 TFLOP/s
✅ Triton and Torch match

❯ AMDGCN_USE_BUFFER_OPS=1 TRITON_HIP_STREAM_MAX_DEPTH=1 MLIR_ENABLE_DUMP=0 TRITON_PRINT_AUTOTUNING=0 python3 ./gemm-ex.py -f ./exp-config.yaml --dump-ir ttgir --trans-a
MASKING load/store: disabled
MATRIX B TRANSPOSED: false
use_bias=False
matmul_kernel
perf: 230.4643251881426 TFLOP/s
✅ Triton and Torch match

❯ AMDGCN_USE_BUFFER_OPS=1 TRITON_HIP_STREAM_MAX_DEPTH=1 MLIR_ENABLE_DUMP=0 TRITON_PRINT_AUTOTUNING=0 python3 ./gemm-ex.py -f ./exp-config.yaml --dump-ir ttgir --trans-a --trans-b
MASKING load/store: disabled
MATRIX B TRANSPOSED: true
use_bias=False
matmul_kernel
perf: 207.3770854286465 TFLOP/s
✅ Triton and Torch match

@guacamoleo
Copy link

Thanks Ravil!
I'm not worried about performance at this point, because we aren't really doing scheduling yet.

I see that refining buffer loads is a new function from traditional loads. We should think of a way to be able to combine many of these refinement function together to consodidate this code as much as possible. We can discuss offline. Maybe there's a way of keeping the common code in the function, but passing in a functor which does the unique operations.

@ravil-mobile
Copy link
Author

@guacamoleo, I tested correctness of the FA kernel with refined buffer ops; numerics is correct

@guacamoleo
Copy link

@guacamoleo, I tested correctness of the FA kernel with refined buffer ops; numerics is correct

Great. How about consolidating code? Is there any way to merge the loads into a single refinement function? I'm concerned about all the duplication which has been inherant in our support. If there's no way to do it right now, we'll want to address consolidating refinement code before we try to upstream this.

@guacamoleo
Copy link

We discussed this offline; commit looks good after rebasing with branch.

@guacamoleo
Copy link

Just a reminder to look into the issue with 16-bit memory ops resulting https://github.com/ROCm/triton-internal/issues/699#issuecomment-2835306030 in conjunction with this commit.

@ravil-mobile ravil-mobile force-pushed the ravil/refine-buffer-ops branch 2 times, most recently from 8590ec0 to 89c1f5f Compare April 30, 2025 13:46
Copy link

@guacamoleo guacamoleo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're just copying axisInfo from operand[0] of extract_slice to the extract slice op itself. Is it implicit in doing this that these values get re-calculated for the smaller sizes? It seems like there are cases where we do change contiguity and divisibility. For example, if an operand was 4 vgprs which need to be allocated contiguously, then we split that in half, the new operands are only 2 vgprs which don't need to be 4-contiguous.
Am I understand or mis-understanding this?

@ravil-mobile
Copy link
Author

It looks like we're just copying axisInfo from operand[0] of extract_slice to the extract slice op itself. Is it implicit in doing this that these values get re-calculated for the smaller sizes? It seems like there are cases where we do change contiguity and divisibility. For example, if an operand was 4 vgprs which need to be allocated contiguously, then we split that in half, the new operands are only 2 vgprs which don't need to be 4-contiguous. Am I understand or mis-understanding this?

Yes, you are correct. The scenario that you describe may happen. Let me think about a solution

@ravil-mobile ravil-mobile force-pushed the ravil/refine-buffer-ops branch 3 times, most recently from d9ea8ba to 240b651 Compare May 6, 2025 15:27
@ravil-mobile
Copy link
Author

ravil-mobile commented May 6, 2025

Hi @guacamoleo, I added some to recompute the AxisInfo. Please, verify:

auto srcType = cast<RankedTensorType>(op.getOperand().getType());
auto srcShape = srcType.getShape();
auto dstType = cast<RankedTensorType>(op.getResult().getType());
auto dstShape = dstType.getShape();
auto offsets = op.getStaticOffsets();
AxisInfo opInfo = operands[0]->getValue();
auto origContiguity = opInfo.getContiguity();
auto origDivisibility = opInfo.getDivisibility();
auto origConstancy = opInfo.getConstancy();
auto recompute = [](ArrayRef<int64_t> vec, int64_t c) {
auto result = std::numeric_limits<int64_t>::max();
for (auto &v : vec) {
// compute the upper bound of `v` based on `contiguity`
auto newC = ((v + c - 1) / c) * c - v;
// make sure that the new value is not broken because
// of the sliced boundaries
newC = newC == 0 ? c : newC;
// conside the minumal value along each dimension
result = result > newC ? newC : result;
}
assert(vec.size() == 2);
const auto dimSize = vec[1] - vec[0];
// make sure that the value doesn't exceed the dimension size
return result > dimSize ? dimSize : result;
};
SmallVector<int64_t> contiguity(origContiguity.size());
SmallVector<int64_t> divisibility(opInfo.getDivisibility().size());
SmallVector<int64_t> constancy(opInfo.getConstancy().size());
for (size_t dim = 0; dim < opInfo.getRank(); ++dim) {
auto start = offsets[dim];
auto end = start + dstShape[dim];
contiguity[dim] = recompute({start, end}, origContiguity[dim]);
// note: contiguity cannot increase while slicing a tensor
assert(contiguity[dim] <= origContiguity[dim]);
constancy[dim] = recompute({start, end}, origConstancy[dim]);
divisibility[dim] = origDivisibility[dim];
if (contiguity[dim] != origContiguity[dim]) {
// note: assume n is the largest power of two that divides `x` and `x +
// c`
// 1. x % n = 0 and 2. (x + c) % n = 0
// reminder of a sum can be calculated as: 3. (x + c) % n = (x % n + c %
// n) % n = 0 becuase of 1. one can write 4. (c % n) % n or 5. c % n = 0
divisibility[dim] = std::min(
origDivisibility[dim],
int64_t(log2Int(highestPowOf2Divisor<int64_t>(contiguity[dim]))));
}
}

@ravil-mobile ravil-mobile force-pushed the ravil/refine-buffer-ops branch from 240b651 to 91ac21d Compare May 6, 2025 15:53
@@ -626,6 +626,100 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
}
};

struct AMDGCNBufferLoadOp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an explanatory comment regarding how refinement of buffer_load is more complex than that of global_load; it looks like we need to examine the refinement of masks, otherTensor and offsets and bring it all together. This'll make the function more understandable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// -----

#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding a test!
I recall us determining that contiguity, divisitbility and constancy can all change from extract slice; can you add a test where all 3 change and we correctly test that behavior?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to simply propagate the AxisInfo from extract_slice. it is a part of the upstream

@ravil-mobile ravil-mobile force-pushed the refine-ops-pass branch 2 times, most recently from 3070646 to 0341d75 Compare June 23, 2025 14:48
@ravil-mobile ravil-mobile force-pushed the ravil/refine-buffer-ops branch from d44e4d9 to 4951510 Compare June 23, 2025 15:09
@ravil-mobile ravil-mobile force-pushed the ravil/refine-buffer-ops branch from 4951510 to 62a8438 Compare June 23, 2025 15:25
@ravil-mobile
Copy link
Author

Hi @guacamoleo, I rebased this PR using the latest refine-ops-pass branch. Could you, please, re-review

@ravil-mobile ravil-mobile requested a review from guacamoleo June 23, 2025 15:30
Copy link

@guacamoleo guacamoleo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Ravil. I had looked at this more closely and seen that there are multiple differences from global_loads so it doesn't make sense to try and merge this with global_load.
This looks good if tests passing and merge conflict fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants