Skip to content

Commit e4f83c9

Browse files
Fix gemm-tensor-of-ptr performance regression (#4209)
Verified all inductor tests below are passing: ``` python test/inductor/test_select_algorithm.py TestSelectAlgorithm.test_addmm_fp16 python test/inductor/test_select_algorithm.py TestSelectAlgorithm.test_convolution1 python test/inductor/test_max_autotune.py TestPrologueFusion.test_multiple_inputs_sizes2 python test/inductor/test_max_autotune.py TestPrologueFusion.test_upcast_sizes2 ``` Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15033893513/job/42251970440 Fixes #4206. Signed-off-by: Whitney Tsang <[email protected]>
1 parent e1a432d commit e4f83c9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

test/Triton/Intel/RemoveMasks/loop-canonical-masks.mlir

+1-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ module {
110110
// CHECK: }
111111

112112
tt.func public @test_kernel2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
113-
%c7_i32 = arith.constant 7 : i32
114113
%c8_i32 = arith.constant 8 : i32
115114
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32>
116115
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16>
@@ -165,7 +164,7 @@ module {
165164
%33 = arith.addi %31, %32 : tensor<64x256xi32>
166165
%34 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>>
167166
%35 = tt.addptr %34, %33 : tensor<64x256x!tt.ptr<f16>>, tensor<64x256xi32>
168-
%36:3 = scf.for %arg3 = %c0_i32 to %c7_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %27, %arg6 = %35) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x256x!tt.ptr<f16>>) : i32 {
167+
%36:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %27, %arg6 = %35) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x256x!tt.ptr<f16>>) : i32 {
169168
%51 = arith.muli %arg3, %c64_i32 : i32
170169
%52 = arith.subi %c512_i32, %51 : i32
171170
%53 = tt.splat %52 : i32 -> tensor<1x64xi32>

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
119119
int64_t N =
120120
cast<arith::ConstantIntOp>(maskInfo.N.getDefiningOp()).value();
121121
unsigned END = maskInfo.END;
122-
bool cond = UB <= ((N + END - 1) / END) - 1;
122+
bool cond = UB == ((N - END) / END) + 1;
123123
return builder.create<arith::ConstantIntOp>(forOp.getLoc(), cond,
124124
builder.getI1Type());
125125
}
@@ -156,7 +156,8 @@ class CanonicalMaskValidator final : public MaskValidatorBase {
156156
int64_t UB = cast<arith::ConstantIntOp>(defOp).value();
157157
int64_t N =
158158
cast<arith::ConstantIntOp>(maskInfo.N.getDefiningOp()).value();
159-
return UB == ((N + maskInfo.END - 1) / maskInfo.END) - 1;
159+
unsigned END = maskInfo.END;
160+
return UB == ((N - END) / END) + 1;
160161
}
161162

162163
if (!isa<arith::DivSIOp>(defOp))

0 commit comments

Comments
 (0)