@@ -223,30 +223,32 @@ class KrnlMatmulLowering : public ConversionPattern {
223
223
if (matVectorProduct) {
224
224
// clang-format off
225
225
create.affineKMem .ifThenElseIE (indexScope, allFullTiles,
226
- /* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
226
+ /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
227
227
genSimdMatVect (createAffine, matmulOp, elementType, aStart, bStart,
228
228
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize ,
229
229
vectorLen, fullUnrollAndJam);
230
- }, /* else has partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
230
+ }, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
231
231
genScalar (createAffine, matmulOp, elementType, aStart, bStart, cStart,
232
232
iTrip, jTrip, kTrip , /* unroll*/ false );
233
233
});
234
234
// clang-format on
235
235
} else {
236
236
// clang-format off
237
237
create.affineKMem .ifThenElseIE (indexScope, allFullTiles,
238
- /* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) {
238
+ /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) {
239
239
genSimdMatMat (createAffine, matmulOp, elementType, aStart, bStart,
240
240
cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize ,
241
241
vectorLen, fullUnrollAndJam);
242
- }, /* has some partial tiles */ [&](AffineBuilderKrnlMem &createAffine) {
242
+ },
243
+ /* Else has some partial tiles */
244
+ [&](const AffineBuilderKrnlMem &createAffine) {
243
245
// Trip regardless of full/partial for N & K
244
246
// Test if SIMD dim (M) is full.
245
247
createAffine.ifThenElseIE (indexScope, jFullTiles,
246
- /* full SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
248
+ /* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
247
249
genSimdMatMat (createAffine, matmulOp, elementType, aStart, bStart,
248
250
cStart, iTrip, jComputeTileSize, kTrip , vectorLen, /* unroll*/ false );
249
- }, /* else partial SIMD */ [&](AffineBuilderKrnlMem &createAffine) {
251
+ }, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) {
250
252
// TODO: evaluate if get performance from partial SIMD
251
253
if (false && jPartialTrip.isLiteral () && jPartialTrip.getLiteral () >=2 ) {
252
254
// has a known trip count along the simd dimension of at least 2
@@ -265,11 +267,11 @@ class KrnlMatmulLowering : public ConversionPattern {
265
267
// Scalar code generator.
266
268
// clang-format off
267
269
create.affineKMem .ifThenElseIE (indexScope, allFullTiles,
268
- /* then full */ [&](AffineBuilderKrnlMem &createAffine) {
270
+ /* then full */ [&](const AffineBuilderKrnlMem &createAffine) {
269
271
genScalar (createAffine, matmulOp, elementType, aStart, bStart, cStart,
270
272
iComputeTileSize, jComputeTileSize, kComputeTileSize ,
271
273
fullUnrollAndJam);
272
- }, /* else partial */ [&](AffineBuilderKrnlMem &createAffine) {
274
+ }, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) {
273
275
genScalar (createAffine, matmulOp, elementType, aStart, bStart, cStart,
274
276
iTrip, jTrip, kTrip , false );
275
277
});
@@ -280,7 +282,7 @@ class KrnlMatmulLowering : public ConversionPattern {
280
282
}
281
283
282
284
private:
283
- void genScalar (AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
285
+ void genScalar (const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
284
286
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
285
287
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
286
288
bool unrollJam) const {
@@ -300,10 +302,11 @@ class KrnlMatmulLowering : public ConversionPattern {
300
302
LiteralIndexExpr zeroIE (0 );
301
303
Value jSaved;
302
304
createAffine.forLoopIE (zeroIE, I, 1 ,
303
- [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
305
+ [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
304
306
Value i = loopInd[0 ];
305
307
createAffine.forLoopIE (zeroIE, J, 1 ,
306
- [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
308
+ [&](const AffineBuilderKrnlMem &createAffine,
309
+ ValueRange loopInd) {
307
310
MathBuilder createMath (createAffine);
308
311
Value j = loopInd[0 ];
309
312
// Defines induction variables, and possibly initialize C.
@@ -315,7 +318,7 @@ class KrnlMatmulLowering : public ConversionPattern {
315
318
createAffine.store (initVal, TmpC, tmpCAccess);
316
319
// Sum over k.
317
320
createAffine.forLoopIE (zeroIE, K, 1 ,
318
- [&](AffineBuilderKrnlMem &createAffine,
321
+ [&](const AffineBuilderKrnlMem &createAffine,
319
322
ValueRange loopInd) {
320
323
MathBuilder createMath (createAffine);
321
324
Value k = loopInd[0 ];
@@ -340,7 +343,7 @@ class KrnlMatmulLowering : public ConversionPattern {
340
343
}
341
344
342
345
// Initially, simdize with full K vector length.
343
- void genSimdMatVect (AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
346
+ void genSimdMatVect (const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
344
347
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
345
348
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
346
349
IndexExpr vectorLen, bool unrollJam) const {
@@ -384,7 +387,7 @@ class KrnlMatmulLowering : public ConversionPattern {
384
387
Value iZero = create.math .constantIndex (0 );
385
388
386
389
create.affineKMem .forLoopIE (zeroIE, K, VL,
387
- [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
390
+ [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
388
391
MultiDialectBuilder<MathBuilder, VectorBuilder> create (createAffine);
389
392
Value k = loopInd[0 ];
390
393
// Iterates over the I indices (K is SIMD dim).
@@ -431,7 +434,7 @@ class KrnlMatmulLowering : public ConversionPattern {
431
434
}
432
435
433
436
// Simdize along J / memory rows in B and C.
434
- void genSimdMatMat (AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
437
+ void genSimdMatMat (const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op,
435
438
Type elementType, ArrayRef<IndexExpr> aStart, ArrayRef<IndexExpr> bStart,
436
439
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
437
440
IndexExpr vectorLen, bool unrollJam) const {
@@ -466,7 +469,7 @@ class KrnlMatmulLowering : public ConversionPattern {
466
469
Value iZero = create.math .constantIndex (0 );
467
470
468
471
createAffine.forLoopIE (zeroIE, I, 1 ,
469
- [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
472
+ [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
470
473
MultiDialectBuilder<MathBuilder, VectorBuilder> create (createAffine);
471
474
Value i = loopInd[0 ];
472
475
iSaved = i; // Saved for unroll and jam.
@@ -476,7 +479,8 @@ class KrnlMatmulLowering : public ConversionPattern {
476
479
createAffine.store (initVal, TmpC, tmpCAccess);
477
480
// Sum over k.
478
481
createAffine.forLoopIE (zeroIE, K, 1 ,
479
- [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) {
482
+ [&](const AffineBuilderKrnlMem &createAffine,
483
+ ValueRange loopInd) {
480
484
MultiDialectBuilder<MathBuilder, VectorBuilder> create (
481
485
createAffine);
482
486
Value k = loopInd[0 ];
0 commit comments