Skip to content

[BACKEND] Need to support load regular tensor pointers of DPAS layout. #4059

Closed
@chengjunlu

Description

@chengjunlu

Describe the bug

The flex attention uses the tensor of pointers to load the GEMM C matrix.

It hits the unreachable error on TritonXPU main branch.

Need to enable that. Here is the MLIR reproducer. The issue IR %83 = tt.load %82 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma>.

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 2], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
#mma1 = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 1], A = [16, 16], B = [16, 16], C = [16, 16]}>
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
  tt.func public @triton_tem_fused_4(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
    %cst_1 = arith.constant dense<2.500000e-01> : tensor<128x64xf32, #mma>
    %cst_2 = arith.constant dense<1.44269502> : tensor<128x64xf32, #mma>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
    %cst_4 = arith.constant dense<16> : tensor<128x1xi32, #blocked>
    %cst_5 = arith.constant dense<16> : tensor<1x16xi32, #blocked>
    %cst_6 = arith.constant dense<256> : tensor<128x1xi32, #blocked>
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %cst_8 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %c1073741760_i32 = arith.constant 1073741760 : i32
    %c64_i32 = arith.constant 64 : i32
    %c1_i64 = arith.constant 1 : i64
    %c256_i64 = arith.constant 256 : i64
    %c256_i32 = arith.constant 256 : i32
    %c4_i32 = arith.constant 4 : i32
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c4096_i32 = arith.constant 4096 : i32
    %c16384_i32 = arith.constant 16384 : i32
    %c128_i32 = arith.constant 128 : i32
    %c8388608_i32 = arith.constant 8388608 : i32
    %c16_i64 = arith.constant 16 : i64
    %c0_i32 = arith.constant 0 : i32
    %c1073741824_i32 = arith.constant 1073741824 : i32
    %c16777216_i32 = arith.constant 16777216 : i32
    %cst_9 = arith.constant dense<256> : tensor<128x1xi32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = arith.divsi %1, %c4_i32 : i32
    %3 = arith.remsi %1, %c4_i32 : i32
    %4 = arith.remsi %2, %c2_i32 : i32
    %5 = arith.muli %2, %c16384_i32 : i32
    %6 = arith.muli %3, %c4096_i32 : i32
    %7 = arith.addi %5, %6 : i32
    %8 = arith.muli %4, %c16384_i32 : i32
    %9 = arith.addi %8, %6 : i32
    %10 = tt.addptr %arg0, %7 : !tt.ptr<f16>, i32
    %11 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i32
    %12 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i32
    %13 = arith.muli %0, %c128_i32 : i32
    %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1>
    %17 = tt.splat %13 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %18 = tt.splat %13 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %19 = tt.splat %13 : i32 -> tensor<128xi32, #blocked1>
    %20 = arith.addi %17, %14 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %21 = arith.addi %18, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %22 = arith.addi %19, %16 : tensor<128xi32, #blocked1>
    %23 = arith.divsi %0, %c8388608_i32 : i32
    %24 = tt.make_tensor_ptr %10, [%c256_i64, %c16_i64], [%c16_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
    %25 = tt.load %24 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
    %26 = tt.addptr %arg5, %23 : !tt.ptr<i32>, i32
    %27 = tt.load %26 : !tt.ptr<i32>
    %28 = arith.muli %27, %c1073741824_i32 : i32
    %29 = tt.addptr %arg4, %23 : !tt.ptr<i32>, i32
    %30 = tt.load %29 : !tt.ptr<i32>
    %31 = arith.muli %30, %c16777216_i32 : i32
    %32 = arith.minsi %31, %c4_i32 : i32
    %33 = tt.make_tensor_ptr %11, [%c16_i64, %c256_i64], [%c1_i64, %c16_i64], [%c0_i32, %28] {order = array<i32: 0, 1>} : <tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
    %34 = tt.make_tensor_ptr %12, [%c256_i64, %c16_i64], [%c16_i64, %c1_i64], [%28, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
    %35 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %36 = tt.splat %28 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %37 = arith.addi %36, %35 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %38 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
    %39 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
    %40 = tt.expand_dims %37 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma>
    %41 = tt.splat %arg8 : !tt.ptr<f16> -> tensor<1x64x!tt.ptr<f16>, #mma>
    %42 = arith.muli %38, %cst_9 : tensor<128x1xi32, #mma>
    %43 = tt.broadcast %42 : tensor<128x1xi32, #mma> -> tensor<128x64xi32, #mma>
    %44:7 = scf.for %arg10 = %c0_i32 to %32 step %c1_i32 iter_args(%arg11 = %cst_0, %arg12 = %cst_7, %arg13 = %cst_8, %arg14 = %cst_8, %arg15 = %33, %arg16 = %40, %arg17 = %34) -> (tensor<128x16xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, tensor<1x64xi32, #mma>, !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>)  : i32 {
      %77 = tt.load %arg15 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
      %78 = tt.dot %25, %77, %cst_3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
      %79 = arith.mulf %78, %cst_1 : tensor<128x64xf32, #mma>
      %80 = tt.addptr %41, %arg16 : tensor<1x64x!tt.ptr<f16>, #mma>, tensor<1x64xi32, #mma>
      %81 = tt.broadcast %80 : tensor<1x64x!tt.ptr<f16>, #mma> -> tensor<128x64x!tt.ptr<f16>, #mma>
      %82 = tt.addptr %81, %43 : tensor<128x64x!tt.ptr<f16>, #mma>, tensor<128x64xi32, #mma>
      %83 = tt.load %82 {triton_intel_gpu.block_io = "row_major"} : tensor<128x64x!tt.ptr<f16>, #mma>
      %84 = arith.extf %83 : tensor<128x64xf16, #mma> to tensor<128x64xf32, #mma>
      %85 = arith.addf %79, %84 : tensor<128x64xf32, #mma>
      %86 = arith.mulf %85, %cst_2 : tensor<128x64xf32, #mma>
      %87 = "tt.reduce"(%86) <{axis = 1 : i32}> ({
      ^bb0(%arg18: f32, %arg19: f32):
        %133 = arith.maxnumf %arg18, %arg19 : f32
        tt.reduce.return %133 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %88 = arith.maxnumf %arg14, %87 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %89 = arith.maxnumf %arg13, %87 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %90 = arith.cmpf oeq, %88, %cst_8 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %91 = arith.cmpf oeq, %89, %cst_8 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %92 = arith.select %90, %cst_7, %88 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %93 = arith.select %91, %cst_7, %89 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %94 = arith.subf %arg13, %93 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %95 = math.exp2 %94 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %96 = tt.expand_dims %92 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
      %97 = tt.broadcast %96 : tensor<128x1xf32, #mma> -> tensor<128x64xf32, #mma>
      %98 = arith.subf %86, %97 : tensor<128x64xf32, #mma>
      %99 = math.exp2 %98 : tensor<128x64xf32, #mma>
      %100 = arith.mulf %arg12, %95 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %101 = "tt.reduce"(%99) <{axis = 1 : i32}> ({
      ^bb0(%arg18: f32, %arg19: f32):
        %133 = arith.addf %arg18, %arg19 : f32
        tt.reduce.return %133 : f32
      }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %102 = arith.addf %100, %101 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
      %103 = tt.expand_dims %95 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
      %104 = ttg.convert_layout %103 : tensor<128x1xf32, #mma> -> tensor<128x1xf32, #mma1>
      %105 = tt.broadcast %104 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
      %106 = arith.mulf %arg11, %105 : tensor<128x16xf32, #mma1>
      %107 = tt.load %arg17 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
      %108 = arith.truncf %99 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
      %109 = ttg.convert_layout %108 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>>
      %110 = tt.dot %109, %107, %106 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<128x16xf32, #mma1>
      %111 = arith.divsi %arg10, %c16777216_i32 : i32
      %112 = tt.addptr %26, %111 : !tt.ptr<i32>, i32
      %113 = tt.load %112 evictionPolicy = evict_last : !tt.ptr<i32>
      %114 = arith.addi %111, %c1_i32 : i32
      %115 = arith.cmpi slt, %114, %30 : i32
      %116 = tt.addptr %112, %c1_i32 : !tt.ptr<i32>, i32
      %117 = tt.load %116, %115 evictionPolicy = evict_last : !tt.ptr<i32>
      %118 = arith.addi %arg10, %c1_i32 : i32
      %119 = arith.remsi %118, %c16777216_i32 : i32
      %120 = arith.cmpi eq, %119, %c0_i32 : i32
      %121 = arith.subi %117, %113 : i32
      %122 = arith.muli %121, %c1073741824_i32 : i32
      %123 = arith.subi %122, %c1073741760_i32 : i32
      %124 = arith.extui %120 : i1 to i32
      %125 = arith.muli %123, %124 : i32
      %126 = arith.subi %c1_i32, %124 : i32
      %127 = arith.muli %126, %c64_i32 : i32
      %128 = arith.addi %125, %127 : i32
      %129 = tt.advance %arg17, [%128, %c0_i32] : <tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
      %130 = tt.advance %arg15, [%c0_i32, %128] : <tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
      %131 = tt.splat %128 : i32 -> tensor<1x64xi32, #mma>
      %132 = arith.addi %arg16, %131 : tensor<1x64xi32, #mma>
      scf.yield %110, %102, %89, %88, %130, %132, %129 : tensor<128x16xf32, #mma1>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, tensor<1x64xi32, #mma>, !tt.ptr<tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>>>
    }
    %45 = arith.cmpf oeq, %44#1, %cst_7 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %46 = arith.select %45, %cst, %44#1 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %47 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
    %48 = ttg.convert_layout %47 : tensor<128x1xf32, #mma> -> tensor<128x1xf32, #mma1>
    %49 = tt.broadcast %48 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
    %50 = arith.divf %44#0, %49 : tensor<128x16xf32, #mma1>
    %51 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %52 = tt.expand_dims %51 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
    %53 = arith.cmpi slt, %39, %cst_6 : tensor<128x1xi32, #blocked>
    %54 = arith.cmpi slt, %52, %cst_5 : tensor<1x16xi32, #blocked>
    %55 = tt.broadcast %53 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
    %56 = tt.broadcast %54 : tensor<1x16xi1, #blocked> -> tensor<128x16xi1, #blocked>
    %57 = arith.andi %55, %56 : tensor<128x16xi1, #blocked>
    %58 = arith.muli %39, %cst_4 : tensor<128x1xi32, #blocked>
    %59 = tt.broadcast %52 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
    %60 = tt.broadcast %58 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
    %61 = arith.addi %59, %60 : tensor<128x16xi32, #blocked>
    %62 = tt.splat %6 : i32 -> tensor<128x16xi32, #blocked>
    %63 = arith.addi %61, %62 : tensor<128x16xi32, #blocked>
    %64 = tt.splat %5 : i32 -> tensor<128x16xi32, #blocked>
    %65 = arith.addi %63, %64 : tensor<128x16xi32, #blocked>
    %66 = tt.splat %arg9 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
    %67 = tt.addptr %66, %65 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
    %68 = arith.truncf %50 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
    %69 = ttg.convert_layout %68 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #blocked>
    tt.store %67, %69, %57 : tensor<128x16x!tt.ptr<f16>, #blocked>
    %70 = arith.muli %1, %c256_i32 : i32
    %71 = tt.addptr %arg3, %70 : !tt.ptr<f32>, i32
    %72 = tt.splat %71 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked1>
    %73 = tt.addptr %72, %22 : tensor<128x!tt.ptr<f32>, #blocked1>, tensor<128xi32, #blocked1>
    %74 = math.log2 %46 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %75 = arith.addf %44#2, %74 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
    %76 = ttg.convert_layout %75 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128xf32, #blocked1>
    tt.store %73, %76 : tensor<128x!tt.ptr<f32>, #blocked1>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-scf-to-cf, convert-index-to-llvm{index-bitwidth=0}, allocate-shared-memory, tritongpu-global-scratch-memory-allocation, convert-triton-intel-gpu-to-llvm{advanced_path=false one_matrix_per_load_for_bt=false use_tile_load_linear_layout=true}, tritonintelgpu-rewrite-stack-ptr, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
      disable_threading: false,
      verify_each: true
    }
  }
#-}

Environment details

TritonXPU: main
GPU.

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions