Skip to content

[AMD] Added Refine-Reschedule hint for rescheduling #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: refine-ops-pass
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c0d4c27
[AMD] moved membar analysis to its dedicated pass
ravil-mobile Feb 3, 2025
e6813be
[AMD] Added a skeleton for the `RefineOpsPass`
ravil-mobile Jan 20, 2025
0be1454
[AMD] Added a prototype of convertDot to RefineOps pass
ravil-mobile Jan 23, 2025
ee2a1d5
[AMD] Added a proto of the ConcatOp to the AMDGPU dialect
ravil-mobile Jan 24, 2025
b5e4c94
[AMD] Added ConcatOpToLLVM rewrite pattern
ravil-mobile Jan 24, 2025
be7cf7f
[AMD] Added tt.LoadOp and tt.LocalStore refinement
ravil-mobile Jan 27, 2025
a9ffb03
[AMD] restructured `refineOps.cpp` and added `reschedule` pass
ravil-mobile Feb 3, 2025
368565e
[AMD] added `refine-ops` opstion as a sched.hint variant
ravil-mobile Feb 4, 2025
75b0d30
[AMD] Added a proto of the dependency-graph
ravil-mobile Feb 5, 2025
5ed7104
[AMD] refactored dependency-graph builders
ravil-mobile Feb 6, 2025
3113263
[AMD] moved ttg.local_load rewrite in refineOps to its own place
ravil-mobile Feb 7, 2025
b1598e1
[AMD] Implemented machine model for the rescheduling pass
ravil-mobile Feb 10, 2025
5d0c8b3
Adding support for dot-tiling
guacamoleo Feb 11, 2025
b97c62a
Adding DotTileAttr to extract_slices of local_loads
guacamoleo Feb 12, 2025
dc03b4d
Fixing dot-tiling for different shapes and orders.
guacamoleo Feb 12, 2025
c6375e0
Reformat from clang-format.
guacamoleo Feb 13, 2025
84a29d2
Addressing review improvements.
guacamoleo Feb 17, 2025
e2f023c
Applied clang-format to DotTiling.h and RefineOps.cpp
ravil-mobile Feb 19, 2025
d497e95
[AMD] Fixed lit-tests
ravil-mobile Feb 19, 2025
d0e61ad
[AMD] Improved re-scheduling
ravil-mobile Feb 17, 2025
016e44d
[AMD] adjusted calculations coming from calcDotTileShape
ravil-mobile Feb 19, 2025
6a067c0
[AMD] improved canonicalizer for extract and concat ops
ravil-mobile Feb 25, 2025
109ae4c
[AMD] bug-fix in reschedule-ops pass: barriers for local_alloc ops
ravil-mobile Feb 25, 2025
e0979b1
[AMD] Changed machine model op selector in resched. pass
ravil-mobile Feb 26, 2025
9083b7d
DotTiling to handle invalid configurations more robustly.
guacamoleo Feb 28, 2025
64e996b
addressing review comments
guacamoleo Feb 28, 2025
2a06f1f
adding unverified/incorrect support for ttg.extract_slice(ttg.slice)
guacamoleo Mar 3, 2025
07b4ef5
Preparing extract_slice support for checkin
guacamoleo Mar 3, 2025
63eb1cd
Adding tests and removing patch.
guacamoleo Mar 4, 2025
80c1872
Adding empty rewrite function
guacamoleo Mar 5, 2025
186daf6
refined ReduceOp working
guacamoleo Mar 6, 2025
5662f5c
preparing refine reduce for reivew
guacamoleo Mar 7, 2025
5c7b83d
addressing issues and running clang-format
guacamoleo Mar 10, 2025
f1b18b9
re-running clang-format on files outside this commit
guacamoleo Mar 10, 2025
b9e0079
Merging 1d and 2d into combined function.
guacamoleo Mar 5, 2025
994330d
[AMD] fixed calculations in `fastPathComputeOffsets`
ravil-mobile Mar 12, 2025
9f5b93e
[AMD] Added changes required after rebasing
ravil-mobile Mar 17, 2025
58b3355
[AMD] Added bugfixes: canonicalizer, extract_slice
ravil-mobile Mar 28, 2025
6d97bf6
[AMD] Fixed/Extended ConcatOp lowering to LLVM
ravil-mobile Mar 31, 2025
d89549a
[AMD] Added elementwise broadcast ops refinement
ravil-mobile Apr 1, 2025
7b58b4d
[AMD] Fixed MFMA to Linear Layout conversion for 1D tensors
ravil-mobile Apr 3, 2025
e857d6c
[AMD] Adapted the code due to upstream changes
ravil-mobile Apr 17, 2025
a9b093a
Refine pattern matching rework (#783)
binarman Apr 29, 2025
387125d
[AMD] Removed redundant headers from `RefineOps.cpp`
ravil-mobile Apr 30, 2025
dd0a348
[AMD] Inlined refine-ops functions
ravil-mobile Apr 30, 2025
fb859b6
Merge pull request #786 from ROCm/ravil/refine-ops-inline
ravil-mobile Apr 30, 2025
064ff46
[AMD] Added `local_alloc` refinement
ravil-mobile May 2, 2025
51a62ff
Merge pull request #787 from ROCm/ravil/local-alloc-refine
ravil-mobile May 5, 2025
14573e7
[AMD] Added Refine-Reschedule hint for rescheduling
ravil-mobile May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerOptimizeAMDLDSUsage();

// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUMembarAnalysis();
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUHoistLayoutConversions();
Expand All @@ -80,6 +81,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {

// NVWS passes
mlir::registerNVWSTransformsPasses();
mlir::registerTritonAMDGPURefineOps();
mlir::registerTritonAMDGPURescheduleOps();

registry.insert<
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
10 changes: 0 additions & 10 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,16 +394,6 @@ class SharedMemoryObject {
return offsets[dim];
}

// TODO(Keren): deprecate the method once AMD backend has cleaned up
Value getBaseBeforeSlice(int dim, Location loc,
RewriterBase &rewriter) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value cSwizzleOffset = getCSwizzleOffset(dim);
Value offset = b.sub(b.i32_val(0), cSwizzleOffset);
Type type = base.getType();
return b.gep(type, baseElemType, base, offset);
}

private:
static SmallVector<unsigned>
getOrderForShape(ArrayRef<int64_t> shape, ArrayRef<unsigned> layoutOrder) {
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,22 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
identityStandardND(S("warp"), getWarpsPerCTA(), order);
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
auto combinedLayout =
combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);

auto bases = combinedLayout.getBases();
std::vector<std::vector<int>> newRegBases;
for (const auto &basis : bases[S("register")]) {
if (llvm::any_of(basis, [](int b) { return b != 0; })) {
newRegBases.push_back(basis);
}
}
bases[S("register")] = newRegBases;

auto result = LinearLayout(std::move(bases),
llvm::to_vector(combinedLayout.getOutDimNames()));

return result;
}

LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/async_ops_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --triton-amdgpu-membar-analysis --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --triton-amdgpu-membar-analysis --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
Expand Down
37 changes: 37 additions & 0 deletions test/TritonGPU/amd/amd-extractslice-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,40 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
tt.return
}
}

#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @extract_slice_slice_1(%arg0: tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> {tt.divisibility = 16 : i32}) {
// CHECK: llvm.func @extract_slice_slice_1
// CHECK-COUNT-8: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %8 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)>
%1 = amdgpu.extract_slice %arg0 [128] : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> to tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>>
tt.return
}
}

#blocked4 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @extract_slice_slice_0(%arg0: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> {tt.divisibility = 16 : i32}) {
// CHECK: llvm.func @extract_slice_slice_0
// CHECK-COUNT-8: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %8 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)>
%0 = amdgpu.extract_slice %arg0 [0] : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> to tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked4}>>
tt.return
}
}

#blocked5 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @extract_slice_slice_2() {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>>
// CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)>
%2 = amdgpu.extract_slice %0 [0] : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> to tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked5}>>
// CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)>
%3 = amdgpu.extract_slice %1 [0] : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> to tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked5}>>
tt.return
}
}
127 changes: 127 additions & 0 deletions test/TritonGPU/amd/ops-refinement/elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-refine-ops='arch=gfx942' | FileCheck %s

