Skip to content

Commit 4cd33e9

Browse files
Do not use extractvalue if the inserted value is directly reachable
1 parent fff1773 commit 4cd33e9

File tree

8 files changed

+145
-260
lines changed

8 files changed

+145
-260
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,22 @@ SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
562562
return {llvmStruct};
563563
ArrayRef<Type> types =
564564
cast<LLVM::LLVMStructType>(llvmStruct.getType()).getBody();
565-
SmallVector<Value> results(types.size());
566-
auto b = TritonLLVMOpBuilder(loc, rewriter);
567-
for (unsigned i = 0; i < types.size(); ++i) {
568-
Type type = types[i];
569-
results[i] = b.extract_val(type, llvmStruct, i);
565+
unsigned remaining = types.size();
566+
SmallVector<Value> results(remaining);
567+
// If llvmStruct is an InsertValueOp, iterate up over the chain of
568+
// InsertValueOps and get the inserted values instead of extracting
569+
// from the struct.
570+
for (auto ins = llvmStruct.getDefiningOp<LLVM::InsertValueOp>();
571+
ins && ins.getPosition()[0] == remaining - 1;
572+
ins = ins.getContainer().getDefiningOp<LLVM::InsertValueOp>()) {
573+
results[--remaining] = ins.getValue();
574+
}
575+
if (remaining) {
576+
auto b = TritonLLVMOpBuilder(loc, rewriter);
577+
for (unsigned i = 0; i < remaining; ++i) {
578+
Type type = types[i];
579+
results[i] = b.extract_val(type, llvmStruct, i);
580+
}
570581
}
571582
return results;
572583
}

test/Conversion/amd/buffer_load_store.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
2828
%4 = arith.addi %3, %2 : tensor<128xi32, #blocked0>
2929
%5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0>
3030
%7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0>
31-
// CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)>
31+
// CHECK: %[[mask:.*]] = llvm.icmp "slt"
3232
// CHECK: %[[offset:.*]] = llvm.select %[[mask]]
3333
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]]
3434
%ret = amdgpu.buffer_load %arg0[%offset], %7 stride = %c256_i32 : tensor<128xf32, #blocked0>
@@ -51,7 +51,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
5151
%5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0>
5252
%7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0>
5353
%other = arith.constant dense<0.00e+00> : tensor<128xf32, #blocked0>
54-
// CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)>
54+
// CHECK: %[[mask:.*]] = llvm.icmp "slt"
5555
// CHECK: %[[offset:.*]] = llvm.select %[[mask]]
5656
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]]
5757
// CHECK: llvm.select
@@ -90,7 +90,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
9090
%4 = arith.addi %3, %2 : tensor<128xi32, #blocked0>
9191
%5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0>
9292
%7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0>
93-
// CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)>
93+
// CHECK: %[[mask0:.*]] = llvm.icmp "slt"
9494
// CHECK: %[[mask1:.*]] = llvm.mlir.constant(true) : i1
9595
// CHECK: %[[mask2:.*]] = llvm.and %[[mask1]], %[[mask0]]
9696
// CHECK: %[[offset:.*]] = llvm.select %[[mask2]]
@@ -216,7 +216,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
216216
%4 = arith.addi %3, %2 : tensor<128xi32, #blocked0>
217217
%5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0>
218218
%mask = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0>
219-
// CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)>
219+
// CHECK: %[[mask0:.*]] = llvm.icmp "slt"
220220
// There should be a single release fence before any atomics
221221
// CHECK: llvm.fence syncscope("agent") release
222222
// CHECK: %[[mask1:.*]] = llvm.mlir.constant(true) : i1

test/Conversion/intel/dot_layout_offset.mlir

+1-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
66
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset()
77
tt.func public @dot_layout_emit_offset() {
88
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot_operand_a>
9-
// CHECK-COUNT-64: {{.*}} = llvm.extractvalue {{.*}}
109

1110
// COM: Base index of the dot layout.
1211
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj
@@ -327,11 +326,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.thr
327326
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset()
328327
tt.func public @dot_layout_emit_offset() {
329328
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot_operand_b>
330-
// CHECK-COUNT-64: {{.*}} = llvm.extractvalue {{.*}}
331-
// CHECK: %[[VAL_142:.*]] = llvm.mlir.constant(0 : i32) : i32
332-
333329
// COM: Base index of the dot layout.
334-
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAL_142]])
330+
// CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj
335331
// CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32
336332
// CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32
337333
// CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32

test/Conversion/intel/tritongpu_to_gen.mlir

+6-12
Original file line numberDiff line numberDiff line change
@@ -499,13 +499,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
499499
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
500500
// CHECK-LABEL: basic_view_broadcast
501501
tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
502-
// CHECK: [[ARG0_0:%.*]] = llvm.extractvalue %arg0[0]
503-
// CHECK-NEXT: [[ARG0_1:%.*]] = llvm.extractvalue %arg0[1]
504-
// CHECK-NEXT: [[STRUCT:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
505-
// CHECK-NEXT: [[STRUCT1:%.*]] = llvm.insertvalue [[ARG0_0]], [[STRUCT]][0]
506-
// CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1]
507-
// CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0]
508-
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1]
502+
// CHECK: [[T0:%.*]] = llvm.extractvalue %arg0[0]
503+
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue %arg0[1]
509504
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
510505
// CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
511506
// CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0]
@@ -1889,13 +1884,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
18891884
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
18901885
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
18911886
// CHECK-LABEL: convert_single_element_and_add
1892-
// CHECK-NOT: llvm.store
1893-
// CHECK-NOT: llvm.load
1894-
// CHECK: llvm.insertvalue
1895-
// CHECK: llvm.extractvalue
1887+
// CHECK: llvm.mlir.constant(1.000000e+03 : f32) : f32
1888+
// CHECK: llvm.mlir.constant(2.000000e+03 : f32) : f32
1889+
// CHECK: llvm.fadd %{{.*}}, %{{.*}} : f32
18961890
tt.func public @convert_single_element_and_add() attributes {noinline = false} {
18971891
%cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1>
1898-
%cst2 = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked>
1892+
%cst2 = arith.constant dense<2.000000e+03> : tensor<1xf32, #blocked>
18991893
%0 = ttg.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked>
19001894
%1 = arith.addf %0, %cst2 : tensor<1xf32, #blocked>
19011895
tt.return

test/Conversion/tritongpu_to_llvm.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
426426
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
427427
// CHECK-LABEL: basic_view_broadcast
428428
tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
429-
// CHECK: llvm.mlir.undef
430429
// CHECK: %[[T0:.*]] = llvm.extractvalue
431430
// CHECK: %[[T1:.*]] = llvm.extractvalue
432431
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
@@ -1967,8 +1966,9 @@ module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num-
19671966
// CHECK-LABEL: convert_single_element_and_add
19681967
// CHECK-NOT: llvm.store
19691968
// CHECK-NOT: llvm.load
1969+
// CHECK: llvm.fadd
1970+
// CHECK: llvm.mlir.undef
19701971
// CHECK: llvm.insertvalue
1971-
// CHECK: llvm.extractvalue
19721972
tt.func public @convert_single_element_and_add() attributes {noinline = false} {
19731973
%cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1>
19741974
%cst2 = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked>

0 commit comments

Comments
 (0)