Skip to content

[CPU] i4 pack op fails to compile #16285

@dcaballe

Description

@dcaballe

It looks like some memref.subviews are not optimized away for this i4 pack op and we try to apply narrow type emulation to it:

#config = #iree_codegen.lowering_config<tile_sizes = [[20000, 16000], [1, 1]]>
#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+neon", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34", ukernels = "none"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<CPUDataTiling>
module {
  hal.executable public @pack_i4 {
    hal.executable.variant public @system_elf_arm_64 target(#executable_target_system_elf_arm_64_) {
      hal.executable.export public @pack_i4 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
      ^bb0(%arg0: !hal.device):
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @pack_i4() {
          %c0_i4 = arith.constant 0 : i4
          %c0 = arith.constant 0 : index
          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>>
          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
          %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16000, 32000], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>> -> tensor<16000x32000xi4>
          %3 = tensor.empty() : tensor<200000x16000x64x1xi4>
          %pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x32000xi4> -> tensor<200000x16000x64x1xi4>
          flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [200000, 16000, 64, 1], strides = [1, 1, 1, 1] : tensor<200000x16000x64x1xi4> -> !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
          return
        }
      }
    }
  }
}

Error:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features="+neon" --iree-llvmcpu-target-triple=aarch64-none-linux-android34 --iree-opt-dcata-tiling=true --iree-llvmcpu-enable-ukernels=none --compile-from=executable-sources repro.mlir

repro.mlir:21:19: error: failed to legalize operation 'memref.subview' that was explicitly marked illegal                                                                              
          %pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x3
2000xi4> -> tensor<200000x16000x64x1xi4>                                                                                                                                               
                  ^                                                                                                                                                                    
repro.mlir:21:19: note: see current operation: %29 = "memref.subview"(%7, %27, %28, %24) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -92233720368547758
08, -9223372036854775808>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<16000x32000xi4>, index, index, index) -> memref<1x?xi4, strided<[32000, 1], offset: ?>>

IR before the compilation error:

// -----// IR Dump After FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
module {
  func.func @pack_i4() {
    %c1 = arith.constant 1 : index
    %c20000 = arith.constant 20000 : index
    %c16000 = arith.constant 16000 : index
    %c200000 = arith.constant 200000 : index
    %c0_i4 = arith.constant 0 : i4
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<16000x32000xi4>
    memref.assume_alignment %0, 64 : memref<16000x32000xi4>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<200000x16000x64x1xi4>
    memref.assume_alignment %1, 64 : memref<200000x16000x64x1xi4>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    %2 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_id_y]
    %3 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_count_y]
    %4 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_id_x]
    %5 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_count_x]
    cf.br ^bb1(%2 : index)
  ^bb1(%6: index):  // 2 preds: ^bb0, ^bb11
    %7 = arith.cmpi slt, %6, %c200000 : index
    cf.cond_br %7, ^bb2, ^bb12
  ^bb2:  // pred: ^bb1
    cf.br ^bb3(%4 : index)
  ^bb3(%8: index):  // 2 preds: ^bb2, ^bb10
    %9 = arith.cmpi slt, %8, %c16000 : index
    cf.cond_br %9, ^bb4, ^bb11
  ^bb4:  // pred: ^bb3
    cf.br ^bb5(%c0 : index)
  ^bb5(%10: index):  // 2 preds: ^bb4, ^bb9
    %11 = arith.cmpi slt, %10, %c20000 : index
    cf.cond_br %11, ^bb6, ^bb10
  ^bb6:  // pred: ^bb5
    %12 = affine.min affine_map<()[s0, s1] -> (s0 * -64 - s1 * 64 + 32000, 64)>()[%10, %6]
    cf.br ^bb7(%c0 : index)
  ^bb7(%13: index):  // 2 preds: ^bb6, ^bb8
    %14 = arith.cmpi slt, %13, %c16000 : index
    cf.cond_br %14, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    %15 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
    %16 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%6, %10]
    %subview = memref.subview %0[%15, %16] [1, %12] [1, 1] : memref<16000x32000xi4> to memref<1x?xi4, strided<[32000, 1], offset: ?>>
    %17 = vector.transfer_read %subview[%c0, %c0], %c0_i4 : memref<1x?xi4, strided<[32000, 1], offset: ?>>, vector<64xi4>
    %18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%6, %10]
    %19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
    vector.store %17, %1[%18, %19, %c0, %c0] : memref<200000x16000x64x1xi4>, vector<64xi4>
    %20 = arith.addi %13, %c1 : index
    cf.br ^bb7(%20 : index)
  ^bb9:  // pred: ^bb7
    %21 = arith.addi %10, %c1 : index
    cf.br ^bb5(%21 : index)
  ^bb10:  // pred: ^bb5
    %22 = arith.addi %8, %5 : index
    cf.br ^bb3(%22 : index)
  ^bb11:  // pred: ^bb3
    %23 = arith.addi %6, %3 : index
    cf.br ^bb1(%23 : index)
  ^bb12:  // pred: ^bb1
    return
  }
}

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions