Skip to content

Commit cf1fc54

Browse files
Merge branch 'main' into zhigh-to-onnx-simplify
2 parents 5b0cc3f + 40b607d commit cf1fc54

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+291
-316
lines changed

docs/LoweringCode.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct KrnlBuilder : public DialectBuilder {
105105

106106
void iterate(ValueRange originalLoops, ValueRange optimizedLoops,
107107
ValueRange lbs, ValueRange ubs,
108-
function_ref<void(KrnlBuilder &createKrnl, ValueRange indices)>
108+
function_ref<void(const KrnlBuilder &createKrnl, ValueRange indices)>
109109
bodyBuilderFn);
110110
};
111111
```
@@ -128,7 +128,7 @@ ValueRange loopDef = createKrnl.defineLoops(2);
128128
129129
// Create the loop.
130130
createKrnl.iterate(loopDef, loopDef, {zero, zero}, {ub0, ub1},
131-
[&](KrnlBuilder &createKrnl, ValueRange loopInd){
131+
[&](const KrnlBuilder &createKrnl, ValueRange loopInd){
132132
// Loop body.
133133
createKrnl.store(zero, array, loopInd);
134134
});
@@ -183,7 +183,7 @@ ValueRange loopBlockDef = createKrnl.block(loopDef, 4);
183183
createKrnl.permute({loopBlockDef[0], loopBlockDef[1], {0,1});
184184
// Create the loop iterating over the blocks.
185185
createKrnl.iterate(loopDef, {loopBlockDef[0], loopBlockDef[0]}, {zero}, {ub0},
186-
[&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){
186+
[&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){
187187
// Loop body.
188188
createKrnl.store(zero, array, loopInd);
189189
});
@@ -209,10 +209,10 @@ We now consider tiling our original 2-dimensional example below.
209209
// Create the loop iterating over the blocks.
210210
createKrnl.iterate(loopDef, {outerLoopBlockDef[0], innerLoopBlockDef[0]},
211211
{zero, zero}, {ub0, ub1},
212-
[&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){
212+
[&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){
213213
// Create the loop iterating inside the blocks.
214214
createKrnl.iterate({}, {outerLoopBlockDef[1], innerLoopBlockDef[1]},
215-
{}, {}, [&](KrnlBuilder &createKrnl, ValueRange loopInd) {
215+
{}, {}, [&](const KrnlBuilder &createKrnl, ValueRange loopInd) {
216216
// Loop body.
217217
createKrnl.store(zero, array, loopInd);
218218
});

src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp

+10-13
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {
13151315
Value iZero = create.math.constantIndex(0);
13161316
ValueRange batchLoop = create.krnl.defineLoops(1);
13171317
create.krnl.iterate(batchLoop, batchLoop, {iZero}, {create.mem.dim(Y, 2)},
1318-
[&](KrnlBuilder &createKrnl, ValueRange batchIndices) {
1318+
[&](const KrnlBuilder &createKrnl, ValueRange batchIndices) {
13191319
MathBuilder createMath(createKrnl);
13201320
IndexExprScope ieScope(createKrnl);
13211321
Value bs = batchIndices[0];
@@ -1338,7 +1338,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {
13381338
rewriter.setInsertionPointToStart(&regionOp.getBodyRegion().front());
13391339
ValueRange loops = create.krnl.defineLoops(yRank - 1);
13401340
create.krnl.iterate(loops, loops, yLbs, yUbs,
1341-
[&](KrnlBuilder &createKrnl, ValueRange indices) {
1341+
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
13421342
Value sequenceIV(indices[0]);
13431343
Value directionIV(indices[1]);
13441344
Value hs(indices[2]);
@@ -1366,7 +1366,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern {
13661366

13671367
ValueRange loops = create.krnl.defineLoops(yRank);
13681368
create.krnl.iterate(loops, loops, yLbs, yUbs,
1369-
[&](KrnlBuilder &createKrnl, ValueRange indices) {
1369+
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
13701370
MathBuilder createMath(createKrnl);
13711371
IndexExprScope ieScope(createKrnl);
13721372
Value sequenceIV(indices[0]);
@@ -1435,7 +1435,7 @@ struct ZHighToZLowFixGRUYhOpLowering : public ConversionPattern {
14351435
Value seqSize = create.mem.dim(Y, 0);
14361436
ValueRange loops = create.krnl.defineLoops(htRank);
14371437
create.krnl.iterate(loops, loops, htLbs, htUbs,
1438-
[&](KrnlBuilder &createKrnl, ValueRange indices) {
1438+
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
14391439
MathBuilder createMath(createKrnl);
14401440
IndexExprScope ieScope(createKrnl);
14411441
Value bs(indices[1]), hs(indices[2]);
@@ -1612,7 +1612,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering
16121612
SmallVector<IndexExpr, 4> lbs(rank, LitIE(0));
16131613
SmallVector<IndexExpr, 4> ubs = shapeHelper.getOutputDims();
16141614
create.krnl.iterateIE(loopDef, loopDef, lbs, ubs,
1615-
[&](KrnlBuilder &createKrnl, ValueRange indices) {
1615+
[&](const KrnlBuilder &createKrnl, ValueRange indices) {
16161616
// Keep this load inside the loop to tweak LLVM.
16171617
Value valueF16 = createKrnl.load(memrefF16);
16181618
createKrnl.store(valueF16, res, indices);
@@ -1701,13 +1701,10 @@ struct ZHighToZLowDataConversionLowering
17011701
SmallVector<IndexExpr, 4> flattenedOutputDims;
17021702
Value flatOutput = create.mem.reshapeToFlatInnermost(
17031703
alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops);
1704-
DimsExpr lbs(1, LitIE(0));
17051704

17061705
// Create loop iteration (flattened to 1D) and block it by totVL.
1707-
ValueRange loopDef = create.krnl.defineLoops(1);
1708-
ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL);
1709-
SmallVector<Value, 1> optimizedLoopDef(1, blockedLoopDef[0]);
1710-
1706+
DimsExpr lbs = {LitIE(0)};
1707+
bool useParallel = false;
17111708
if (enableParallel) {
17121709
int64_t parId;
17131710
int64_t tripCount = flattenedOutputDims[0].isLiteral()
@@ -1716,7 +1713,7 @@ struct ZHighToZLowDataConversionLowering
17161713
: -1;
17171714
if (findSuitableParallelDimension(lbs, flattenedOutputDims, 0, 1, parId,
17181715
/*min iter for going parallel*/ 1024)) {
1719-
create.krnl.parallel(blockedLoopDef[0]);
1716+
useParallel = true;
17201717
onnxToKrnlParallelReport(op, /*successful*/ true, 0, tripCount,
17211718
"dlf16-f32 conversion fully parallelized");
17221719
} else {
@@ -1729,8 +1726,8 @@ struct ZHighToZLowDataConversionLowering
17291726
: -1,
17301727
"dlf16-f32 conversion fully flattened");
17311728

1732-
create.krnl.iterateIE(loopDef, optimizedLoopDef, lbs, flattenedOutputDims,
1733-
[&](KrnlBuilder &b, ValueRange loopInd) {
1729+
create.krnl.forLoopIE(lbs[0], flattenedOutputDims[0], totVL, useParallel,
1730+
[&](const KrnlBuilder &b, ValueRange loopInd) {
17341731
MDBuilder create(b);
17351732
// Manually unrolled loop, add archVL offset at each iterations.
17361733
for (int64_t u = 0; u < unrollVL; ++u) {

src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
149149
create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims);
150150

151151
// Outer loop (E4, E3, E2, E1 iterates over tiles of 64 elements)
152-
create.krnl.iterateIE(
153-
loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) {
152+
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
153+
[&](const KrnlBuilder &b, ValueRange loopInd) {
154154
MDBuilder create(b);
155155
IndexExprScope outerScope(create.krnl, &allocScope);
156156
DimsExpr outerIndices = DimListIE(loopInd);
@@ -192,14 +192,14 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
192192
// Condition
193193
isFullLogical.getValue(),
194194
// Then (is full).
195-
[&](SCFBuilder b) {
195+
[&](const SCFBuilder b) {
196196
MDBuilder create(b);
197197
// Loop (tried unroll of 2 and 8, 4 was best).
198198
const int64_t unrollVL = 4;
199199
const int64_t totVL = unrollVL * archVL;
200200
assert(totVL <= 64 && "bad unroll");
201201
create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL,
202-
[&](SCFBuilder b, ValueRange loopInd) {
202+
[&](const SCFBuilder b, ValueRange loopInd) {
203203
MDBuilder create(b);
204204
IndexExprScope innerScope(b, &outerScope);
205205
Value loopIndex = loopInd[0];
@@ -430,8 +430,8 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
430430
create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims);
431431

432432
// Outer loop (E1 iterates over tiles of 64 elements).
433-
create.krnl.iterateIE(
434-
loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) {
433+
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
434+
[&](const KrnlBuilder &b, ValueRange loopInd) {
435435
MDBuilder create(b);
436436
IndexExprScope outerScope(create.krnl, &allocScope);
437437
DimsExpr outerIndices;
@@ -458,7 +458,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
458458
#endif
459459

460460
create.affine.forLoopIE(litZero, simdLoopUB, totVL,
461-
[&](AffineBuilder &b, ValueRange loopInd) {
461+
[&](const AffineBuilder &b, ValueRange loopInd) {
462462
MDBuilder create(b);
463463
DimsExpr inputAF;
464464
IndexExprScope innerScope(create.krnl, &outerScope);

src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
9090
return success();
9191
}
9292

93-
void genCopyLoops(AffineBuilderKrnlMem &createAffine,
93+
void genCopyLoops(const AffineBuilderKrnlMem &createAffine,
9494
IndexExprScope *enclosingScope, Value buffMemref, Value destMemref,
9595
IndexExpr zeroIE, SmallVectorImpl<IndexExpr> &starts,
9696
SmallVectorImpl<IndexExpr> &writeUBs, SmallVectorImpl<Value> &loopIndices,
@@ -125,7 +125,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern {
125125
} else {
126126
// Loop to copy the data.
127127
createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/,
128-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
128+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
129129
loopIndices.emplace_back(loopInd[0]);
130130
genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref,
131131
zeroIE, starts, writeUBs, loopIndices, i + 1, buffRank);

src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
129129
return success();
130130
}
131131

132-
void genCopyLoops(AffineBuilderKrnlMem &createAffine,
132+
void genCopyLoops(const AffineBuilderKrnlMem &createAffine,
133133
IndexExprScope *enclosingScope, Value buffMemref, Value sourceMemref,
134134
SmallVectorImpl<int64_t> &srcLoopMap, Value padVal, IndexExpr zeroIE,
135135
SmallVectorImpl<IndexExpr> &starts, SmallVectorImpl<IndexExpr> &readUBs,
@@ -169,7 +169,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
169169
// Nothing to read, skip.
170170
} else {
171171
createAffine.forLoopIE(zeroIE, readUBs[i], 1,
172-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
172+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
173173
loopIndices.emplace_back(loopInd[0]);
174174
genCopyLoops(createAffine, enclosingScope, buffMemref,
175175
sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs,
@@ -182,7 +182,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern {
182182
// No padding needed.
183183
} else {
184184
createAffine.forLoopIE(readUBs[i], padUBs[i], 1,
185-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
185+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
186186
loopIndices.emplace_back(loopInd[0]);
187187
genCopyLoops(createAffine, enclosingScope, buffMemref,
188188
sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs,

src/Conversion/KrnlToAffine/KrnlMatmul.cpp

+21-17
Original file line numberDiff line numberDiff line change
@@ -223,30 +223,32 @@ class KrnlMatmulLowering : public ConversionPattern {
223223
if (matVectorProduct) {
224224
// clang-format off
225225
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
226-
/* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
226+
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
227227
genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart,
228228
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
229229
vectorLen, fullUnrollAndJam);
230-
}, /* else has partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
230+
}, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
231231
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
232232
iTrip, jTrip, kTrip, /*unroll*/ false);
233233
});
234234
// clang-format on
235235
} else {
236236
// clang-format off
237237
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
238-
/* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
238+
/* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
239239
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
240240
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize,
241241
vectorLen, fullUnrollAndJam);
242-
}, /* has some partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
242+
},
243+
/* Else has some partial tiles */
244+
[&](const AffineBuilderKrnlMem &createAffine) {
243245
// Trip regardless of full/partial for N & K
244246
// Test if SIMD dim (M) is full.
245247
createAffine.ifThenElseIE(indexScope, jFullTiles,
246-
/* full SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
248+
/* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
247249
genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart,
248250
cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false);
249-
}, /* else partial SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
251+
}, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
250252
// TODO: evaluate if get performance from partial SIMD
251253
if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) {
252254
// has a known trip count along the simd dimension of at least 2
@@ -265,11 +267,11 @@ class KrnlMatmulLowering : public ConversionPattern {
265267
// Scalar code generator.
266268
// clang-format off
267269
create.affineKMem.ifThenElseIE(indexScope, allFullTiles,
268-
/* then full */ [&](AffineBuilderKrnlMem &createAffine) {
270+
/* then full */ [&](const AffineBuilderKrnlMem &createAffine) {
269271
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
270272
iComputeTileSize, jComputeTileSize, kComputeTileSize,
271273
fullUnrollAndJam);
272-
}, /* else partial */ [&](AffineBuilderKrnlMem &createAffine) {
274+
}, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) {
273275
genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart,
274276
iTrip, jTrip, kTrip, false);
275277
});
@@ -280,7 +282,7 @@ class KrnlMatmulLowering : public ConversionPattern {
280282
}
281283

282284
private:
283-
void genScalar(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
285+
void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
284286
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
285287
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
286288
bool unrollJam) const {
@@ -300,10 +302,11 @@ class KrnlMatmulLowering : public ConversionPattern {
300302
LiteralIndexExpr zeroIE(0);
301303
Value jSaved;
302304
createAffine.forLoopIE(zeroIE, I, 1,
303-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
305+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
304306
Value i = loopInd[0];
305307
createAffine.forLoopIE(zeroIE, J, 1,
306-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
308+
[&](const AffineBuilderKrnlMem &createAffine,
309+
ValueRange loopInd) {
307310
MathBuilder createMath(createAffine);
308311
Value j = loopInd[0];
309312
// Defines induction variables, and possibly initialize C.
@@ -315,7 +318,7 @@ class KrnlMatmulLowering : public ConversionPattern {
315318
createAffine.store(initVal, TmpC, tmpCAccess);
316319
// Sum over k.
317320
createAffine.forLoopIE(zeroIE, K, 1,
318-
[&](AffineBuilderKrnlMem &createAffine,
321+
[&](const AffineBuilderKrnlMem &createAffine,
319322
ValueRange loopInd) {
320323
MathBuilder createMath(createAffine);
321324
Value k = loopInd[0];
@@ -340,7 +343,7 @@ class KrnlMatmulLowering : public ConversionPattern {
340343
}
341344

342345
// Initially, simdize with full K vector length.
343-
void genSimdMatVect(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
346+
void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
344347
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
345348
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
346349
IndexExpr vectorLen, bool unrollJam) const {
@@ -384,7 +387,7 @@ class KrnlMatmulLowering : public ConversionPattern {
384387
Value iZero = create.math.constantIndex(0);
385388

386389
create.affineKMem.forLoopIE(zeroIE, K, VL,
387-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
390+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
388391
MultiDialectBuilder<MathBuilder, VectorBuilder> create(createAffine);
389392
Value k = loopInd[0];
390393
// Iterates over the I indices (K is SIMD dim).
@@ -431,7 +434,7 @@ class KrnlMatmulLowering : public ConversionPattern {
431434
}
432435

433436
// Simdize along J / memory rows in B and C.
434-
void genSimdMatMat(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
437+
void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
435438
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
436439
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
437440
IndexExpr vectorLen, bool unrollJam) const {
@@ -466,7 +469,7 @@ class KrnlMatmulLowering : public ConversionPattern {
466469
Value iZero = create.math.constantIndex(0);
467470

468471
createAffine.forLoopIE(zeroIE, I, 1,
469-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
472+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
470473
MultiDialectBuilder<MathBuilder, VectorBuilder> create(createAffine);
471474
Value i = loopInd[0];
472475
iSaved = i; // Saved for unroll and jam.
@@ -476,7 +479,8 @@ class KrnlMatmulLowering : public ConversionPattern {
476479
createAffine.store(initVal, TmpC, tmpCAccess);
477480
// Sum over k.
478481
createAffine.forLoopIE(zeroIE, K, 1,
479-
[&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
482+
[&](const AffineBuilderKrnlMem &createAffine,
483+
ValueRange loopInd) {
480484
MultiDialectBuilder<MathBuilder, VectorBuilder> create(
481485
createAffine);
482486
Value k = loopInd[0];

src/Conversion/KrnlToAffine/KrnlMemset.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class KrnlMemsetLowering : public ConversionPattern {
5959
SmallVector<int64_t, 4> steps(rank, 1);
6060
// Copy data,
6161
create.affineKMem.forLoopsIE(lbs, ubs, steps,
62-
[&](AffineBuilderKrnlMem &createAffine, ValueRange indices) {
62+
[&](const AffineBuilderKrnlMem &createAffine, ValueRange indices) {
6363
createAffine.store(destVal, destMemRef, indices);
6464
});
6565
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)