Skip to content

Commit 1d4ed1b

Browse files
committed
Merge branch 'main' into mem_reduction_stickified
2 parents 53b99c1 + bf905d1 commit 1d4ed1b

33 files changed

+2424
-859
lines changed

docs/Dialects/krnl.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,35 @@ Typically it is used for optional arguments used in KrnlCallop.
929929
| :----: | ----------- |
930930
| `none_val` | none type
931931

932+
### `krnl.parallel_clause` (KrnlParallelClauseOp)
933+
934+
_Attach OpenMP clauses to an index varialbe_
935+
936+
937+
Syntax:
938+
939+
```
940+
operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)?
941+
attr-dict `:` type($parallel_loop_index)
942+
```
943+
944+
Attach OpenMP clauses to an index variable. That index variable
945+
is used to uniquely associate a parallel loop with its clauses.
946+
947+
#### Attributes:
948+
949+
<table>
950+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
951+
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
952+
</table>
953+
954+
#### Operands:
955+
956+
| Operand | Description |
957+
| :-----: | ----------- |
958+
| `parallel_loop_index` | index
959+
| `num_threads` | 32-bit signless integer
960+
932961
### `krnl.parallel` (KrnlParallelOp)
933962

934963
_Mark Krnl loops as parallel loops_
@@ -937,23 +966,38 @@ _Mark Krnl loops as parallel loops_
937966
Syntax:
938967

939968
```
940-
operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops)
969+
operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops)
941970
```
942971

943972
Parallelize the specified loops. When multiple loop specifiers are passed
944973
as parameters, there loops can be parallelized as a collapsed loop.
945974
krnl.parallel should be placed as the last operator before krnl.iterate,
946975
Since we do not want to parallelize the loop until we interpret krnl.block,
947976
krnl.permute and krnl.unroll.
977+
978+
Optionally, a value may specifiy the number of threads requested for the
979+
parallel loop. A proc_bind string may also be specified; valid values are
980+
"primary", "close", or "spread". Default values are used when not specified.
981+
948982
```
949983
krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop
950984
```
951985

986+
Traits: `AttrSizedOperandSegments`
987+
988+
#### Attributes:
989+
990+
<table>
991+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
992+
<tr><td><code>proc_bind</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
993+
</table>
994+
952995
#### Operands:
953996

954997
| Operand | Description |
955998
| :-----: | ----------- |
956999
| `loops` | variadic of any type
1000+
| `num_threads` | 32-bit signless integer
9571001

9581002
### `krnl.permute` (KrnlPermuteOp)
9591003

src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
261261
// Store f32 values back to the (normal layout) output.
262262
DimsExpr outputAF = SymListIE(inputAF);
263263
outputAF[E1] = outputAF[E1] + l;
264-
create.vec.storeIE(vecF32H, alloc, outputAF, {});
264+
create.vec.storeIE(vecF32H, alloc, outputAF);
265265
create.vec.storeIE(
266266
vecF32L, alloc, outputAF, {litArchVLHalf.getValue()});
267267
});
@@ -277,8 +277,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
277277
Value vecF32L = convertOp.getResult(1);
278278
// Save into archVL value buffer.
279279
Value bufferF32 = create.mem.alignedAlloca(bufferType);
280-
create.vec.storeIE(vecF32H, bufferF32, {litZero}, {});
281-
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {});
280+
create.vec.storeIE(vecF32H, bufferF32, {litZero});
281+
create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf});
282282
// Save the remaining values as scalars.
283283
create.scf.forLoop(litZero.getValue(),
284284
remainingScalarValues.getValue(), 1,

src/Compiler/CompilerOptions.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
45+
bool disableQuantZeroPoint; // common for both
4546
bool enableKrnlBufferReuse; // common for both
4647
bool disableMemRefPrefetch; // common for both
4748
EmissionTargetType emissionTarget; // onnx-mlir only
@@ -195,7 +196,7 @@ static llvm::cl::list<std::string, std::vector<std::string>>
195196
llvm::cl::cat(OnnxMlirCommonOptions));
196197

197198
static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
198-
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
199+
llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n"
199200
"Set to 'false' if you want to disable ONNX hybrid pass."),
200201
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
201202
llvm::cl::cat(OnnxMlirCommonOptions));
@@ -208,11 +209,20 @@ static llvm::cl::list<std::string, std::vector<std::string>>
208209

209210
static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
210211
"disable-krnl-op-fusion",
211-
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
212+
llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n"
212213
"Set to 'true' if you want to disable fusion."),
213214
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
214215
llvm::cl::cat(OnnxMlirCommonOptions));
215216

217+
static llvm::cl::opt<bool, true> disable_quantization_zero_point(
218+
"disable-quantization-zero-point",
219+
llvm::cl::desc(
220+
"Disable the use of zero-point in quantization (default=false).\n"
221+
"Set to 'true' if you want to disable the use of zero-point\n"
222+
"in dyn/static quantization/dequantization."),
223+
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
224+
llvm::cl::cat(OnnxMlirCommonOptions));
225+
216226
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
217227
"enable-krnl-buffer-reuse",
218228
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
@@ -223,7 +233,7 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
223233

224234
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
225235
"disable-memref-prefetch",
226-
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
236+
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
227237
"Set to 'true' if you want to disable prefetch."),
228238
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
229239
llvm::cl::cat(OnnxMlirCommonOptions));
@@ -1145,7 +1155,6 @@ std::string getLibraryPath() {
11451155
// as lrodataScript.
11461156
std::string getToolPath(
11471157
const std::string &tool, bool flag /*false by default*/) {
1148-
11491158
if (!flag) {
11501159
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
11511160
llvm::SmallString<8> toolPath(execDir);

src/Compiler/CompilerOptions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
8787
extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
90+
extern bool disableQuantZeroPoint; // common for both
9091
extern bool enableKrnlBufferReuse; // common for both
9192
extern bool disableMemRefPrefetch; // common for both
9293
extern EmissionTargetType emissionTarget; // onnx-mlir only

src/Compiler/CompilerPasses.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ void addKrnlToLLVMPasses(
251251
// The alloca_scope ops are somewhat fragile; canonicalize remove them when
252252
// redundant, which helps reliability of the compilation of these ops.
253253
pm.addPass(mlir::createCanonicalizerPass());
254+
pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass());
254255
}
255256

256257
// The pass below is needed for subview and collapseShape.. Unfortunately,

src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,10 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
742742
<< parallelOp << "\n");
743743
// ToFix handle multiple parallel loop
744744
ValueRange loopRefs = parallelOp.getLoops();
745+
Value numThreads = parallelOp.getNumThreads();
746+
StringAttr procBind = parallelOp.getProcBindAttr();
747+
bool needParallelClause =
748+
numThreads || (procBind && procBind.getValue().size() > 0);
745749

746750
// Obtain the the reference the loop that needs to be parallelized
747751
for (Value loopRef : loopRefs) {
@@ -778,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder,
778782
parallelLoop.getRegion().takeBody(loopToParallel.getRegion());
779783
Operation *yieldOp = &parallelLoop.getBody()->back();
780784
yieldOp->setOperands(reducedValues);
785+
if (needParallelClause) {
786+
// Use clause only for the first one (expected the outermost one).
787+
// Ideally, we would generate here a single, multi-dimensional
788+
// AffineParallelOp, and we would not need to reset the flag.
789+
needParallelClause = false;
790+
// Currently approach: insert after yield and then move before it.
791+
PatternRewriter::InsertionGuard insertGuard(builder);
792+
builder.setInsertionPointAfter(yieldOp);
793+
// Get induction variable.
794+
ValueRange optionalLoopIndices = parallelLoop.getIVs();
795+
assert(optionalLoopIndices.size() >= 1 &&
796+
"expected at least one loop index");
797+
Value parallelLoopIndex = optionalLoopIndices[0];
798+
Operation *newOp = opBuilder.create<KrnlParallelClauseOp>(
799+
loc, parallelLoopIndex, numThreads, procBind);
800+
newOp->moveBefore(yieldOp);
801+
}
781802
// Replace the affine.forOp with affine.parallelOp in loopRefToTop
782803
loopRefToOp[loopRef] = parallelLoop;
783804
loopToParallel.erase();
@@ -975,6 +996,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
975996
target.addIllegalOp<KrnlCopyToBufferOp>();
976997
target.addIllegalOp<KrnlCopyFromBufferOp>();
977998
target.addIllegalOp<KrnlPrefetchOp>();
999+
target.addLegalOp<KrnlParallelClauseOp>();
9781000
target.addLegalOp<AffineYieldOp>();
9791001
target.addLegalOp<AffineLoadOp>();
9801002
target.addLegalOp<AffineStoreOp>();

src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
124124
// Nothing to write.
125125
} else {
126126
// Loop to copy the data.
127-
createAffine.forLoopIE(zeroIE, writeUBs[i], 1,
127+
createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/,
128128
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
129129
loopIndices.emplace_back(loopInd[0]);
130130
genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref,

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,9 +1358,15 @@ Value emitScalarOpFor<ONNXDequantizeLinearOp>(
13581358
Value scaleFloat = scalarOperands[1];
13591359
Value zeroPointInt = scalarOperands[2];
13601360

1361-
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
13621361
Value xFloat = create.math.cast(elementType, XInt);
1363-
Value sub = create.math.sub(xFloat, zeroPointFloat);
1362+
1363+
Value sub;
1364+
if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) {
1365+
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
1366+
sub = create.math.sub(xFloat, zeroPointFloat);
1367+
} else {
1368+
sub = xFloat;
1369+
}
13641370
Value res = create.math.mul(sub, scaleFloat);
13651371
return res;
13661372
}
@@ -1521,8 +1527,7 @@ static LogicalResult getPartiallyFlattenedSimdCode(
15211527

15221528
create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly,
15231529
useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF},
1524-
[&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
1525-
SmallVectorImpl<Value> &resVals, int64_t VL) {
1530+
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
15261531
MultiDialectBuilder<MathBuilder> create(kb);
15271532
Type currElementType = outputElementType;
15281533
if (VL > 1)
@@ -1551,9 +1556,9 @@ static LogicalResult getPartiallyFlattenedSimdCode(
15511556
res = emitPostProcessingFor<OP_TYPE>(rewriter, create.getLoc(),
15521557
op, currElementType, accumulated);
15531558
}
1554-
resVals.emplace_back(res);
1555-
}); // SIMD kernel.
1556-
}); // Outer loops.
1559+
return res;
1560+
}}); // SIMD kernel.
1561+
}); // Outer loops.
15571562

15581563
rewriter.replaceOp(op, alloc);
15591564
return success();

0 commit comments

Comments
 (0)