Description
Describe the bug
The flex attention uses the tensor of pointers to load the GEMM C matrix.
It hits the unreachable error on TritonXPU main branch.
Need to enable that. Here is the MLIR reproducer. The issue IR %83 = tt.load %82 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma>
.
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 2], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
#mma1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 1], A = [16, 16], B = [16, 16], C = [16, 16]}>
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
tt.func public @triton_tem_fused_4(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
%cst_1 = arith.constant dense<2.500000e-01> : tensor<128x64xf32, #mma>
%cst_2 = arith.constant dense<1.44269502> : tensor<128x64xf32, #mma>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%cst_4 = arith.constant dense<16> : tensor<128x1xi32, #blocked>
%cst_5 = arith.constant dense<16> : tensor<1x16xi32, #blocked>
%cst_6 = arith.constant dense<256> : tensor<128x1xi32, #blocked>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_8 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%c1073741760_i32 = arith.constant 1073741760 : i32
%c64_i32 = arith.constant 64 : i32
%c1_i64 = arith.constant 1 : i64
%c256_i64 = arith.constant 256 : i64
%c256_i32 = arith.constant 256 : i32
%c4_i32 = arith.constant 4 : i32
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c4096_i32 = arith.constant 4096 : i32
%c16384_i32 = arith.constant 16384 : i32
%c128_i32 = arith.constant 128 : i32
%c8388608_i32 = arith.constant 8388608 : i32
%c16_i64 = arith.constant 16 : i64
%c0_i32 = arith.constant 0 : i32
%c1073741824_i32 = arith.constant 1073741824 : i32
%c16777216_i32 = arith.constant 16777216 : i32
%cst_9 = arith.constant dense<256> : tensor<128x1xi32, #mma>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.divsi %1, %c4_i32 : i32
%3 = arith.remsi %1, %c4_i32 : i32
%4 = arith.remsi %2, %c2_i32 : i32
%5 = arith.muli %2, %c16384_i32 : i32
%6 = arith.muli %3, %c4096_i32 : i32
%7 = arith.addi %5, %6 : i32
%8 = arith.muli %4, %c16384_i32 : i32
%9 = arith.addi %8, %6 : i32
%10 = tt.addptr %arg0, %7 : !tt.ptr<f16>, i32
%11 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i32
%12 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i32
%13 = arith.muli %0, %c128_i32 : i32
%14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1>
%17 = tt.splat %13 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%18 = tt.splat %13 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%19 = tt.splat %13 : i32 -> tensor<128xi32, #blocked1>
%20 = arith.addi %17, %14 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%21 = arith.addi %18, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%22 = arith.addi %19, %16 : tensor<128xi32, #blocked1>
%23 = arith.divsi %0, %c8388608_i32 : i32
%24 = tt.make_tensor_ptr %10, [%c256_i64, %c16_i64], [%c16_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%25 = tt.load %24 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%26 = tt.addptr %arg5, %23 : !tt.ptr<i32>, i32
%27 = tt.load %26 : !tt.ptr<i32>
%28 = arith.muli %27, %c1073741824_i32 : i32
%29 = tt.addptr %arg4, %23 : !tt.ptr<i32>, i32
%30 = tt.load %29 : !tt.ptr<i32>
%31 = arith.muli %30, %c16777216_i32 : i32
%32 = arith.minsi %31, %c4_i32 : i32
%33 = tt.make_tensor_ptr %11, [%c16_i64, %c256_i64], [%c1_i64, %c16_i64], [%c0_i32, %28] {order = array<i32: 0, 1>} : <tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%34 = tt.make_tensor_ptr %12, [%c256_i64, %c16_i64], [%c16_i64, %c1_i64], [%28, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
%35 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%36 = tt.splat %28 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%37 = arith.addi %36, %35 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%38 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%39 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
%40 = tt.expand_dims %37 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma>
%41 = tt.splat %arg8 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #mma>
%42 = arith.muli %38, %cst_9 : tensor<128x1xi32, #mma>
%43 = tt.broadcast %42 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma>
%44:7 = scf.for %arg10 = %c0_i32 to %32 step %c1_i32 iter_args(%arg11 = %cst_0, %arg12 = %cst_7, %arg13 = %cst_8, %arg14 = %cst_8, %arg15 = %33, %arg16 = %40, %arg17 = %34) -> (tensor<128x16xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, tensor<1x64xi32, #mma>, !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>) : i32 {
%77 = tt.load %arg15 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%78 = tt.dot %25, %77, %cst_3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
%79 = arith.mulf %78, %cst_1 : tensor<128x64xf32, #mma>
%80 = tt.addptr %41, %arg16 : tensor<1x64x!tt.ptr<f16>, #mma>, tensor<1x64xi32, #mma>
%81 = tt.broadcast %80 : tensor<1x64x!tt.ptr<f16>, #mma> -> tensor<128x64x!tt.ptr<f16>, #mma>
%82 = tt.addptr %81, %43 : tensor<128x64x!tt.ptr<f16>, #mma>, tensor<128x64xi32, #mma>
%83 = tt.load %82 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma>
%84 = arith.extf %83 : tensor<128x64xf16, #mma> to tensor<128x64xf32, #mma>
%85 = arith.addf %79, %84 : tensor<128x64xf32, #mma>
%86 = arith.mulf %85, %cst_2 : tensor<128x64xf32, #mma>
%87 = "tt.reduce"(%86) <{axis = 1 : i32}> ({
^bb0(%arg18: f32, %arg19: f32):
%133 = arith.maxnumf %arg18, %arg19 : f32
tt.reduce.return %133 : f32
}) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%88 = arith.maxnumf %arg14, %87 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%89 = arith.maxnumf %arg13, %87 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%90 = arith.cmpf oeq, %88, %cst_8 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%91 = arith.cmpf oeq, %89, %cst_8 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%92 = arith.select %90, %cst_7, %88 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%93 = arith.select %91, %cst_7, %89 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%94 = arith.subf %arg13, %93 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%95 = math.exp2 %94 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%96 = tt.expand_dims %92 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%97 = tt.broadcast %96 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma>
%98 = arith.subf %86, %97 : tensor<128x64xf32, #mma>
%99 = math.exp2 %98 : tensor<128x64xf32, #mma>
%100 = arith.mulf %arg12, %95 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%101 = "tt.reduce"(%99) <{axis = 1 : i32}> ({
^bb0(%arg18: f32, %arg19: f32):
%133 = arith.addf %arg18, %arg19 : f32
tt.reduce.return %133 : f32
}) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%102 = arith.addf %100, %101 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%103 = tt.expand_dims %95 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%104 = ttg.convert_layout %103 : tensor<128x1xf32, #mma> -> tensor<128x1xf32, #mma1>
%105 = tt.broadcast %104 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
%106 = arith.mulf %arg11, %105 : tensor<128x16xf32, #mma1>
%107 = tt.load %arg17 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
%108 = arith.truncf %99 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
%109 = ttg.convert_layout %108 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>>
%110 = tt.dot %109, %107, %106 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<128x16xf32, #mma1>
%111 = arith.divsi %arg10, %c16777216_i32 : i32
%112 = tt.addptr %26, %111 : !tt.ptr<i32>, i32
%113 = tt.load %112 evictionPolicy = evict_last : !tt.ptr<i32>
%114 = arith.addi %111, %c1_i32 : i32
%115 = arith.cmpi slt, %114, %30 : i32
%116 = tt.addptr %112, %c1_i32 : !tt.ptr<i32>, i32
%117 = tt.load %116, %115 evictionPolicy = evict_last : !tt.ptr<i32>
%118 = arith.addi %arg10, %c1_i32 : i32
%119 = arith.remsi %118, %c16777216_i32 : i32
%120 = arith.cmpi eq, %119, %c0_i32 : i32
%121 = arith.subi %117, %113 : i32
%122 = arith.muli %121, %c1073741824_i32 : i32
%123 = arith.subi %122, %c1073741760_i32 : i32
%124 = arith.extui %120 : i1 to i32
%125 = arith.muli %123, %124 : i32
%126 = arith.subi %c1_i32, %124 : i32
%127 = arith.muli %126, %c64_i32 : i32
%128 = arith.addi %125, %127 : i32
%129 = tt.advance %arg17, [%128, %c0_i32] : <tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
%130 = tt.advance %arg15, [%c0_i32, %128] : <tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
%131 = tt.splat %128 : i32 -> tensor<1x64xi32, #mma>
%132 = arith.addi %arg16, %131 : tensor<1x64xi32, #mma>
scf.yield %110, %102, %89, %88, %130, %132, %129 : tensor<128x16xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, tensor<1x64xi32, #mma>, !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
}
%45 = arith.cmpf oeq, %44#1, %cst_7 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%46 = arith.select %45, %cst, %44#1 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%47 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%48 = ttg.convert_layout %47 : tensor<128x1xf32, #mma> -> tensor<128x1xf32, #mma1>
%49 = tt.broadcast %48 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
%50 = arith.divf %44#0, %49 : tensor<128x16xf32, #mma1>
%51 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%52 = tt.expand_dims %51 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
%53 = arith.cmpi slt, %39, %cst_6 : tensor<128x1xi32, #blocked>
%54 = arith.cmpi slt, %52, %cst_5 : tensor<1x16xi32, #blocked>
%55 = tt.broadcast %53 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
%56 = tt.broadcast %54 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked>
%57 = arith.andi %55, %56 : tensor<128x16xi1, #blocked>
%58 = arith.muli %39, %cst_4 : tensor<128x1xi32, #blocked>
%59 = tt.broadcast %52 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
%60 = tt.broadcast %58 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
%61 = arith.addi %59, %60 : tensor<128x16xi32, #blocked>
%62 = tt.splat %6 : i32 -> tensor<128x16xi32, #blocked>
%63 = arith.addi %61, %62 : tensor<128x16xi32, #blocked>
%64 = tt.splat %5 : i32 -> tensor<128x16xi32, #blocked>
%65 = arith.addi %63, %64 : tensor<128x16xi32, #blocked>
%66 = tt.splat %arg9 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%67 = tt.addptr %66, %65 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%68 = arith.truncf %50 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
%69 = ttg.convert_layout %68 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #blocked>
tt.store %67, %69, %57 : tensor<128x16x!tt.ptr<f16>, #blocked>
%70 = arith.muli %1, %c256_i32 : i32
%71 = tt.addptr %arg3, %70 : !tt.ptr<f32>, i32
%72 = tt.splat %71 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked1>
%73 = tt.addptr %72, %22 : tensor<128x!tt.ptr<f32>, #blocked1>, tensor<128xi32, #blocked1>
%74 = math.log2 %46 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%75 = arith.addf %44#2, %74 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%76 = ttg.convert_layout %75 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128xf32, #blocked1>
tt.store %73, %76 : tensor<128x!tt.ptr<f32>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(convert-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-shared-memory, tritongpu-global-scratch-memory-allocation, convert-triton-intel-gpu-to-llvm{advanced_path=false one_matrix_per_load_for_bt=false use_tile_load_linear_layout=true}, tritonintelgpu-rewrite-stack-ptr, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
Environment details
TritonXPU: main
GPU.