Skip to content

Commit df11f6a

Browse files
added parallel back
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 234fc65 commit df11f6a

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
16+
#include "mlir/Conversion/Passes.h"
1617
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1718
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
1819
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -215,9 +216,10 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
215216
addKrnlToAffinePasses(pm);
216217
// Optimizations at ZLow that needs affine map in MemRef.
217218
pm.addPass(zlow::createZLowRewritePass());
218-
if (nnpaEnableCompilerStickUnstick) {
219-
pm.addPass(zlow::createZLowStickExpansionPass());
220-
}
219+
// Late generation of code for stick/unstick, needed to be after a
220+
// ZLowRewrite pass.
221+
if (nnpaEnableCompilerStickUnstick)
222+
pm.addPass(zlow::createZLowStickExpansionPass(enableParallel));
221223
pm.addPass(mlir::createCanonicalizerPass());
222224
// Normalize MemRefs.
223225
normalizeMemRefsPasses(pm);
@@ -226,6 +228,11 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
226228
addKrnlToAffinePasses(pm);
227229
// Optimizations at ZLow after normalizing MemRefs.
228230
pm.addPass(zlow::createZLowRewritePass());
231+
// The createZLowStickExpansion pass may create parallel constructs,
232+
// they need to be handled here.
233+
if (nnpaEnableCompilerStickUnstick && enableParallel)
234+
pm.addPass(mlir::createConvertSCFToOpenMPPass());
235+
229236
pm.addPass(mlir::createCanonicalizerPass());
230237
// Constant folding for std.alloc.
231238
pm.addNestedPass<func::FuncOp>(onnx_mlir::createFoldStdAllocPass());

src/Accelerators/NNPA/Pass/NNPAPasses.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ namespace zlow {
6161
std::unique_ptr<mlir::Pass> createZLowRewritePass();
6262

6363
/// Add pass for rewriting ZLow ops.
64-
std::unique_ptr<mlir::Pass> createZLowStickExpansionPass();
64+
std::unique_ptr<mlir::Pass> createZLowStickExpansionPass(
65+
bool enableParallel = false);
6566

6667
/// Add pass for rewriting ZLow ops.
6768
std::unique_ptr<mlir::Pass> createZLowDummyOpForMultiDerefPass();

src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,11 @@ class ZLowStickExpansionPass
451451
: public PassWrapper<ZLowStickExpansionPass, OperationPass<func::FuncOp>> {
452452

453453
public:
454-
bool enableParallelism = false; // hi alex, fix this.
454+
ZLowStickExpansionPass(bool enableParallel)
455+
: PassWrapper<ZLowStickExpansionPass, OperationPass<func::FuncOp>>(),
456+
enableParallel(enableParallel) {}
457+
458+
bool enableParallel;
455459

456460
StringRef getArgument() const override { return "zlow-stick-expansion"; }
457461

@@ -465,8 +469,8 @@ class ZLowStickExpansionPass
465469
llvm::SmallDenseSet<ZLowStickOp, 4> removableStickOps;
466470
ConversionTarget target(getContext());
467471
RewritePatternSet patterns(&getContext());
468-
patterns.insert<StickExpansionPattern>(&getContext(), enableParallelism);
469-
patterns.insert<UnstickExpansionPattern>(&getContext(), enableParallelism);
472+
patterns.insert<StickExpansionPattern>(&getContext(), enableParallel);
473+
patterns.insert<UnstickExpansionPattern>(&getContext(), enableParallel);
470474
// patterns.insert<UnstickExpansionPattern>(&getContext());
471475

472476
fprintf(stderr, "hi alex, apply patterns in zlow stick expansion\n");
@@ -484,8 +488,8 @@ class ZLowStickExpansionPass
484488
}
485489
};
486490

487-
std::unique_ptr<Pass> createZLowStickExpansionPass() {
488-
return std::make_unique<ZLowStickExpansionPass>();
491+
std::unique_ptr<Pass> createZLowStickExpansionPass(bool enableParallel) {
492+
return std::make_unique<ZLowStickExpansionPass>(enableParallel);
489493
}
490494

491495
} // namespace zlow

0 commit comments

Comments
 (0)