Skip to content

Commit 36aa0ba

Browse files
authored
Transpose NCHW to NHWC if compiler stick/unstick is enabled and enable model zoo tests on NNPA (#2893)
* Transpose NCHW to NHWC if compiler stick/unstick is enabled Signed-off-by: Tung D. Le <[email protected]> * disable createZHighClipToDLFloatPass Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent b754bec commit 36aa0ba

File tree

8 files changed

+156
-24
lines changed

8 files changed

+156
-24
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ llvm::cl::opt<NNPAEmissionTargetType> nnpaEmissionTarget(
2929
llvm::cl::opt<bool> nnpaClipToDLFloatRange("nnpa-clip-to-dlfloat-range",
3030
llvm::cl::desc("Clip CPU tensors to dlfloat range before stickification to "
3131
"avoid out-of-range. Only clip Softmax inputs at this "
32-
"moment. Default is true."),
32+
"moment. Default is true. This option will be removed and "
33+
"replaced by --nnpa-saturation in the future."),
3334
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));
3435

3536
llvm::cl::opt<bool> nnpaEnableZHighToOnnx("enable-zhigh-to-onnx",
@@ -55,7 +56,7 @@ llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick(
5556
"enable-compiler-stick-unstick",
5657
llvm::cl::desc("[Experimental feature] Enable the compiler generate some "
5758
"stick/unstick code. Default is true."),
58-
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));
59+
llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions));
5960

