Skip to content

Commit e9282c6

Browse files
committed
Merge branch 'main' into parallelop_forkop_with_omp_pr1_opdef
2 parents 338e42a + 893cf89 commit e9282c6

38 files changed

+1209
-973
lines changed

.github/workflows/macos-amd64-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on: [push, pull_request]
44

55
jobs:
66
build:
7-
runs-on: macos-latest
7+
runs-on: macos-13
88
steps:
99
- uses: actions/checkout@v3
1010
with:

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
16+
#include "mlir/Conversion/Passes.h"
1617
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1718
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
1819
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -215,6 +216,10 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
215216
addKrnlToAffinePasses(pm);
216217
// Optimizations at ZLow that needs affine map in MemRef.
217218
pm.addPass(zlow::createZLowRewritePass());
219+
// Late generation of code for stick/unstick, needed to be after a
220+
// ZLowRewrite pass.
221+
if (nnpaEnableCompilerStickUnstick)
222+
pm.addPass(zlow::createZLowStickExpansionPass(enableParallel));
218223
pm.addPass(mlir::createCanonicalizerPass());
219224
// Normalize MemRefs.
220225
normalizeMemRefsPasses(pm);
@@ -223,6 +228,11 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
223228
addKrnlToAffinePasses(pm);
224229
// Optimizations at ZLow after normalizing MemRefs.
225230
pm.addPass(zlow::createZLowRewritePass());
231+
// The createZLowStickExpansion pass may create parallel constructs,
232+
// they need to be handled here.
233+
if (nnpaEnableCompilerStickUnstick && enableParallel)
234+
pm.addPass(mlir::createConvertSCFToOpenMPPass());
235+
226236
pm.addPass(mlir::createCanonicalizerPass());
227237
// Constant folding for std.alloc.
228238
pm.addNestedPass<func::FuncOp>(onnx_mlir::createFoldStdAllocPass());

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,11 @@ bool isSuitableForZDNN<ONNXBatchNormalizationInferenceModeOp>(
981981

982982
return true;
983983
}
984+
985+
/// Check legality for ONNXReshapeOp.
986+
template <>
987+
bool isSuitableForZDNN<ONNXReshapeOp>(
988+
ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
989+
// Noop Reshape is suitable for zAIU as this pass removes such reshape ops.
990+
return isIdentityReshape(op, dimAnalysis);
991+
}

src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "src/Dialect/ONNX/ElementsAttr/WideNum.hpp"
3535
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
3636
#include "src/Dialect/ONNX/ONNXOps.hpp"
37+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
3738
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
3839
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"
3940
#include "src/Support/TypeUtilities.hpp"
@@ -467,6 +468,31 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern<OP_TYPE> {
467468
}
468469
};
469470

471+
class RemoveReshapeWithIdentityPattern
472+
: public OpRewritePattern<ONNXReshapeOp> {
473+
public:
474+
using OpRewritePattern<ONNXReshapeOp>::OpRewritePattern;
475+
476+
DimAnalysis *dimAnalysis;
477+
478+
RemoveReshapeWithIdentityPattern(
479+
MLIRContext *context, DimAnalysis *dimAnalysis)
480+
: OpRewritePattern<ONNXReshapeOp>(context, 1001),
481+
dimAnalysis(dimAnalysis) {}
482+
483+
LogicalResult matchAndRewrite(
484+
ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const override {
485+
if (!isIdentityReshape(reshapeOp, dimAnalysis))
486+
return failure();
487+
488+
// Rewrite
489+
Operation *op = reshapeOp.getOperation();
490+
Value data = reshapeOp.getData();
491+
rewriter.replaceOp(op, data);
492+
return success();
493+
}
494+
};
495+
470496
//===----------------------------------------------------------------------===//
471497
// Rewrite ONNX ops to ZHigh ops and ONNX ops for ZHigh.
472498
//===----------------------------------------------------------------------===//
@@ -482,6 +508,8 @@ void getRewriteONNXForZHighPatterns(
482508
patterns.getContext(), dimAnalysis);
483509
patterns.insert<AddSubWithRHSZeroExpandPattern<ONNXSubOp>>(
484510
patterns.getContext(), dimAnalysis);
511+
patterns.insert<RemoveReshapeWithIdentityPattern>(
512+
patterns.getContext(), dimAnalysis);
485513
}
486514

487515
void getRewriteONNXForZHighDynamicallyLegal(
@@ -643,6 +671,13 @@ void getRewriteONNXForZHighDynamicallyLegal(
643671
return isSuitableForZDNN<ONNXConvOp>(op) ||
644672
!canInferencePadsForNNPAConv(op);
645673
});
674+
addDynamicallyLegalOpFor<ONNXReshapeOp>(target, dimAnalysis,
675+
[](ONNXReshapeOp op, const DimAnalysis *dimAnalysis) {
676+
// Get rid of identity reshape here, as it impacts stick/unstick.
677+
// So all reshape are legal, unless it is an identity reshape, in which
678+
// case there is a rule here to remove it.
679+
return !isIdentityReshape(op, dimAnalysis);
680+
});
646681
}
647682

648683
struct RewriteONNXForZHighPass

0 commit comments

Comments
 (0)