Skip to content

Commit d03eff2

Browse files
Added support to generate OpenMP parallel construct clauses, at this time for num_threads and proc_bind (#2944)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 9dd7c4a commit d03eff2

15 files changed

+592
-9
lines changed

docs/Dialects/krnl.md

+45-1
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
929929
| :----: | ----------- |
930930
| `none_val` | none type
931931

932+
### `krnl.parallel_clause` (KrnlParallelClauseOp)
933+
934+
_Attach OpenMP clauses to an index varialbe_
935+
936+
937+
Syntax:
938+
939+
```
940+
operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
941+
attr-dict `:` type($parallel_loop_index)
942+
```
943+
944+
Attach OpenMP clauses to an index variable. That index variable
945+
is used to uniquely associate a parallel loop with its clauses.
946+
947+
#### Attributes:
948+
949+
<table>
950+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
951+
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
952+
</table>
953+
954+
#### Operands:
955+
956+
| Operand | Description |
957+
| :-----: | ----------- |
958+
| `parallel_loop_index` | index
959+
| `num_threads` | 32-bit signless integer
960+
932961
### `krnl.parallel` (KrnlParallelOp)
933962

934963
_Mark Krnl loops as parallel loops_
@@ -937,23 +966,38 @@ _Mark Krnl loops as parallel loops_
937966
Syntax:
938967

939968
```
940-
operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
969+
operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
941970
```
942971

943972
Parallelize the specified loops. When multiple loop specifiers are passed
944973
as parameters, there loops can be parallelized as a collapsed loop.
945974
krnl.parallel should be placed as the last operator before krnl.iterate,
946975
Since we do not want to parallelize the loop until we interpret krnl.block,
947976
krnl.permute and krnl.unroll.
977+
978+
Optionally, a value may specifiy the number of threads requested for the
979+
parallel loop. A proc_bind string may also be specified; valid values are
980+
"primary", "close", or "spread". Default values are used when not specified.
981+
948982
```
949983
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
950984
```
951985

986+
Traits: `AttrSizedOperandSegments`
987+
988+
#### Attributes:
989+
990+
<table>
991+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
992+
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
993+
</table>
994+
952995
#### Operands:
953996

954997
| Operand | Description |
955998
| :-----: | ----------- |
956999
| `loops` | variadic of any type
1000+
| `num_threads` | 32-bit signless integer
9571001

9581002
### `krnl.permute` (KrnlPermuteOp)
9591003

src/Compiler/CompilerPasses.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
251251
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
252252
// redundant, which helps reliability of the compilation of these ops.
253253
pm.addPass(mlir::createCanonicalizerPass());
254+
pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
254255
}
255256

256257
// The pass below is needed for subview and collapseShape.. Unfortunately,

src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
742742
<< parallelOp << "\n");
743743
// ToFix handle multiple parallel loop
744744
ValueRange loopRefs = parallelOp.getLoops();
745+
Value numThreads = parallelOp.getNumThreads();
746+
StringAttr procBind = parallelOp.getProcBindAttr();
747+
bool needParallelClause =
748+
numThreads || (procBind && procBind.getValue().size() > 0);
745749

746750
// Obtain the the reference the loop that needs to be parallelized
747751
for (Value loopRef : loopRefs) {
@@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
778782
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
779783
Operation *yieldOp = &parallelLoop.getBody()->back();
780784
yieldOp->setOperands(reducedValues);
785+
if (needParallelClause) {
786+
// Use clause only for the first one (expected the outermost one).
787+
// Ideally, we would generate here a single, multi-dimensional
788+
// AffineParallelOp, and we would not need to reset the flag.
789+
needParallelClause = false;
790+
// Currently approach: insert after yield and then move before it.
791+
PatternRewriter::InsertionGuard insertGuard(builder);
792+
builder.setInsertionPointAfter(yieldOp);
793+
// Get induction variable.
794+
ValueRange optionalLoopIndices = parallelLoop.getIVs();
795+
assert(optionalLoopIndices.size() >= 1 &&
796+
"expected at least one loop index");
797+
Value parallelLoopIndex = optionalLoopIndices[0];
798+
Operation *newOp = opBuilder.create<KrnlParallelClauseOp>(
799+
loc, parallelLoopIndex, numThreads, procBind);
800+
newOp->moveBefore(yieldOp);
801+
}
781802
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
782803
loopRefToOp[loopRef] = parallelLoop;
783804
loopToParallel.erase();
@@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
975996
target.addIllegalOp<KrnlCopyToBufferOp>();
976997
target.addIllegalOp<KrnlCopyFromBufferOp>();
977998
target.addIllegalOp<KrnlPrefetchOp>();
999+
target.addLegalOp<KrnlParallelClauseOp>();
9781000
target.addLegalOp<AffineYieldOp>();
9791001
target.addLegalOp<AffineLoadOp>();
9801002
target.addLegalOp<AffineStoreOp>();

src/Dialect/Krnl/DialectBuilder.cpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,26 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const {
155155
}
156156

157157
void KrnlBuilder::parallel(ValueRange loops) const {
158-
b().template create<KrnlParallelOp>(loc(), loops);
158+
Value noneValue;
159+
StringAttr noneStrAttr;
160+
b().template create<KrnlParallelOp>(loc(), loops, noneValue, noneStrAttr);
161+
}
162+
163+
void KrnlBuilder::parallel(
164+
ValueRange loops, Value numThreads, StringAttr procBind) const {
165+
if (procBind.getValue().size() > 0) {
166+
std::string str = procBind.getValue().str();
167+
assert((str == "primary" || str == "close" || str == "spread") &&
168+
"expected primary, close, or spread for proc_bind");
169+
}
170+
b().template create<KrnlParallelOp>(loc(), loops, numThreads, procBind);
171+
}
172+
173+
void KrnlBuilder::parallelClause(
174+
Value parallelLoopIndex, Value numThreads, StringAttr procBind) const {
175+
// No need to check procBind as its value are derived from parallel(...).
176+
b().template create<KrnlParallelClauseOp>(
177+
loc(), parallelLoopIndex, numThreads, procBind);
159178
}
160179

161180
void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops,

src/Dialect/Krnl/DialectBuilder.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ struct KrnlBuilder : public DialectBuilder {
6666
void permute(mlir::ValueRange loops, mlir::ArrayRef<int64_t> map) const;
6767
mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const;
6868
void parallel(mlir::ValueRange loops) const;
69+
void parallel(mlir::ValueRange loops, mlir::Value numThreads,
70+
mlir::StringAttr procBind) const;
71+
void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads,
72+
mlir::StringAttr procBind) const;
6973

7074
// Iterate over optimized loops given the original loops, lbs and ubs. Lambda
7175
// function implement the body of the loop, and receive a KRNL builder and the

src/Dialect/Krnl/Krnl.td

+27-3
Original file line numberDiff line numberDiff line change
@@ -514,23 +514,47 @@ def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
514514
}];
515515
}
516516

517-
def KrnlParallelOp : Op<Krnl_Dialect, "parallel"> {
517+
def KrnlParallelOp : Op<Krnl_Dialect, "parallel", [AttrSizedOperandSegments]> {
518518
let summary = "Mark Krnl loops as parallel loops";
519519
let description = [{
520520
Parallelize the specified loops. When multiple loop specifiers are passed
521521
as parameters, there loops can be parallelized as a collapsed loop.
522522
krnl.parallel should be placed as the last operator before krnl.iterate,
523523
Since we do not want to parallelize the loop until we interpret krnl.block,
524524
krnl.permute and krnl.unroll.
525+
526+
Optionally, a value may specifiy the number of threads requested for the
527+
parallel loop. A proc_bind string may also be specified; valid values are
528+
"primary", "close", or "spread". Default values are used when not specified.
529+
525530
```
526531
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
527532
```
528533
}];
529534

530-
let arguments = (ins Variadic<AnyType>:$loops);
535+
let arguments = (ins Variadic<AnyType>:$loops,
536+
Optional<I32>:$num_threads,
537+
OptionalAttr<StrAttr>:$proc_bind);
531538

532539
let assemblyFormat = [{
533-
`(` $loops `)` attr-dict `:` type($loops)
540+
`(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
541+
}];
542+
}
543+
544+
def KrnlParallelClauseOp : Op<Krnl_Dialect, "parallel_clause"> {
545+
let summary = "Attach OpenMP clauses to an index varialbe";
546+
let description = [{
547+
Attach OpenMP clauses to an index variable. That index variable
548+
is used to uniquely associate a parallel loop with its clauses.
549+
}];
550+
551+
let arguments = (ins Index: $parallel_loop_index,
552+
Optional<I32>:$num_threads,
553+
OptionalAttr<StrAttr>:$proc_bind);
554+
555+
let assemblyFormat = [{
556+
`(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
557+
attr-dict `:` type($parallel_loop_index)
534558
}];
535559
}
536560

src/Pass/Passes.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
9191
bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd,
9292
bool simdIsEnabled);
9393
std::unique_ptr<mlir::Pass> createProcessScfParallelPrivatePass();
94+
std::unique_ptr<mlir::Pass> createProcessKrnlParallelClausePass();
9495

9596
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
9697
/// Add pass for lowering to Stablehlo IR.

src/Tools/onnx-mlir-opt/RegisterPasses.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ void registerOMPasses(int optLevel) {
9797
return createProcessScfParallelPrivatePass();
9898
});
9999

100+
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
101+
return createProcessKrnlParallelClausePass();
102+
});
103+
100104
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
101105
return krnl::createConvertSeqToMemrefPass();
102106
});

src/Transform/CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion
88
MLIRTransformUtils
99
)
1010

11-
add_onnx_mlir_library(OMScfParallelPrivateRegion
11+
add_onnx_mlir_library(OMScfParallelPrivateRegion
1212
ProcessScfParallelPrivate.cpp
13+
ProcessKrnlParallelClause.cpp
1314

1415
LINK_LIBS PUBLIC
1516
OMSupport
1617
MLIRTransformUtils
18+
MLIROpenMPToLLVM
1719
)
1820

1921
add_onnx_mlir_library(OMInstrument

0 commit comments

Comments
 (0)