Skip to content

Commit 6171ea0

Browse files
authored
Merge branch 'main' into hamptonm/feature/llvm
2 parents 771b20f + 80a63f2 commit 6171ea0

File tree

9 files changed

+103
-48
lines changed

9 files changed

+103
-48
lines changed

src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh
1414
OMNNPACompilerOptions
1515
OMONNXOps
1616
OMONNXToKrnl
17+
OMShapeInferencePass
1718
OMZHighOps
1819

1920
ACCEL_INCLUDE_DIRS PRIVATE

src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() {
196196
// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
197197
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
198198
RewritePatternSet Patterns2(context);
199-
getONNXToZHighOneOpPatterns(Patterns2);
199+
getONNXToZHighMultipleOpPatterns(Patterns2);
200200
(void)applyAnalysisConversion(module, target, std::move(Patterns2),
201201
ConversionConfig{.legalizableOps = &legalizedOps2});
202202

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
25+
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"
2526

2627
using namespace mlir;
2728

@@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
328329
patterns.insert<replaceONNXMatMulAddPattern2>(context);
329330
patterns.insert<replaceONNXReluConvPattern>(context);
330331
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
332+
// Shape inference for newly-added operations.
333+
getShapeInferencePatterns(patterns);
331334
}
332335

333336
void ONNXToZHighLoweringPass::runOnOperation() {
334337
ModuleOp module = getOperation();
335338

336-
// Run the unknown dimension analysis to help check equality of unknown
337-
// dimensions at compile time.
338-
onnx_mlir::DimAnalysis dimAnalysis(module);
339-
dimAnalysis.analyze();
340-
341339
// The first thing to define is the conversion target. This will define the
342340
// final target for this lowering.
343341
ConversionTarget target(getContext());
@@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() {
363361
// It's ok to fail.
364362
(void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));
365363

364+
// Run the unknown dimension analysis to help check equality of unknown
365+
// dimensions at compile time.
366+
onnx_mlir::DimAnalysis dimAnalysis(module);
367+
dimAnalysis.analyze();
368+
366369
// Single ONNX to ZHigh operation lowering.
367370
RewritePatternSet patterns(&getContext());
368371
onnx_mlir::getONNXToZHighOneOpPatterns(patterns);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() {
5959

6060
RewritePatternSet patterns(&getContext());
6161
populateWithGenerated(patterns);
62+
zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
63+
zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());
6264

6365
(void)applyPatternsAndFoldGreedily(function, std::move(patterns));
6466
}

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,62 +37,107 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
3737
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
3838
// (ZHighStickOp %Y))
3939
//===----------------------------------------------------------------------===//
40-
def replaceZHighAddPattern : Pat<
41-
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
42-
(ONNXAddOp $x, $y),
43-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
40+
def replaceZHighAddPattern1 : Pat<
41+
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)),
42+
(ONNXAddOp $x, (ZHighUnstickOp $y)),
43+
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
4444
>;
4545

46+
def replaceZHighAddPattern2 : Pat<
47+
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))),
48+
(ONNXAddOp (ZHighUnstickOp $x), $y),
49+
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
50+
>;
4651

4752
//===----------------------------------------------------------------------===//
4853
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
4954
// (ZHighStickOp %Y))
5055
//===----------------------------------------------------------------------===//
51-
def replaceZHighMulPattern : Pat<
52-
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
53-
(ONNXMulOp $x, $y),
54-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
56+
def replaceZHighMulPattern1 : Pat<
57+
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)),
58+
(ONNXMulOp $x, (ZHighUnstickOp $y)),
59+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
60+
(addBenefit 1)
61+
>;
62+
63+
def replaceZHighMulPattern2 : Pat<
64+
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))),
65+
(ONNXMulOp (ZHighUnstickOp $x), $y),
66+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
67+
(addBenefit 0)
5568
>;
5669

5770
//===----------------------------------------------------------------------===//
5871
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
5972
// (ZHighStickOp %Y))
6073
//===----------------------------------------------------------------------===//
61-
def replaceZHighSubPattern : Pat<
62-
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
63-
(ONNXSubOp $x, $y),
64-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
74+
def replaceZHighSubPattern1 : Pat<
75+
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)),
76+
(ONNXSubOp $x, (ZHighUnstickOp $y)),
77+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
78+
(addBenefit 1)
79+
>;
80+
81+
def replaceZHighSubPattern2 : Pat<
82+
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))),
83+
(ONNXSubOp (ZHighUnstickOp $x), $y),
84+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
85+
(addBenefit 0)
6586
>;
6687

