Skip to content

Commit 46c8985

Browse files
committed
Merge branch 'main' into pr_extend_dynamic_backend_test
2 parents 27f3892 + 5b1d90b commit 46c8985

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1418
-197
lines changed

docs/BuildOnLinuxOSX.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
1515
``` bash
1616
git clone -n https://github.com/llvm/llvm-project.git
1717
# Check out a specific branch that is known to work with ONNX-MLIR.
18-
cd llvm-project && git checkout a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b && cd ..
18+
cd llvm-project && git checkout 7ac7d418ac2b16fd44789dcf48e2b5d73de3e715 && cd ..
1919
```
2020

2121
[same-as-file]: <> (utils/build-mlir.sh)

docs/BuildOnWindows.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
5252
```shell
5353
git clone -n https://github.com/llvm/llvm-project.git
5454
# Check out a specific branch that is known to work with ONNX-MLIR.
55-
cd llvm-project && git checkout a4ca07f13b560b4f6fa5459eef7159e4f9ee9a6b && cd ..
55+
cd llvm-project && git checkout 7ac7d418ac2b16fd44789dcf48e2b5d73de3e715 && cd ..
5656
```
5757

5858
[same-as-file]: <> (utils/build-mlir.cmd)

src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick(
5757

5858
llvm::cl::opt<bool> nnpaEnableScalarBcastBinary(
5959
"nnpa-enable-scalar-bcast-binary",
60-
llvm::cl::desc("Enable the lowering to NNPA the broadcasting binary ops "
61-
"whose one of the operands is scalar. Currently support "
62-
"ONNXDiv only. Default is false."),
60+
llvm::cl::desc("Enable the lowering to NNPA of binary operations with "
61+
"broadcasting of a scalar operand."
62+
"Currently only enable ONNXDiv. Default is false."),
6363
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
6464

6565
llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile{

src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh
1414
OMNNPACompilerOptions
1515
OMONNXOps
1616
OMONNXToKrnl
17+
OMShapeInferencePass
1718
OMZHighOps
1819

1920
ACCEL_INCLUDE_DIRS PRIVATE

src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() {
196196
// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
197197
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
198198
RewritePatternSet Patterns2(context);
199-
getONNXToZHighOneOpPatterns(Patterns2);
199+
getONNXToZHighMultipleOpPatterns(Patterns2);
200200
(void)applyAnalysisConversion(module, target, std::move(Patterns2),
201201
ConversionConfig{.legalizableOps = &legalizedOps2});
202202

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
25+
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"
2526

2627
using namespace mlir;
2728

@@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
328329
patterns.insert<replaceONNXMatMulAddPattern2>(context);
329330
patterns.insert<replaceONNXReluConvPattern>(context);
330331
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
332+
// Shape inference for newly-added operations.
333+
getShapeInferencePatterns(patterns);
331334
}
332335

333336
void ONNXToZHighLoweringPass::runOnOperation() {
334337
ModuleOp module = getOperation();
335338

336-
// Run the unknown dimension analysis to help check equality of unknown
337-
// dimensions at compile time.
338-
onnx_mlir::DimAnalysis dimAnalysis(module);
339-
dimAnalysis.analyze();
340-
341339
// The first thing to define is the conversion target. This will define the
342340
// final target for this lowering.
343341
ConversionTarget target(getContext());
@@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() {
363361
// It's ok to fail.
364362
(void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));
365363

364+
// Run the unknown dimension analysis to help check equality of unknown
365+
// dimensions at compile time.
366+
onnx_mlir::DimAnalysis dimAnalysis(module);
367+
dimAnalysis.analyze();
368+
366369
// Single ONNX to ZHigh operation lowering.
367370
RewritePatternSet patterns(&getContext());
368371
onnx_mlir::getONNXToZHighOneOpPatterns(patterns);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() {
5959

6060
RewritePatternSet patterns(&getContext());
6161
populateWithGenerated(patterns);
62+
zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
63+
zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());
6264

6365
(void)applyPatternsAndFoldGreedily(function, std::move(patterns));
6466
}

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,62 +37,107 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
3737
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
3838
// (ZHighStickOp %Y))
3939
//===----------------------------------------------------------------------===//
40-
def replaceZHighAddPattern : Pat<
41-
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
42-
(ONNXAddOp $x, $y),
43-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
40+
def replaceZHighAddPattern1 : Pat<
41+
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)),
42+
(ONNXAddOp $x, (ZHighUnstickOp $y)),
43+
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
4444
>;
4545

46+
def replaceZHighAddPattern2 : Pat<
47+
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))),
48+
(ONNXAddOp (ZHighUnstickOp $x), $y),
49+
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
50+
>;
4651

