Skip to content

Commit e2a3de6

Browse files
cjvolzkatungld
andauthored
[Cherry-pick] Fixing the location of DimAnalysis in onnx-to-zhigh pass and some rules in zhigh-to-onnx pass (#2797)
The onnx-to-zhigh pass has two phases: 1) converting multiple onnx ops into a single zhigh op, and 2) converting a single onnx op to a single zhigh op, where the second phase uses DimAnalysis (Patterns in the 1st phase at this moment does not use DimAnalysis) The problem is DimAnalysis is currently called before the 1st phase, which is not good because the 1st phase may change the IR so the information from DimAnalysis is obsoleted to the 2nd phase. Correct position for DimAnalysis would be just before the 2nd phase. Other than that, this PR changes slightly the rules in zhigh-to-onnx pass so that for binary ops, only one input (instead of two) that is from stick would be enough to trigger the rule to convert a zhigh op back to an onnx op. Resolves #2789 --------- (cherry picked from commit 80a63f2) Signed-off-by: Tung D. Le <[email protected]> Signed-off-by: Charles Volzka <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent 0833dac commit e2a3de6

File tree

9 files changed

+103
-48
lines changed

9 files changed

+103
-48
lines changed

src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh
1414
OMNNPACompilerOptions
1515
OMONNXOps
1616
OMONNXToKrnl
17+
OMShapeInferencePass
1718
OMZHighOps
1819

1920
ACCEL_INCLUDE_DIRS PRIVATE

src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() {
196196
// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
197197
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
198198
RewritePatternSet Patterns2(context);
199-
getONNXToZHighOneOpPatterns(Patterns2);
199+
getONNXToZHighMultipleOpPatterns(Patterns2);
200200
(void)applyAnalysisConversion(module, target, std::move(Patterns2),
201201
ConversionConfig{.legalizableOps = &legalizedOps2});
202202

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
25+
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"
2526

2627
using namespace mlir;
2728

@@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
328329
patterns.insert<replaceONNXMatMulAddPattern2>(context);
329330
patterns.insert<replaceONNXReluConvPattern>(context);
330331
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
332+
// Shape inference for newly-added operations.
333+
getShapeInferencePatterns(patterns);
331334
}
332335

333336
void ONNXToZHighLoweringPass::runOnOperation() {
334337
ModuleOp module = getOperation();
335338

336-
// Run the unknown dimension analysis to help check equality of unknown
337-
// dimensions at compile time.
338-
onnx_mlir::DimAnalysis dimAnalysis(module);
339-
dimAnalysis.analyze();
340-
341339
// The first thing to define is the conversion target. This will define the
342340
// final target for this lowering.
343341
ConversionTarget target(getContext());
@@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() {
363361
// It's ok to fail.
364362
(void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));
365363

364+
// Run the unknown dimension analysis to help check equality of unknown
365+
// dimensions at compile time.
366+
onnx_mlir::DimAnalysis dimAnalysis(module);
367+
dimAnalysis.analyze();
368+
366369
// Single ONNX to ZHigh operation lowering.
367370
RewritePatternSet patterns(&getContext());
368371
onnx_mlir::getONNXToZHighOneOpPatterns(patterns);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() {
5959

6060
RewritePatternSet patterns(&getContext());
6161
populateWithGenerated(patterns);
62+
zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
63+
zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());
6264

6365
(void)applyPatternsAndFoldGreedily(function, std::move(patterns));
6466
}

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td

+70-25
Original file line numberDiff line numberDiff line change
@@ -37,62 +37,107 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
3737
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
3838
// (ZHighStickOp %Y))
3939
//===----------------------------------------------------------------------===//
40-
def replaceZHighAddPattern : Pat<
41-
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
42-
(ONNXAddOp $x, $y),
43-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
40+
def replaceZHighAddPattern1 : Pat<
41+
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)),
42+
(ONNXAddOp $x, (ZHighUnstickOp $y)),
43+
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
4444
>;
4545

46+
def replaceZHighAddPattern2 : Pat<
47+
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))),
48+
(ONNXAddOp (ZHighUnstickOp $x), $y),
49+
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
50+
>;
4651

4752
//===----------------------------------------------------------------------===//
4853
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
4954
// (ZHighStickOp %Y))
5055
//===----------------------------------------------------------------------===//
51-
def replaceZHighMulPattern : Pat<
52-
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
53-
(ONNXMulOp $x, $y),
54-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
56+
def replaceZHighMulPattern1 : Pat<
57+
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)),
58+
(ONNXMulOp $x, (ZHighUnstickOp $y)),
59+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
60+
(addBenefit 1)
61+
>;
62+
63+
def replaceZHighMulPattern2 : Pat<
64+
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))),
65+
(ONNXMulOp (ZHighUnstickOp $x), $y),
66+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
67+
(addBenefit 0)
5568
>;
5669

5770
//===----------------------------------------------------------------------===//
5871
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
5972
// (ZHighStickOp %Y))
6073
//===----------------------------------------------------------------------===//
61-
def replaceZHighSubPattern : Pat<
62-
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
63-
(ONNXSubOp $x, $y),
64-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
74+
def replaceZHighSubPattern1 : Pat<
75+
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)),
76+
(ONNXSubOp $x, (ZHighUnstickOp $y)),
77+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
78+
(addBenefit 1)
79+
>;
80+
81+
def replaceZHighSubPattern2 : Pat<
82+
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))),
83+
(ONNXSubOp (ZHighUnstickOp $x), $y),
84+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
85+
(addBenefit 0)
6586
>;
6687

