-
Notifications
You must be signed in to change notification settings - Fork 62
Do not use extractvalue if the inserted value is directly reachable #4212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Can you provide some description here to include the why, not just the what or the how. |
Sure, I'll add when it's ready for review. It's a draft and will definitely be changed (or closed, because I'm not quite sure about it). |
I'm not quite sure about it either :) some more context w/ the draft might help generate ideas for how to proceed. |
This is an attempt to minimise the number of insert/extract_value operations - #4136. Here #4062 (comment) is an example with a huge amount of such operations, where constants are inserted and then extracted from structures. Later all these operations are optimised out by the canonicalizer, but before that a huger IR is created. |
A short example: #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
%1 = tt.broadcast %0 : tensor<256x1xf32,#blocked2> -> tensor<256x4xf32, #blocked2>
tt.return
}
} lowered to: module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.func spir_kernelcc @basic_view_broadcast(%arg0: !llvm.struct<(f32, f32)>) attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
%0 = llvm.extractvalue %arg0[0] : !llvm.struct<(f32, f32)>
%1 = llvm.extractvalue %arg0[1] : !llvm.struct<(f32, f32)>
%2 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
%3 = llvm.insertvalue %0, %2[0] : !llvm.struct<(f32, f32)>
%4 = llvm.insertvalue %1, %3[1] : !llvm.struct<(f32, f32)>
%5 = llvm.extractvalue %4[0] : !llvm.struct<(f32, f32)>
%6 = llvm.extractvalue %4[1] : !llvm.struct<(f32, f32)>
%7 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32)>
%8 = llvm.insertvalue %5, %7[0] : !llvm.struct<(f32, f32, f32, f32)>
%9 = llvm.insertvalue %6, %8[1] : !llvm.struct<(f32, f32, f32, f32)>
%10 = llvm.insertvalue %5, %9[2] : !llvm.struct<(f32, f32, f32, f32)>
%11 = llvm.insertvalue %6, %10[3] : !llvm.struct<(f32, f32, f32, f32)>
llvm.return
}
} The lines %3 = llvm.insertvalue %0, %2[0] : !llvm.struct<(f32, f32)>
%4 = llvm.insertvalue %1, %3[1] : !llvm.struct<(f32, f32)>
%5 = llvm.extractvalue %4[0] : !llvm.struct<(f32, f32)>
%6 = llvm.extractvalue %4[1] : !llvm.struct<(f32, f32)> are redundant. This fix checks if the structure is |
8ee11fc
to
4cd33e9
Compare
4cd33e9
to
e738d9c
Compare
No description provided.