6788
//===----------------------------------------------------------------------===//
6889
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
6990
// %X),(ZHighStickOp %Y))
7091
// Note: turn off this pattern since NNPA is faster at this moment.
7192
//===----------------------------------------------------------------------===//
72-
// def replaceZHighDivPattern : Pat<
73-
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
74-
// (ONNXDivOp $x, $y),
75-
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
76-
// >;
93+
//def replaceZHighDivPattern1 : Pat<
94+
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
95+
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
96+
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
97+
// (addBenefit 1)
98+
//>;
99+
//
100+
//def replaceZHighDivPattern2 : Pat<
101+
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
102+
// (ONNXDivOp (ZHighUnstickOp $x), $y),
103+
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
104+
// (addBenefit 0)
105+
//>;
77106

78107
//===----------------------------------------------------------------------===//
79108
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
80109
// (ZHighStickOp %Y))
81110
//===----------------------------------------------------------------------===//
82-
def replaceZHighMinPattern : Pat<
83-
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
84-
(CreateONNXMinOp $u, $x, $y),
85-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
111+
def replaceZHighMinPattern1 : Pat<
112+
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)),
113+
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
114+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
115+
(addBenefit 1)
116+
>;
117+
118+
def replaceZHighMinPattern2 : Pat<
119+
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))),
120+
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
121+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
122+
(addBenefit 0)
86123
>;
87124

88125
//===----------------------------------------------------------------------===//
89126
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
90127
// (ZHighStickOp %Y))
91128
//===----------------------------------------------------------------------===//
92-
def replaceZHighMaxPattern : Pat<
93-
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
94-
(CreateONNXMaxOp $u, $x, $y),
95-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
129+
def replaceZHighMaxPattern1 : Pat<
130+
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)),
131+
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
132+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
133+
(addBenefit 1)
134+
>;
135+
136+
def replaceZHighMaxPattern2 : Pat<
137+
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))),
138+
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
139+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
140+
(addBenefit 0)
96141
>;
97142

98143
//===----------------------------------------------------------------------===//

test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x1xf32>) -> tensor<*xf32> {
66
%0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "onnx.Add1"} : (tensor<10x10xf32>, tensor<10x1xf32>) -> tensor<*xf32>
77
%1 = "onnx.Add"(%arg0, %0) {onnx_node_name = "onnx.Add2"} : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
8-
"onnx.Return"(%1) : (tensor<*xf32>) -> ()
8+
%2 = "onnx.Relu"(%1) {onnx_node_name = "onnx.Relu"} : (tensor<*xf32>) -> tensor<*xf32>
9+
"onnx.Return"(%2) : (tensor<*xf32>) -> ()
910
}
1011

1112
// CHECK-LABEL: func.func @test_instrument_add_onnx_zhigh
@@ -21,6 +22,9 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : ten
2122
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 5 : i64} : () -> ()
2223
// CHECK: "zhigh.Add"
2324
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 6 : i64} : () -> ()
25+
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 5 : i64} : () -> ()
26+
// CHECK: "zhigh.Relu"
27+
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 6 : i64} : () -> ()
2428
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 5 : i64} : () -> ()
2529
// CHECK: "zhigh.Unstick"
2630
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 6 : i64} : () -> ()

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ func.func @test_fuse_onnx_relu_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tens
107107
// CHECK-NOT: separator of consecutive DAGs
108108
// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "HWCK"} : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>
109109
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<2xf32>) -> tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>
110-
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
111-
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x2x31x31xf32>
110+
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
111+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<5x2x31x31xf32>
112112
// CHECK: return [[VAR_6_]] : tensor<5x2x31x31xf32>
113113
// CHECK: }
114114
}

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias(
7979
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
8080
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
8181
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
82-
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
83-
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
82+
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
83+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
8484
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
8585
// CHECK: }
8686
// CHECK-NOT: "onnx.Add"
@@ -105,8 +105,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized(
105105
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
106106
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
107107
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
108-
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
109-
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
108+
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
109+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
110110
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
111111
// CHECK: }
112112
// CHECK-NOT: "onnx.Add"

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ func.func @test_onnx_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
3939

4040
// CHECK-LABEL: func @test_onnx_logsoftmax
4141
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
42-
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<*xf32>
43-
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
44-
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
45-
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
46-
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<10x10xf32>
42+
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<1x10x10xf32>
43+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x10x10xf32>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
44+
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
45+
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf32>
46+
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x10x10xf32>) -> tensor<10x10xf32>
4747
// CHECK: return [[VAR_4_]] : tensor<10x10xf32>
4848
// CHECK: }
4949
}
@@ -57,11 +57,11 @@ func.func @test_onnx_logsoftmax_dyn(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
5757

5858
// CHECK-LABEL: func @test_onnx_logsoftmax_dyn
5959
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
60-
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<*xf32>
61-
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
62-
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
63-
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
64-
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<?x?xf32>
60+
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<1x?x?xf32>
61+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x?x?xf32>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
62+
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
63+
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf32>
64+
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x?x?xf32>) -> tensor<?x?xf32>
6565
// CHECK: return [[VAR_4_]] : tensor<?x?xf32>
6666
// CHECK: }
6767
}

0 commit comments

Comments
 (0)