6788
//===----------------------------------------------------------------------===//
6889
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
6990
// %X),(ZHighStickOp %Y))
7091
// Note: turn off this pattern since NNPA is faster at this moment.
7192
//===----------------------------------------------------------------------===//
72-
// def replaceZHighDivPattern : Pat<
73-
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
74-
// (ONNXDivOp $x, $y),
75-
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
76-
// >;
93+
//def replaceZHighDivPattern1 : Pat<
94+
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
95+
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
96+
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
97+
// (addBenefit 1)
98+
//>;
99+
//
100+
//def replaceZHighDivPattern2 : Pat<
101+
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
102+
// (ONNXDivOp (ZHighUnstickOp $x), $y),
103+
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
104+
// (addBenefit 0)
105+
//>;
77106

78107
//===----------------------------------------------------------------------===//
79108
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
80109
// (ZHighStickOp %Y))
81110
//===----------------------------------------------------------------------===//
82-
def replaceZHighMinPattern : Pat<
83-
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
84-
(CreateONNXMinOp $u, $x, $y),
85-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
111+
def replaceZHighMinPattern1 : Pat<
112+
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)),
113+
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
114+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
115+
(addBenefit 1)
116+
>;
117+
118+
def replaceZHighMinPattern2 : Pat<
119+
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))),
120+
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
121+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
122+
(addBenefit 0)
86123
>;
87124

88125
//===----------------------------------------------------------------------===//
89126
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
90127
// (ZHighStickOp %Y))
91128
//===----------------------------------------------------------------------===//
92-
def replaceZHighMaxPattern : Pat<
93-
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
94-
(CreateONNXMaxOp $u, $x, $y),
95-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
129+
def replaceZHighMaxPattern1 : Pat<
130+
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)),
131+
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
132+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
133+
(addBenefit 1)
134+
>;
135+
136+
def replaceZHighMaxPattern2 : Pat<
137+
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))),
138+
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
139+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
140+
(addBenefit 0)
96141
>;
97142

98143
//===----------------------------------------------------------------------===//

test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x1xf32>) -> tensor<*xf32> {
66
%0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "onnx.Add1"} : (tensor<10x10xf32>, tensor<10x1xf32>) -> tensor<*xf32>
77
%1 = "onnx.Add"(%arg0, %0) {onnx_node_name = "onnx.Add2"} : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
8-
"onnx.Return"(%1) : (tensor<*xf32>) -> ()
8+
%2 = "onnx.Relu"(%1) {onnx_node_name = "onnx.Relu"} : (tensor<*xf32>) -> tensor<*xf32>
9+
"onnx.Return"(%2) : (tensor<*xf32>) -> ()
910
}
1011

1112
// CHECK-LABEL: func.func @test_instrument_add_onnx_zhigh
@@ -21,6 +22,9 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : ten
2122
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 5 : i64} : () -> ()
2223
// CHECK: "zhigh.Add"
2324
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 6 : i64} : () -> ()
25+
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 5 : i64} : () -> ()
26+
// CHECK: "zhigh.Relu"
27+
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 6 : i64} : () -> ()
2428
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 5 : i64} : () -> ()
2529
// CHECK: "zhigh.Unstick"
2630
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 6 : i64} : () -> ()

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ func.func @test_fuse_onnx_relu_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tens
107107
// CHECK-NOT: separator of consecutive DAGs
108108
// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "HWCK"} : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>
109109
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<2xf32>) -> tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>
110-
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
111-
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x2x31x31xf32>
110+
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
111+
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<5x2x31x31xf32>
112112
// CHECK: return [[VAR_6_]] : tensor<5x2x31x31xf32>
113113
// CHECK: }
114114
}

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias(
7979
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
8080
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
8181
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
82-
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
83-
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
82+
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
83+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
8484
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
8585
// CHECK: }
8686
// CHECK-NOT: "onnx.Add"
@@ -105,8 +105,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized(
105105
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
106106
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
107107
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
108-
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
109-
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
108+
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
109+
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
110110
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
111111
// CHECK: }
112112
// CHECK-NOT: "onnx.Add"

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir

+10-10
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ func.func @test_onnx_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
3939

4040
// CHECK-LABEL: func @test_onnx_logsoftmax
4141
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
42-
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<*xf32>
43-
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
44-
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
45-
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
46-
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<10x10xf32>
42+
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<1x10x10xf32>
43+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x10x10xf32>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
44+
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
45+
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf32>
46+
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x10x10xf32>) -> tensor<10x10xf32>
4747
// CHECK: return [[VAR_4_]] : tensor<10x10xf32>
4848
// CHECK: }
4949
}
@@ -57,11 +57,11 @@ func.func @test_onnx_logsoftmax_dyn(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
5757

5858
// CHECK-LABEL: func @test_onnx_logsoftmax_dyn
5959
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
60-
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<*xf32>
61-
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
62-
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
63-
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
64-
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<?x?xf32>
60+
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<1x?x?xf32>
61+
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x?x?xf32>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
62+
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
63+
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf32>
64+
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x?x?xf32>) -> tensor<?x?xf32>
6565
// CHECK: return [[VAR_4_]] : tensor<?x?xf32>
6666
// CHECK: }
6767
}

0 commit comments

Comments
 (0)