Skip to content

Commit 8759c27

Browse files
moved most iterateIE with 1 loop to forLoopIE new interface
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent a4666da commit 8759c27

File tree

8 files changed

+24
-117
lines changed

8 files changed

+24
-117
lines changed

src/Conversion/ONNXToKrnl/NN/Normalization.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -941,17 +941,14 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern<OP_TYPE> {
941941
invStdDevFlatMemRef);
942942
// Alloc mem for reductions (should be private if parallel)
943943
MemRefType tmpRedType = MemRefType::get({B, totVL}, elementType);
944-
// Iterate over 1st dim by block
945-
ValueRange loopDefs = create.krnl.defineLoops(1);
946-
IndexExpr zero = LitIE(0);
947-
ValueRange blockedLoopDefs = create.krnl.block(loopDefs[0], B);
948-
Value blockedLoopDef = blockedLoopDefs[0];
944+
// Iterate over 1st dim by block B.
945+
bool useParallel = false;
949946
if (enableParallel) {
950947
int64_t parId;
951948
SmallVector<IndexExpr, 1> lb(1, LitIE(0)), ub(1, XFlatDims[0]);
952949
if (findSuitableParallelDimension(lb, ub, 0, 1, parId,
953950
/*min iter for going parallel*/ 4)) {
954-
create.krnl.parallel(blockedLoopDef);
951+
useParallel = true;
955952
onnxToKrnlParallelReport(op, true, 0, lb[0], ub[0], "in layer-norm");
956953
} else {
957954
onnxToKrnlParallelReport(
@@ -960,8 +957,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern<OP_TYPE> {
960957
} else {
961958
onnxToKrnlParallelReport(op, false, -1, -1, "no parallel in layer norm");
962959
}
963-
create.krnl.iterateIE({loopDefs[0]}, {blockedLoopDef}, {zero},
964-
{XFlatDims[0]},
960+
create.krnl.forLoopIE(LitIE(0), XFlatDims[0], /*step*/ B, useParallel,
965961
[&](const KrnlBuilder &ck, ValueRange blockedLoopIndices) {
966962
MDBuilder create(ck);
967963
IndexExprScope innerScope(ck);

src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp

+10-12
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,13 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
160160

161161
if (direction == FORWARD || direction == BIDIRECTIONAL) {
162162
IndexExprScope childScope(create.krnl);
163-
mlir::ValueRange loopDef = create.krnl.defineLoops(1);
164-
llvm::SmallVector<IndexExpr, 4> lbs(1, LitIE(0));
165-
llvm::SmallVector<IndexExpr, 4> ubs;
163+
IndexExpr lb = LitIE(0);
164+
IndexExpr ub;
166165
if (!mlir::ShapedType::isDynamic(sequenceDimSize))
167-
ubs.emplace_back(LitIE(sequenceDimSize));
166+
ub = LitIE(sequenceDimSize);
168167
else
169-
ubs.emplace_back(create.krnlIE.getShapeAsDim(X, 0));
170-
create.krnl.iterateIE(loopDef, loopDef, lbs, ubs,
168+
ub = create.krnlIE.getShapeAsDim(X, 0);
169+
create.krnl.forLoopIE(lb, ub, /*step*/ 1, /*par*/ false,
171170
[&](const KrnlBuilder &createKrnl, mlir::ValueRange loopInd) {
172171
MathBuilder createMath(createKrnl);
173172
mlir::Value directionIV =
@@ -185,14 +184,13 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern<RNNOp> {
185184

186185
if (direction == REVERSE || direction == BIDIRECTIONAL) {
187186
IndexExprScope childScope(create.krnl);
188-
mlir::ValueRange loopDef = create.krnl.defineLoops(1);
189-
llvm::SmallVector<IndexExpr, 4> lbs(1, LitIE(0));
190-
llvm::SmallVector<IndexExpr, 4> ubs;
187+
IndexExpr lb = LitIE(0);
188+
IndexExpr ub;
191189
if (!mlir::ShapedType::isDynamic(sequenceDimSize))
192-
ubs.emplace_back(LitIE(sequenceDimSize));
190+
ub = LitIE(sequenceDimSize);
193191
else
194-
ubs.emplace_back(create.krnlIE.getShapeAsDim(X, 0));
195-
create.krnl.iterateIE(loopDef, loopDef, lbs, ubs,
192+
ub = create.krnlIE.getShapeAsDim(X, 0);
193+
create.krnl.forLoopIE(lb, ub, /*step*/ 1, /*par*/ false,
196194
[&](const KrnlBuilder &ck, mlir::ValueRange loopInd) {
197195
MultiDialectBuilder<MemRefBuilder, MathBuilder> create(ck);
198196

src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,7 @@ struct ONNXSequenceEraseOpLowering
6464

6565
// Copy the elements before the position
6666
KrnlBuilder createKrnl(rewriter, loc);
67-
SmallVector<IndexExpr, 1> lbs;
68-
lbs.emplace_back(LitIE(0));
69-
SmallVector<IndexExpr, 1> ubs;
70-
ubs.emplace_back(positionIE);
71-
ValueRange firstLoopDef = createKrnl.defineLoops(1);
72-
createKrnl.iterateIE(firstLoopDef, firstLoopDef, lbs, ubs,
67+
createKrnl.forLoopIE(LitIE(0), positionIE, /*step*/ 1, /*par*/ false,
7368
[&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
7469
Value element =
7570
createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]);
@@ -78,12 +73,7 @@ struct ONNXSequenceEraseOpLowering
7873
});
7974

8075
// Copy the elements after the position
81-
SmallVector<IndexExpr, 1> lbs1;
82-
lbs1.emplace_back(positionIE + 1);
83-
SmallVector<IndexExpr, 1> ubs1;
84-
ubs1.emplace_back(boundIE);
85-
ValueRange secondLoopDef = createKrnl.defineLoops(1);
86-
createKrnl.iterateIE(secondLoopDef, secondLoopDef, lbs1, ubs1,
76+
createKrnl.forLoopIE(positionIE + 1, boundIE, /*step*/ 1, /*par*/ false,
8777
[&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
8878
Value element =
8979
createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]);

src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,7 @@ struct ONNXSequenceInsertOpLowering
7777
// compilation problem due to the unranked tensor even though
7878
// the loop will not be reached at runtime.
7979
} else {
80-
SmallVector<IndexExpr, 1> lbs;
81-
lbs.emplace_back(LitIE(0));
82-
SmallVector<IndexExpr, 1> ubs;
83-
ubs.emplace_back(positionIE);
84-
ValueRange firstLoopDef = createKrnl.defineLoops(1);
85-
createKrnl.iterateIE(firstLoopDef, firstLoopDef, lbs, ubs,
80+
createKrnl.forLoopIE(LitIE(0), positionIE, /*step*/ 1, /*par*/ false,
8681
[&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
8782
auto element =
8883
createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]);
@@ -91,12 +86,7 @@ struct ONNXSequenceInsertOpLowering
9186
});
9287

9388
// Copy the elements after the position
94-
SmallVector<IndexExpr, 1> lbs1;
95-
lbs1.emplace_back(positionIE + 1);
96-
SmallVector<IndexExpr, 1> ubs1;
97-
ubs1.emplace_back(boundIE);
98-
ValueRange secondLoopDef = createKrnl.defineLoops(1);
99-
createKrnl.iterateIE(secondLoopDef, secondLoopDef, lbs1, ubs1,
89+
createKrnl.forLoopIE(positionIE + 1, boundIE, /*step*/ 1, /*par*/ false,
10090
[&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
10191
auto element =
10292
createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]);

src/Conversion/ONNXToKrnl/Tensor/Compress.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ struct ONNXCompressOpLowering : public OpConversionPattern<ONNXCompressOp> {
6060
// Now create a loop to iterate over all conditions.
6161
Value condMemRef = adaptor.getCondition();
6262
IndexExpr condShapeFirstRank = create.krnlIE.getShapeAsDim(condMemRef, 0);
63-
ValueRange loopDef = create.krnl.defineLoops(1);
64-
create.krnl.iterateIE(loopDef, loopDef, {zeroIE}, {condShapeFirstRank},
63+
create.krnl.forLoopIE(zeroIE, condShapeFirstRank, /*step*/ 1, /*par*/ false,
6564
[&](const KrnlBuilder createKrnl, ValueRange loopInd) {
6665
MathBuilder createMath(createKrnl);
6766
// Load the condition
@@ -215,9 +214,8 @@ struct ONNXCompressOpLowering : public OpConversionPattern<ONNXCompressOp> {
215214
innerLbs.emplace_back(inputLbs[i]);
216215
innerUbs.emplace_back(inputUbs[i]);
217216
}
218-
ValueRange axisLoopDef = create.krnl.defineLoops(1);
219-
create.krnl.iterateIE(axisLoopDef, axisLoopDef, {inputLbs[axisValue]},
220-
{inputUbs[axisValue]},
217+
create.krnl.forLoopIE(inputLbs[axisValue], inputUbs[axisValue],
218+
/*step*/ 1, /*par*/ false,
221219
[&](const KrnlBuilder createKrnl, ValueRange axisLoopInd) {
222220
MultiDialectBuilder<KrnlBuilder, MathBuilder, SCFBuilder> create(
223221
createKrnl);

src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ struct ONNXOneHotOpLowering : public OpConversionPattern<ONNXOneHotOp> {
6262
create.krnlIE.getShapeAsDims(indices, indicesUbs);
6363
ValueRange indicesLoopDef = create.krnl.defineLoops(indicesRank);
6464
create.krnl.iterateIE(indicesLoopDef, indicesLoopDef, indicesLbs,
65-
indicesUbs, [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
65+
indicesUbs,
66+
[&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) {
6667
// Loop for all input values.
6768
MathBuilder createMath(createKrnl);
6869
// Input val is allowed to be any integer/float. Read and convert to
@@ -89,8 +90,7 @@ struct ONNXOneHotOpLowering : public OpConversionPattern<ONNXOneHotOp> {
8990
Value onValueIndexVal = onValueIndex.getValue();
9091
// Now we have the index that is on, iterate over the depth values
9192
// along axis, and set the right one to the value on.
92-
ValueRange depthLoopDef = createKrnl.defineLoops(1);
93-
createKrnl.iterateIE(depthLoopDef, depthLoopDef, {zeroIE}, {depth},
93+
createKrnl.forLoopIE(zeroIE, depth, /*step*/ 1, /*par*/ false,
9494
[&](const KrnlBuilder createBuilder, ValueRange depthLoopInd) {
9595
MathBuilder createMath(createKrnl);
9696
Value onCond = createMath.eq(depthLoopInd[0], onValueIndexVal);

src/Dialect/Krnl/DialectBuilder.cpp

-49
Original file line numberDiff line numberDiff line change
@@ -286,55 +286,6 @@ void KrnlBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL,
286286
reductionBodyFn, postReductionBodyFn);
287287
}
288288

289-
void KrnlBuilder::forExplicitlyParallelLoopIE(IndexExpr lb, IndexExpr ub,
290-
int64_t stepModifier, IndexExpr numThreads, StringAttr procBind,
291-
KrnlLoopBodyFn builderFn) const {
292-
IndexExpr zero = LitIE(0);
293-
if (numThreads.isLiteralAndIdenticalTo(1)) {
294-
// Noop. Invoke function with (0, lb, ub).
295-
SmallVector<Value, 4> params = {
296-
zero.getValue(), lb.getValue(), ub.getValue()};
297-
MultiDialectBuilder<KrnlBuilder> create(*this);
298-
builderFn(create.krnl, params);
299-
return;
300-
}
301-
if (numThreads.isLiteralAndIdenticalTo(-1)) {
302-
// Dynamic, get value from OMP.
303-
llvm_unreachable("not implemented yet"); // hi alex.
304-
}
305-
IndexExpr trip = ub - lb;
306-
if (stepModifier > 1)
307-
trip = trip.ceilDiv(stepModifier);
308-
IndexExpr block = trip.ceilDiv(numThreads);
309-
if (stepModifier > 1)
310-
block = block * stepModifier;
311-
// Create parallel loop with numThreads.
312-
ValueRange originalLoopDef = defineLoops(1);
313-
llvm::SmallVector<Value, 1> optLoopDef(1, originalLoopDef[0]);
314-
parallel(optLoopDef[0], numThreads.getValue(), procBind);
315-
iterateIE(originalLoopDef, optLoopDef, {lb}, {ub},
316-
[&](const KrnlBuilder &kb, ValueRange loopInd) {
317-
// Compute current LB/UB for this thread.
318-
IndexExprScope currScope(kb);
319-
IndexExpr tid = DimIE(loopInd[0]);
320-
IndexExpr currLB = tid * SymIE(block);
321-
IndexExpr currUB = currLB + SymIE(block);
322-
currUB = IndexExpr::max(currUB, SymIE(ub));
323-
SmallVector<Value, 4> params = {
324-
tid.getValue(), currLB.getValue(), currUB.getValue()};
325-
// Invoke function with (tid, currLB, currUB).
326-
builderFn(kb, params);
327-
});
328-
}
329-
330-
void KrnlBuilder::forExplicitlyParallelLoopIE(IndexExpr lb, IndexExpr ub,
331-
int64_t stepModifier, IndexExpr numThreads,
332-
KrnlLoopBodyFn builderFn) const {
333-
StringAttr procBind; // Empty == default, unspecified.
334-
forExplicitlyParallelLoopIE(
335-
lb, ub, stepModifier, numThreads, procBind, builderFn);
336-
}
337-
338289
void KrnlBuilder::yield(ValueRange iterArgs) const {
339290
b().create<KrnlYieldOp>(loc(), iterArgs);
340291
}

src/Dialect/Krnl/DialectBuilder.hpp

-16
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,6 @@ struct KrnlBuilder : public DialectBuilder {
106106
void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel,
107107
KrnlLoopBodyFn builderFn) const;
108108

109-
// Create a parallel loop iterating from tid=0 to tNum. When tNum== -1, the
110-
// maximum number of thread is extracted at runtime (via is obtained from
111-
// omp_get_num_threads). The parameters passed to the builder functions are
112-
// (tid, currLB, currUB) where the currLB and currUB define the work assigned
113-
// to a given thread tid. If follows the OpenMP static schedule, assigning
114-
// roughly ceil((ub - lb)/tNum) iterations per thread. When stepModifier>1,
115-
// assigned chunks are multiple of stepModifier.
116-
// When tNum == 1, this is a nop, essentially calling the builder function
117-
// with (0, lb, ub).
118-
void forExplicitlyParallelLoopIE(IndexExpr lb, IndexExpr ub,
119-
int64_t stepModifier, IndexExpr numThreads,
120-
KrnlLoopBodyFn builderFn) const;
121-
void forExplicitlyParallelLoopIE(IndexExpr lb, IndexExpr ub,
122-
int64_t stepModifier, IndexExpr numThreads, mlir::StringAttr procBind,
123-
KrnlLoopBodyFn builderFn) const;
124-
125109
// Common simd loop interface (krnl/affine/scf).
126110
/*
127111
Iterate over a loop executing the loop body in SIMD mode (of vector length

0 commit comments

Comments
 (0)