Skip to content

Commit fb9544d

Browse files
[NNPA] Simplify rules in zhigh-to-onnx pass and change some pass order (#2951)
* Simplify rules in zhigh-to-onnx pass and change some pass order Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent 40b607d commit fb9544d

File tree

2 files changed

+41
-85
lines changed

2 files changed

+41
-85
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

+16-14
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,12 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
9090

9191
pm.addPass(onnx_mlir::createONNXToZHighPass());
9292
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
93+
9394
// There are more opportunities for const propagation once all zhigh ops were
9495
// generated.
9596
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
9697
pm.addPass(mlir::createCanonicalizerPass());
98+
9799
// Layout propagation at ZHighIR.
98100
pm.addNestedPass<func::FuncOp>(
99101
onnx_mlir::zhigh::createZHighLayoutPropagationPass());
@@ -110,13 +112,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
110112
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
111113
}
112114

113-
// After all optimizations, if there are still light-weight ops (e.g. add,
114-
// sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to
115-
// use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle
116-
// these ops, e.g vectorize the computation.
117-
if (nnpaEnableZHighToOnnx)
118-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());
119-
120115
// One more call to ONNX shape inference/canonicalization/... to update shape
121116
// if possible.
122117
if (enableONNXHybridPass) {
@@ -134,13 +129,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
134129
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
135130
pm.addPass(createScrubDisposablePass());
136131

137-
// Constant propagation at ZHighIR: constant stickify.
138-
// Only support BE machines.
139-
bool isBE = llvm::endianness::native == llvm::endianness::big;
140-
if (isBE)
141-
pm.addNestedPass<func::FuncOp>(
142-
onnx_mlir::zhigh::createZHighConstPropagationPass());
143-
144132
// Experimental feature: Decompose stick/unstick into two phases: layout
145133
// transform and data conversion. Do some optimizations after decomposing.
146134
// Then, recompose again layout and data conversion if they are not optimized.
@@ -152,6 +140,20 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
152140
onnx_mlir::zhigh::createZHighRecomposeToStickUnstickPass());
153141
}
154142

143+
// After all optimizations, if there are still light-weight ops (e.g. add,
144+
// sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to
145+
// use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle
146+
// these ops, e.g vectorize the computation.
147+
if (nnpaEnableZHighToOnnx)
148+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());
149+
150+
// Constant propagation at ZHighIR: constant stickify.
151+
// Only support BE machines.
152+
bool isBE = llvm::endianness::native == llvm::endianness::big;
153+
if (isBE)
154+
pm.addNestedPass<func::FuncOp>(
155+
onnx_mlir::zhigh::createZHighConstPropagationPass());
156+
155157
// Remove common sub-expressions.
156158
pm.addPass(mlir::createCSEPass());
157159

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td

+25-71
Original file line numberDiff line numberDiff line change
@@ -37,107 +37,61 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
3737
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
3838
// (ZHighStickOp %Y))
3939
//===----------------------------------------------------------------------===//
40-
def replaceZHighAddPattern1 : Pat<
41-
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
42-
(ONNXAddOp $x, (ZHighUnstickOp $y)),
43-
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
44-
>;
45-
46-
def replaceZHighAddPattern2 : Pat<
47-
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
48-
(ONNXAddOp (ZHighUnstickOp $x), $y),
49-
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
40+
def replaceZHighAddPattern : Pat<
41+
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
42+
(ONNXAddOp $x, $y),
43+
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
5044
>;
5145

5246
//===----------------------------------------------------------------------===//
5347
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
5448
// (ZHighStickOp %Y))
5549
//===----------------------------------------------------------------------===//
56-
def replaceZHighMulPattern1 : Pat<
57-
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
58-
(ONNXMulOp $x, (ZHighUnstickOp $y)),
59-
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
60-
(addBenefit 1)
61-
>;
62-
63-
def replaceZHighMulPattern2 : Pat<
64-
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
65-
(ONNXMulOp (ZHighUnstickOp $x), $y),
66-
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
67-
(addBenefit 0)
50+
def replaceZHighMulPattern : Pat<
51+
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
52+
(ONNXMulOp $x, $y),
53+
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
6854
>;
6955

7056
//===----------------------------------------------------------------------===//
7157
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
7258
// (ZHighStickOp %Y))
7359
//===----------------------------------------------------------------------===//
74-
def replaceZHighSubPattern1 : Pat<
75-
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
76-
(ONNXSubOp $x, (ZHighUnstickOp $y)),
77-
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
78-
(addBenefit 1)
79-
>;
80-
81-
def replaceZHighSubPattern2 : Pat<
82-
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
83-
(ONNXSubOp (ZHighUnstickOp $x), $y),
84-
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
85-
(addBenefit 0)
60+
def replaceZHighSubPattern : Pat<
61+
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
62+
(ONNXSubOp $x, $y),
63+
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
8664
>;
8765

8866
//===----------------------------------------------------------------------===//
8967
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
9068
// %X),(ZHighStickOp %Y))
9169
// Note: turn off this pattern since NNPA is faster at this moment.
9270
//===----------------------------------------------------------------------===//
93-
//def replaceZHighDivPattern1 : Pat<
94-
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
95-
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
96-
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
97-
// (addBenefit 1)
98-
//>;
99-
//
100-
//def replaceZHighDivPattern2 : Pat<
101-
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
102-
// (ONNXDivOp (ZHighUnstickOp $x), $y),
103-
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
104-
// (addBenefit 0)
105-
//>;
71+
// def replaceZHighDivPattern : Pat<
72+
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
73+
// (ONNXDivOp $x, $y),
74+
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
75+
// >;
10676

10777
//===----------------------------------------------------------------------===//
10878
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
10979
// (ZHighStickOp %Y))
11080
//===----------------------------------------------------------------------===//
111-
def replaceZHighMinPattern1 : Pat<
112-
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
113-
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
114-
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
115-
(addBenefit 1)
116-
>;
117-
118-
def replaceZHighMinPattern2 : Pat<
119-
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
120-
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
121-
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
122-
(addBenefit 0)
81+
def replaceZHighMinPattern : Pat<
82+
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
83+
(CreateONNXMinOp $u, $x, $y),
84+
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
12385
>;
12486

12587
//===----------------------------------------------------------------------===//
12688
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
12789
// (ZHighStickOp %Y))
12890
//===----------------------------------------------------------------------===//
129-
def replaceZHighMaxPattern1 : Pat<
130-
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), $y)),
131-
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
132-
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
133-
(addBenefit 1)
134-
>;
135-
136-
def replaceZHighMaxPattern2 : Pat<
137-
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_, $_))),
138-
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
139-
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
140-
(addBenefit 0)
91+
def replaceZHighMaxPattern : Pat<
92+
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))),
93+
(CreateONNXMaxOp $u, $x, $y),
94+
[(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)]
14195
>;
14296

14397
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)