6061
llvm::cl::opt<bool> nnpaEnableScalarBcastBinary(
6162
"nnpa-enable-scalar-bcast-binary",

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
103103
// Clip zhigh.Stick inputs if required. This is to avoid out-of-range of
104104
// dlfloat. Do constant propagation after clipping to remove ONNX ops used for
105105
// clipping such as ONNXMax if applicable.
106-
if (nnpaClipToDLFloatRange) {
106+
// This pass will be removed and replaced by nnpa-saturation in the future.
107+
if (!nnpaEnableSaturation && nnpaClipToDLFloatRange) {
107108
pm.addNestedPass<func::FuncOp>(
108109
onnx_mlir::zhigh::createZHighClipToDLFloatPass());
109110
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());

src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp

+48-13
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,9 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern {
507507
StringAttr layout = stickOp.getLayoutAttr();
508508
IntegerAttr saturation = stickOp.getSaturationAttr();
509509

510-
IndexExprBuilderForKrnl createKrnlIE(rewriter, loc);
511-
ZHighStickOpShapeHelper shapeHelper(op, operands, &createKrnlIE);
510+
MultiDialectBuilder<OnnxBuilder, IndexExprBuilderForKrnl> create(
511+
rewriter, loc);
512+
ZHighStickOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
512513
shapeHelper.computeShapeAndAssertOnFailure();
513514

514515
// Convert ZTensor type to MemRefType.
@@ -518,9 +519,17 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern {
518519
// Allocate a buffer for the result MemRef.
519520
Value alloc = insertAllocForZMemRef(
520521
zMemRefType, shapeHelper.getOutputDims(), op, rewriter);
521-
// Set pre-transformed layout: if NHWC, we can directly stickify from NCHW.
522-
if (isNHWCLayout(layout))
523-
layout = getNCHWLayoutAttr(rewriter);
522+
if (isNHWCLayout(layout)) {
523+
if (nnpaEnableCompilerStickUnstick) {
524+
// Compiler-generated stick hasn't supported NCHW yet.
525+
// Explicitly transpose NCHW to NHWC.
526+
input = create.onnx.toMemref(
527+
create.onnx.transposeInt64(input, ArrayRef<int64_t>({0, 2, 3, 1})));
528+
} else
529+
// Otherwise, we can directly stickify from NCHW.
530+
// Set pre-transformed layout to NCHW.
531+
layout = getNCHWLayoutAttr(rewriter);
532+
}
524533

525534
// Else, emit a ZLow operation.
526535
rewriter.create<ZLowStickOp>(loc, input, alloc, layout, saturation);
@@ -625,24 +634,50 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
625634
StringAttr layout =
626635
getZTensorLayoutAttr(rewriter, op->getOperand(0).getType());
627636

628-
IndexExprBuilderForKrnl createKrnlIE(rewriter, loc);
629-
ZHighUnstickOpShapeHelper shapeHelper(op, operands, &createKrnlIE);
637+
MultiDialectBuilder<OnnxBuilder, IndexExprBuilderForKrnl> create(
638+
rewriter, loc);
639+
ZHighUnstickOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
630640
shapeHelper.computeShapeAndAssertOnFailure();
631641

632642
// Convert ZTensor type to MemRefType.
633643
ZMemRefType zMemRefType =
634644
convertZTensorToMemRefType(*op->result_type_begin());
635645

636646
// Allocate a buffer for the result MemRef.
637-
Value alloc = insertAllocForZMemRef(
638-
zMemRefType, shapeHelper.getOutputDims(), op, rewriter);
639-
640-
// Set layout: if NHWC, we can directly unstickify to NCHW.
641-
if (isNHWCLayout(layout))
642-
layout = getNCHWLayoutAttr(rewriter);
647+
Value alloc = nullptr;
648+
if (isNHWCLayout(layout)) {
649+
if (nnpaEnableCompilerStickUnstick) {
650+
// Compiler-generated unstick hasn't supported NCHW yet.
651+
// This code allocates a NHWC buffer. It gets dims from the NCHW input.
652+
SmallVector<IndexExpr> dimList;
653+
dimList.emplace_back(shapeHelper.getOutputDims()[0]);
654+
dimList.emplace_back(shapeHelper.getOutputDims()[2]);
655+
dimList.emplace_back(shapeHelper.getOutputDims()[3]);
656+
dimList.emplace_back(shapeHelper.getOutputDims()[1]);
657+
MultiDialectBuilder<MemRefBuilder> create(rewriter, loc);
658+
MemRefType resType = zMemRefType.value;
659+
ArrayRef<int64_t> shape = resType.getShape();
660+
alloc = create.mem.alignedAlloc(
661+
MemRefType::get({shape[0], shape[2], shape[3], shape[1]},
662+
resType.getElementType()),
663+
dimList);
664+
} else {
665+
// Otherwise, we can directly stickify from NCHW.
666+
// Set pre-transformed layout to NCHW.
667+
layout = getNCHWLayoutAttr(rewriter);
668+
}
669+
}
670+
if (alloc == nullptr)
671+
alloc = insertAllocForZMemRef(
672+
zMemRefType, shapeHelper.getOutputDims(), op, rewriter);
643673

644674
// Emit a ZLow operation.
645675
rewriter.create<ZLowUnstickOp>(loc, input, alloc, layout);
676+
if (isNHWCLayout(layout) && nnpaEnableCompilerStickUnstick)
677+
// Compiler-generated unstick hasn't supported NCHW yet.
678+
// Explicitly transpose NHWC to NCHW.
679+
alloc =
680+
create.onnx.transposeInt64(alloc, ArrayRef<int64_t>({0, 3, 1, 2}));
646681
rewriter.replaceOp(op, alloc);
647682
return success();
648683
}

test/accelerators/NNPA/backend/CMakeLists.txt

+18-6
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,12 @@ set(NNPA_TEST_LIST
362362
# ==LIM== Input tensor must be less than or equal to 4 dimensions.
363363

364364
# Model
365-
# test_densenet121_cpu # accurary error
366-
#test_inception_v1_cpu,zdnn_conv2d
367-
#test_resnet50_cpu,zdnn_conv2d
368-
#test_shufflenet_cpu,zdnn_matmul_op_ext
369-
#test_squeezenet_cpu,zdnn_conv
370-
#test_vgg19_cpu,zdnn_conv
365+
test_densenet121_cpu,zdnn_conv2d
366+
test_inception_v1_cpu,zdnn_conv2d
367+
test_resnet50_cpu,zdnn_conv2d
368+
test_shufflenet_cpu,zdnn_matmul_op_ext
369+
# test_squeezenet_cpu,zdnn_conv # got NaN results
370+
test_vgg19_cpu,zdnn_conv
371371
)
372372
set(ENV_TEST_CASE_BY_USER "")
373373
foreach(test_name IN LISTS NNPA_TEST_LIST)
@@ -394,6 +394,9 @@ add_custom_target(check-onnx-backend-nnpa
394394
COMMAND
395395
TEST_INSTRUCTION_CHECK=true
396396
ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-nnpa
397+
# Needed for convolution models to avoid NaN outputs.
398+
# Remove this if saturation is enabled by default.
399+
TEST_COMPILE_ARGS="--nnpa-saturation=true"
397400
${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py
398401
DEPENDS
399402
${FILE_GENERATE_DIR}/test.py
@@ -405,6 +408,9 @@ add_custom_target(check-onnx-backend-dynamic-nnpa
405408
ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-dynamic-nnpa
406409
TEST_INSTRUCTION_CHECK=true
407410
TEST_DYNAMIC=true
411+
# Needed for convolution models to avoid NaN outputs.
412+
# Remove this if saturation is enabled by default.
413+
TEST_COMPILE_ARGS="--nnpa-saturation=true"
408414
${NNPA_TESTS_ENVS_DYNAMIC} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py
409415
DEPENDS
410416
${FILE_GENERATE_DIR}/test.py
@@ -418,6 +424,9 @@ add_custom_target(check-onnx-backend-constant-nnpa
418424
# TEST_INSTRUCTION_CHECK=true
419425
ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-constant-nnpa
420426
TEST_CONSTANT=true
427+
# Needed for convolution models to avoid NaN outputs.
428+
# Remove this if saturation is enabled by default.
429+
TEST_COMPILE_ARGS="--nnpa-saturation=true"
421430
${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py
422431
DEPENDS
423432
${FILE_GENERATE_DIR}/test.py
@@ -427,6 +436,9 @@ add_custom_target(check-onnx-backend-constant-nnpa
427436
add_custom_target(check-onnx-backend-compilerlib-nnpa
428437
COMMAND
429438
TEST_COMPILERLIB=true ONNX_HOME=${CMAKE_CURRENT_BINARY_DIR}
439+
# Needed for convolution models to avoid NaN outputs.
440+
# Remove this if saturation is enabled by default.
441+
TEST_COMPILE_ARGS="--nnpa-saturation=true"
430442
${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py
431443
DEPENDS
432444
${FILE_GENERATE_DIR}/test.py

test/backend/common.py

+7
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ def compile_model(model, emit):
142142
command_list.append(model_name)
143143
command_list.append("-o=" + exec_base)
144144

145+
# Additional args passed in by TEST_COMPILE_ARGS
146+
# Args are separated by ';'
147+
additional_args = os.getenv("TEST_COMPILE_ARGS")
148+
if additional_args is not None:
149+
compile_args = additional_args.split(";")
150+
command_list += compile_args
151+
145152
# Call frontend to process model_name.onnx, bit code will be generated.
146153
dynamic_inputs_dims = determine_dynamic_parameters(name)
147154
if args.verbose:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
2+
3+
func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> {
4+
%0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16>
5+
%1 = "zhigh.Unstick"(%0) : (tensor<*xf16>) -> tensor<*xf32>
6+
return %1 : tensor<*xf32>
7+
8+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3 floordiv 64, d1, d2 floordiv 32, d2 mod 32, d3 mod 64)>
9+
// CHECK-LABEL: func.func @should_lower_to_zlow
10+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5x7xf32>) -> memref<1x3x5x7xf32> {
11+
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf16, #map>
12+
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf32>
13+
// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4
14+
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 7){
15+
// CHECK: [[VAR_2_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index)
16+
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x3x5x7xf32>
17+
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_1_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#2, [[VAR_2_]]#3, [[VAR_2_]]#1] : memref<1x5x7x3xf32>
18+
// CHECK: }
19+
// CHECK: "zlow.stick"([[RES_1_]], [[RES_]]) {layout = "NHWC"} : (memref<1x5x7x3xf32>, memref<1x5x7x3xf16, #map>) -> ()
20+
// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1x5x7x3xf32>
21+
// CHECK: "zlow.unstick"([[RES_]], [[RES_]]_1) {layout = "NHWC"} : (memref<1x5x7x3xf16, #map>, memref<1x5x7x3xf32>) -> ()
22+
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5x7xf32>
23+
// CHECK-DAG: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4
24+
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to 5, [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 7, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 3){
25+
// CHECK: [[VAR_2_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index)
26+
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2, [[VAR_2_1_]]#3] : memref<1x5x7x3xf32>
27+
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_]], [[RES_3_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#3, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<1x3x5x7xf32>
28+
// CHECK: }
29+
// CHECK: return [[RES_3_]] : memref<1x3x5x7xf32>
30+
// CHECK: }
31+
}
32+
33+
// -----
34+
35+
func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<1x?x?x7xf32>) -> tensor<*xf32> {
36+
%0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x?x?x7xf32>) -> tensor<*xf16>
37+
%1 = "zhigh.Unstick"(%0) : (tensor<*xf16>) -> tensor<*xf32>
38+
return %1 : tensor<*xf32>
39+
40+
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3 floordiv 64, d1, d2 floordiv 32, d2 mod 32, d3 mod 64)>
41+
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)>
42+
// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d1)>
43+
// CHECK-LABEL: func.func @should_lower_to_zlow_unknown_dims
44+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x?x?x7xf32>) -> memref<1x?x?x7xf32> {
45+
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
46+
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
47+
// CHECK-NOT: separator of consecutive DAGs
48+
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32>
49+
// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32>
50+
// CHECK-NOT: separator of consecutive DAGs
51+
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_0_]], [[VAR_dim_]]) {{.*}}: memref<1x?x7x?xf16, #map>
52+
// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32>
53+
// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32>
54+
// CHECK-NOT: separator of consecutive DAGs
55+
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc([[VAR_dim_1_]], [[VAR_dim_2_]]) {{.*}}: memref<1x?x7x?xf32>
56+
// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4
57+
// CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x?x7xf32>
58+
// CHECK-DAG: [[VAR_dim_5_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x?x?x7xf32>
59+
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_4_]]), [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_4_]], [[VAR_dim_5_]]), [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 7){
60+
// CHECK: [[VAR_2_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index)
61+
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x?x?x7xf32>
62+
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_]], [[RES_1_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#2, [[VAR_2_]]#3, [[VAR_2_]]#1] : memref<1x?x7x?xf32>
63+
// CHECK: }
64+
// CHECK: "zlow.stick"([[RES_1_]], [[RES_]]) {layout = "NHWC"} : (memref<1x?x7x?xf32>, memref<1x?x7x?xf16, #map>) -> ()
65+
// CHECK: [[RES_2_:%.+]] = memref.alloc([[VAR_dim_0_]], [[VAR_dim_]]) {{.*}}: memref<1x?x7x?xf32>
66+
// CHECK: "zlow.unstick"([[RES_]], [[RES_]]_6) {layout = "NHWC"} : (memref<1x?x7x?xf16, #map>, memref<1x?x7x?xf32>) -> ()
67+
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref<1x?x?x7xf32>
68+
// CHECK-DAG: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4
69+
// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_0_]]), [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 7, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_0_]], [[VAR_dim_]])){
70+
// CHECK: [[VAR_2_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index)
71+
// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2, [[VAR_2_1_]]#3] : memref<1x?x7x?xf32>
72+
// CHECK: krnl.store [[LOAD_PARAM_0_MEM_1_]], [[RES_3_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#3, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<1x?x?x7xf32>
73+
// CHECK: }
74+
// CHECK: return [[RES_3_]] : memref<1x?x?x7xf32>
75+
// CHECK: }
76+
}

test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
22

33
func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> {
44
%0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16>

test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s
22

33
func.func @should_lower_to_zlow_1d(%arg0: tensor<7xf32>) -> tensor<*xf16> {
44
%0 = "zhigh.Stick"(%arg0) {layout = "1D"} : (tensor<7xf32>) -> tensor<*xf16>

0 commit comments

Comments
 (0)