-
Notifications
You must be signed in to change notification settings - Fork 728
Closed as not planned
Description
It looks like some memref.subviews
are not optimized away for this i4 pack op and we try to apply narrow type emulation to it:
#config = #iree_codegen.lowering_config<tile_sizes = [[20000, 16000], [1, 1]]>
#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+neon", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34", ukernels = "none"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<CPUDataTiling>
module {
hal.executable public @pack_i4 {
hal.executable.variant public @system_elf_arm_64 target(#executable_target_system_elf_arm_64_) {
hal.executable.export public @pack_i4 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @pack_i4() {
%c0_i4 = arith.constant 0 : i4
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16000, 32000], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>> -> tensor<16000x32000xi4>
%3 = tensor.empty() : tensor<200000x16000x64x1xi4>
%pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x32000xi4> -> tensor<200000x16000x64x1xi4>
flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [200000, 16000, 64, 1], strides = [1, 1, 1, 1] : tensor<200000x16000x64x1xi4> -> !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
return
}
}
}
}
}
Error:
iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features="+neon" --iree-llvmcpu-target-triple=aarch64-none-linux-android34 --iree-opt-dcata-tiling=true --iree-llvmcpu-enable-ukernels=none --compile-from=executable-sources repro.mlir
repro.mlir:21:19: error: failed to legalize operation 'memref.subview' that was explicitly marked illegal
%pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x3
2000xi4> -> tensor<200000x16000x64x1xi4>
^
repro.mlir:21:19: note: see current operation: %29 = "memref.subview"(%7, %27, %28, %24) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -92233720368547758
08, -9223372036854775808>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<16000x32000xi4>, index, index, index) -> memref<1x?xi4, strided<[32000, 1], offset: ?>>
IR before the compilation error:
// -----// IR Dump After FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
module {
func.func @pack_i4() {
%c1 = arith.constant 1 : index
%c20000 = arith.constant 20000 : index
%c16000 = arith.constant 16000 : index
%c200000 = arith.constant 200000 : index
%c0_i4 = arith.constant 0 : i4
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<16000x32000xi4>
memref.assume_alignment %0, 64 : memref<16000x32000xi4>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<200000x16000x64x1xi4>
memref.assume_alignment %1, 64 : memref<200000x16000x64x1xi4>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%2 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_id_y]
%3 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_count_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_id_x]
%5 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_count_x]
cf.br ^bb1(%2 : index)
^bb1(%6: index): // 2 preds: ^bb0, ^bb11
%7 = arith.cmpi slt, %6, %c200000 : index
cf.cond_br %7, ^bb2, ^bb12
^bb2: // pred: ^bb1
cf.br ^bb3(%4 : index)
^bb3(%8: index): // 2 preds: ^bb2, ^bb10
%9 = arith.cmpi slt, %8, %c16000 : index
cf.cond_br %9, ^bb4, ^bb11
^bb4: // pred: ^bb3
cf.br ^bb5(%c0 : index)
^bb5(%10: index): // 2 preds: ^bb4, ^bb9
%11 = arith.cmpi slt, %10, %c20000 : index
cf.cond_br %11, ^bb6, ^bb10
^bb6: // pred: ^bb5
%12 = affine.min affine_map<()[s0, s1] -> (s0 * -64 - s1 * 64 + 32000, 64)>()[%10, %6]
cf.br ^bb7(%c0 : index)
^bb7(%13: index): // 2 preds: ^bb6, ^bb8
%14 = arith.cmpi slt, %13, %c16000 : index
cf.cond_br %14, ^bb8, ^bb9
^bb8: // pred: ^bb7
%15 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
%16 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%6, %10]
%subview = memref.subview %0[%15, %16] [1, %12] [1, 1] : memref<16000x32000xi4> to memref<1x?xi4, strided<[32000, 1], offset: ?>>
%17 = vector.transfer_read %subview[%c0, %c0], %c0_i4 : memref<1x?xi4, strided<[32000, 1], offset: ?>>, vector<64xi4>
%18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%6, %10]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
vector.store %17, %1[%18, %19, %c0, %c0] : memref<200000x16000x64x1xi4>, vector<64xi4>
%20 = arith.addi %13, %c1 : index
cf.br ^bb7(%20 : index)
^bb9: // pred: ^bb7
%21 = arith.addi %10, %c1 : index
cf.br ^bb5(%21 : index)
^bb10: // pred: ^bb5
%22 = arith.addi %8, %5 : index
cf.br ^bb3(%22 : index)
^bb11: // pred: ^bb3
%23 = arith.addi %6, %3 : index
cf.br ^bb1(%23 : index)
^bb12: // pred: ^bb1
return
}
}