// CHECK-LABEL: @exp_kernel
// CHECK-DAG: [[VALUE_1:%.*]] = amdgpu.extract_slice {{.*}} [0, 0]
// CHECK-DAG: [[VALUE_2:%.*]] = math.exp2 [[VALUE_1]]
// CHECK-DAG: [[VALUE_3:%.*]] = amdgpu.extract_slice {{.*}} [0, 16]
// CHECK-DAG: [[VALUE_4:%.*]] = math.exp2 [[VALUE_3]]
// CHECK-DAG: [[VALUE_5:%.*]] = amdgpu.extract_slice {{.*}} [64, 0]
// CHECK-DAG: [[VALUE_6:%.*]] = math.exp2 [[VALUE_5]]
// CHECK-DAG: [[VALUE_7:%.*]] = amdgpu.extract_slice {{.*}} [64, 16]
// CHECK-DAG: [[VALUE_8:%.*]] = math.exp2 [[VALUE_7]]
// CHECK-DAG: [[VALUE_9:%.*]] = amdgpu.concat [[VALUE_2]], [[VALUE_4]], [[VALUE_6]], [[VALUE_8]]
// CHECK-DAG: tt.return [[VALUE_9]]
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @exp_kernel(%arg0: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} {
amdgpu.refine_reschedule_ops_hint
%0 = math.exp2 %arg0 : tensor<128x32xf32, #blocked>
tt.return %0 : tensor<128x32xf32, #blocked>
}
}

// -----

// CHECK-LABEL: mul_kernel
// CHECK-DAG: [[VALUE_1:%.*]] = amdgpu.extract_slice {{.*}} [0, 0]
// CHECK-DAG: [[VALUE_2:%.*]] = amdgpu.extract_slice {{.*}} [0, 0]
// CHECK-DAG: [[VALUE_3:%.*]] = arith.mulf [[VALUE_1]], [[VALUE_2]]
// CHECK-DAG: [[VALUE_4:%.*]] = amdgpu.extract_slice {{.*}} [0, 16]
// CHECK-DAG: [[VALUE_5:%.*]] = amdgpu.extract_slice {{.*}} [0, 16]
// CHECK-DAG: [[VALUE_6:%.*]] = arith.mulf [[VALUE_4]], [[VALUE_5]]
// CHECK-DAG: [[VALUE_7:%.*]] = amdgpu.extract_slice {{.*}} [64, 0]
// CHECK-DAG: [[VALUE_8:%.*]] = amdgpu.extract_slice {{.*}} [64, 0]
// CHECK-DAG: [[VALUE_9:%.*]] = arith.mulf [[VALUE_7]], [[VALUE_8]]
// CHECK-DAG: [[VALUE_10:%.*]] = amdgpu.extract_slice {{.*}} [64, 16]
// CHECK-DAG: [[VALUE_11:%.*]] = amdgpu.extract_slice {{.*}} [64, 16]
// CHECK-DAG: [[VALUE_12:%.*]] = arith.mulf [[VALUE_10]], [[VALUE_11]]
// CHECK-DAG: [[VALUE_13:%.*]] = amdgpu.concat [[VALUE_3]], [[VALUE_6]], [[VALUE_9]], [[VALUE_12]]
// CHECK-DAG: tt.return [[VALUE_13]]
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @mul_kernel(%arg0: tensor<128x32xf32, #blocked>, %arg1: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} {
amdgpu.refine_reschedule_ops_hint
%0 = arith.mulf %arg0, %arg1 : tensor<128x32xf32, #blocked>
tt.return %0 : tensor<128x32xf32, #blocked>
}
}

// -----

// CHECK-LABEL: @multiple_operations_kernel

// CHECK-COUNT-4: amdgpu.extract_slice {{.*}}
// CHECK: [[OP1:%.*]] = amdgpu.concat
// CHECK-COUNT-4: amdgpu.extract_slice [[OP1]]
// CHECK: [[OP2:%.*]] = amdgpu.concat
// CHECK-COUNT-4: amdgpu.extract_slice [[OP2]]
// CHECK: [[OP3:%.*]] = amdgpu.concat
// CHECK: tt.return [[OP3]]
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @multiple_operations_kernel(%arg0: tensor<128x32xf32, #mma>, %arg1: tensor<128x32xf32, #mma>) -> tensor<128x32xf32, #mma> attributes {noinline = false} {
amdgpu.refine_reschedule_ops_hint
%0 = math.exp2 %arg0 : tensor<128x32xf32, #mma>
%1 = math.exp2 %0 : tensor<128x32xf32, #mma>
%2 = math.exp2 %1 : tensor<128x32xf32, #mma>
tt.return %2 : tensor<128x32xf32, #mma>
}
}

// -----

// CHECK-LABEL: @nested_operations_kernel
// CHECK-COUNT-8: amdgpu.extract_slice
// CHECK: mulf
// CHECK: amdgpu.concat
// CHECK: scf.for
// CHECK-COUNT-4: amdgpu.extract_slice
// CHECK: math.exp2
// CHECK: amdgpu.concat
// CHECK: }
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @nested_operations_kernel(%arg0: tensor<128x32xf32, #blocked>, %arg1: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} {
amdgpu.refine_reschedule_ops_hint
%0 = arith.mulf %arg0, %arg1 : tensor<128x32xf32, #blocked>
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<128x32xf32, #blocked>) : i32 {
%2 = math.exp2 %0 : tensor<128x32xf32, #blocked>
scf.yield %2 : tensor<128x32xf32, #blocked>
}
tt.return %1 : tensor<128x32xf32, #blocked>
}
}

// -----

// CHECK-LABEL: @peer_operations_kernel
// CHECK: scf.for
// CHECK-COUNT-4: amdgpu.extract_slice
// CHECK: math.exp2
// CHECK: amdgpu.concat
// CHECK: scf.for
// CHECK-NOT: amdgpu.extract_slice
// CHECK: math.exp2
// CHECK-NOT: amdgpu.concat
// CHECK: }
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @peer_operations_kernel(%arg0: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c4 = arith.constant 4 : i32
%1 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %arg0) -> (tensor<128x32xf32, #blocked>) : i32 {
amdgpu.refine_reschedule_ops_hint
%2 = math.exp2 %arg2 : tensor<128x32xf32, #blocked>
scf.yield %2 : tensor<128x32xf32, #blocked>
}
%3 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %1) -> (tensor<128x32xf32, #blocked>) : i32 {
%4 = math.exp2 %arg4 : tensor<128x32xf32, #blocked>
scf.yield %4 : tensor<128x32xf32, #blocked>
}
tt.return %3 : tensor<128x32xf32, #blocked>
}
}
35 changes: 35 additions & 0 deletions test/TritonGPU/amd/ops-refinement/local_alloc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-refine-ops='arch=gfx942' -canonicalize | FileCheck %s

#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#smem = #ttg.shared_memory


// CHECK-LABEL: @local_alloc_refinement
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @local_alloc_refinement(%arg0: tensor<64x16xf16, #blocked>) attributes {noinline = false} {

// CHECK: [[OFFSET_12:%.*]] = arith.constant 12 : i32
// CHECK: [[OFFSET_8:%.*]] = arith.constant 8 : i32
// CHECK: [[OFFSET_4:%.*]] = arith.constant 4 : i32
// CHECK: [[OFFSET_0:%.*]] = arith.constant 0 : i32
// CHECK: [[ALLOC:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
// CHECK: [[SUBVIEW_0:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_0]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SLICE_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
// CHECK: ttg.local_store [[SLICE_0]], [[SUBVIEW_0]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SUBVIEW_1:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_4]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SLICE_1:%.*]] = amdgpu.extract_slice %arg0 [0, 4] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
// CHECK: ttg.local_store [[SLICE_1]], [[SUBVIEW_1]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SUBVIEW_2:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_8]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SLICE_2:%.*]] = amdgpu.extract_slice %arg0 [0, 8] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
// CHECK: ttg.local_store [[SLICE_2]], [[SUBVIEW_2]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SUBVIEW_3:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_12]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: [[SLICE_3:%.*]] = amdgpu.extract_slice %arg0 [0, 12] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
// CHECK: ttg.local_store [[SLICE_3]], [[SUBVIEW_3]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
// CHECK: amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
// CHECK: ttg.local_dealloc [[ALLOC]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
%0 = ttg.local_alloc %arg0 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem>
amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
tt.return
}
}
42 changes: 42 additions & 0 deletions test/TritonGPU/amd/ops-refinement/simple-dot.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline='num_stages=2' -cse -canonicalize -triton-amdgpu-refine-ops='arch=gfx942' -canonicalize | FileCheck %s

#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK: @matmul_kernel
tt.func public @matmul_kernel(
%arg0: tensor<256x64x!tt.ptr<f16>, #blocked> {tt.contiguity=16 : i32, tt.divisibility=16: i32, tt.constancy=16: i32},
%arg1: tensor<64x128x!tt.ptr<f16>, #blocked> {tt.contiguity=16 : i32, tt.divisibility=16: i32, tt.constancy=16: i32}) -> tensor<256x128xf32, #mma> attributes {noinline = false} {

%output = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c64_i32 = arith.constant 64 : i32

%shift0 = arith.constant dense<64> : tensor<256x64xi32, #blocked>
%shift1 = arith.constant dense<64> : tensor<64x128xi32, #blocked>

%0:3 = scf.for %arg2 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(
%loop_arg0 = %output,
%loop_arg1 = %arg0,
%loop_arg2 = %arg1) -> (
tensor<256x128xf32, #mma>,
tensor<256x64x!tt.ptr<f16>, #blocked>,
tensor<64x128x!tt.ptr<f16>, #blocked>) : i32 {
%1 = tt.load %loop_arg1 : tensor<256x64x!tt.ptr<f16>, #blocked>
%2 = tt.load %loop_arg2 : tensor<64x128x!tt.ptr<f16>, #blocked>
%3 = ttg.convert_layout %1 : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%4 = ttg.convert_layout %2 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%5 = tt.dot %3, %4, %loop_arg0 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
%6 = tt.addptr %loop_arg1, %shift0 : tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<256x64xi32, #blocked>
%7 = tt.addptr %loop_arg2, %shift1 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
scf.yield %5, %6, %7 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked>, tensor<64x128x!tt.ptr<f16>, #blocked>
}

tt.return %0#0 : tensor<256x128xf32, #mma>
}
}


// TODO: add TT GEMM case to the test
4 changes: 4 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ def make_llir(src, metadata, options):
passes.convert.add_index_to_llvmir(pm)

passes.ttgpuir.add_allocate_shared_memory(pm)
amd.passes.ttgpuir.add_membar_analysis(pm)
amd.passes.ttgpuir.add_refine_amdgpu_ops(pm, options.arch)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_reschedule_amdgpu_ops(pm, options.arch)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
Expand Down
Loading
Loading