4752
//===----------------------------------------------------------------------===//
4853
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
4954
// (ZHighStickOp %Y))
5055
//===----------------------------------------------------------------------===//
51-
def replaceZHighMulPattern : Pat<
52-
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
53-
(ONNXMulOp $x, $y),
54-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
56+
def replaceZHighMulPattern1 : Pat<
57+
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)),
58+
(ONNXMulOp $x, (ZHighUnstickOp $y)),
59+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
60+
(addBenefit 1)
61+
>;
62+
63+
def replaceZHighMulPattern2 : Pat<
64+
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))),
65+
(ONNXMulOp (ZHighUnstickOp $x), $y),
66+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
67+
(addBenefit 0)
5568
>;
5669

5770
//===----------------------------------------------------------------------===//
5871
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
5972
// (ZHighStickOp %Y))
6073
//===----------------------------------------------------------------------===//
61-
def replaceZHighSubPattern : Pat<
62-
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
63-
(ONNXSubOp $x, $y),
64-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
74+
def replaceZHighSubPattern1 : Pat<
75+
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)),
76+
(ONNXSubOp $x, (ZHighUnstickOp $y)),
77+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
78+
(addBenefit 1)
79+
>;
80+
81+
def replaceZHighSubPattern2 : Pat<
82+
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))),
83+
(ONNXSubOp (ZHighUnstickOp $x), $y),
84+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
85+
(addBenefit 0)
6586
>;
6687

6788
//===----------------------------------------------------------------------===//
6889
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
6990
// %X),(ZHighStickOp %Y))
7091
// Note: turn off this pattern since NNPA is faster at this moment.
7192
//===----------------------------------------------------------------------===//
72-
// def replaceZHighDivPattern : Pat<
73-
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
74-
// (ONNXDivOp $x, $y),
75-
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
76-
// >;
93+
//def replaceZHighDivPattern1 : Pat<
94+
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
95+
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
96+
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
97+
// (addBenefit 1)
98+
//>;
99+
//
100+
//def replaceZHighDivPattern2 : Pat<
101+
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
102+
// (ONNXDivOp (ZHighUnstickOp $x), $y),
103+
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
104+
// (addBenefit 0)
105+
//>;
77106

78107
//===----------------------------------------------------------------------===//
79108
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
80109
// (ZHighStickOp %Y))
81110
//===----------------------------------------------------------------------===//
82-
def replaceZHighMinPattern : Pat<
83-
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
84-
(CreateONNXMinOp $u, $x, $y),
85-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
111+
def replaceZHighMinPattern1 : Pat<
112+
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)),
113+
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
114+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
115+
(addBenefit 1)
116+
>;
117+
118+
def replaceZHighMinPattern2 : Pat<
119+
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))),
120+
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
121+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
122+
(addBenefit 0)
86123
>;
87124

88125
//===----------------------------------------------------------------------===//
89126
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
90127
// (ZHighStickOp %Y))
91128
//===----------------------------------------------------------------------===//
92-
def replaceZHighMaxPattern : Pat<
93-
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
94-
(CreateONNXMaxOp $u, $x, $y),
95-
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
129+
def replaceZHighMaxPattern1 : Pat<
130+
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)),
131+
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
132+
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
133+
(addBenefit 1)
134+
>;
135+
136+
def replaceZHighMaxPattern2 : Pat<
137+
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))),
138+
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
139+
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
140+
(addBenefit 0)
96141
>;
97142

98143
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)