@@ -58,55 +58,26 @@ static StringRef getFormat(const Type &inputType) {
58
58
59
59
Value KrnlBuilder::load (
60
60
Value memref, ValueRange indices, ValueRange offsets) const {
61
- // Handle offsets.
62
- SmallVector<Value, 4 > computedIndices;
63
- MathBuilder createMath (*this );
64
- createMath.addOffsetToLeastSignificant (indices, offsets, computedIndices);
65
- // Perform load.
66
- if (computedIndices.size () == 0 ) {
67
- // case memref<1xdtype>
68
- MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType ());
69
- assert (type && " Not MemRefType" );
70
- if (type.getRank () == 1 && type.getShape ()[0 ] == 1 ) {
71
- MultiDialectBuilder<MathBuilder> create (*this );
72
- Value iZero = create.math .constantIndex (0 );
73
- return b ().create <KrnlLoadOp>(loc (), memref, ValueRange ({iZero}));
74
- }
75
- }
76
- return b ().create <KrnlLoadOp>(loc (), memref, computedIndices);
61
+ return onnx_mlir::impl::load<KrnlBuilder, KrnlLoadOp>(
62
+ *this , memref, indices, offsets);
77
63
}
78
64
79
65
Value KrnlBuilder::loadIE (
80
66
Value memref, ArrayRef<IndexExpr> indices, ValueRange offsets) const {
81
- SmallVector<Value, 4 > indexValues;
82
- IndexExpr::getValues (indices, indexValues);
83
- return load (memref, indexValues, offsets);
67
+ return onnx_mlir::impl::loadIE<KrnlBuilder, KrnlLoadOp>(
68
+ *this , memref, indices, offsets);
84
69
}
85
70
86
71
void KrnlBuilder::store (
87
72
Value val, Value memref, ValueRange indices, ValueRange offsets) const {
88
- SmallVector<Value, 4 > computedIndices;
89
- MathBuilder createMath (*this );
90
- createMath.addOffsetToLeastSignificant (indices, offsets, computedIndices);
91
- if (computedIndices.size () == 0 ) {
92
- // case memref<1xdtype>
93
- MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType ());
94
- assert (type && " Not MemRefType" );
95
- if (type.getRank () == 1 && type.getShape ()[0 ] == 1 ) {
96
- MultiDialectBuilder<MathBuilder> create (*this );
97
- Value iZero = create.math .constantIndex (0 );
98
- b ().create <KrnlStoreOp>(loc (), val, memref, ValueRange ({iZero}));
99
- return ;
100
- }
101
- }
102
- b ().create <KrnlStoreOp>(loc (), val, memref, computedIndices);
73
+ onnx_mlir::impl::store<KrnlBuilder, KrnlStoreOp>(
74
+ *this , val, memref, indices, offsets);
103
75
}
104
76
105
77
void KrnlBuilder::storeIE (Value val, Value memref, ArrayRef<IndexExpr> indices,
106
78
ValueRange offsets) const {
107
- SmallVector<Value, 4 > indexValues;
108
- IndexExpr::getValues (indices, indexValues);
109
- store (val, memref, indexValues, offsets);
79
+ onnx_mlir::impl::storeIE<KrnlBuilder, KrnlStoreOp>(
80
+ *this , val, memref, indices, offsets);
110
81
}
111
82
112
83
Value KrnlBuilder::getLinearOffsetIndex (
@@ -246,228 +217,37 @@ KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops,
246
217
});
247
218
}
248
219
249
- /*
250
- Example of how to use the interface:
251
- Say you have a loop of i=0..256, j=0..128 and want to exploit r[i,j] = a[i,j] +
252
- b[j] + c. For the loops, we will need access functions for a, b, and r.
253
-
254
- Say we already have the loop for the outer loop of i
255
-
256
- krnl.iterate(loop i from 0 to 256) {
257
- ii is the loop index.
258
-
259
- // 1) compute access function for a, b, c
260
- // 2) launch simd loop with
261
- // 3) simd kernel
262
- }
263
-
264
- 1) Access functions
265
- Assuming here that we are not blocking the j loop, namely the simd iteration
266
- goes over all j values, the access functions should be defined as follows.
267
-
268
- aAF = {ii, 0}
269
- bAF = {0}
270
- rAF = {ii, 0}
271
-
272
- If the j loop was blocked (say j=0 to 128 by 16), then instead of `0` in the
273
- last dim, we would have 'blocked_jj'
274
-
275
- 2) Launch simd loop
276
-
277
- create.krnl.simdIterateIE(
278
- lb=LitIE(0), ub=litIE(128), totVL=8, // loop params
279
- fullySimd=true, useParallel=false, // loop options
280
- inputs={A, B}, inputAFs={aAF, bAF}, // inputs
281
- outputs={R}, outputAFs={rAF}, // outputs
282
- krnl) // lambda function for kernel
283
-
284
- 3) Krnl for SIMD loop
285
-
286
- The kernel functions has 4 inputs:
287
- a) krnl builder to further build code
288
- b) list of loaded input values, in the same order as in inputs
289
- c) list of results values, that must be enqueued by the kernel
290
- d) totVL used for the loop (VL for simd, 1 for scalar)
291
-
292
- The same kernel will be used in a SIMD context, in which the inputs and
293
- outputs must be vectors of VL elements, or in a scalar context, in which the
294
- inputs and outputs must be scalars.
295
-
296
- In our example, the kernel is as follows
297
-
298
- [&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
299
- SmallVectorImpl<Value> &resVals, int64_t VL) {
300
- MultiDialectBuilder<KrnlBuilder, MathBuilder> create(kb);
301
- Value aVal = inputVals[0]; // simd or scalar
302
- Value bVal = inputVals[1]; // simd or scalar
303
- Value cVal = create.krnl.load(C); // scalar always
304
- Value newVal = create.math.add(aVal, bVal); // simd or scalar
305
- newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is
306
- // splatted
307
- res.emplace_back(newVal); // Save simd or scalar result.
308
- }
309
-
310
- The krnl.simdIterateIE will be in charge of loading and saving the values in
311
- memory. The create.math functions have been extended so that when a SIMD
312
- value is computed with a scalar, that scalar will be automaticaly splatted
313
- (aka promoted to a vector of identical values). As a result, the kernel can
314
- be written in a SIMD agnostic value. However, in rare situations, we may
315
- want to know if we are in SIMD mode or not. VL will give the totVL used here
316
- (either totVL>1 or 1).
317
- */
318
-
319
- // Determine if an access has one element from the innermost dimensions up to
320
- // innerDim.
321
- bool static hasOneElementInInnermostDims (Value value, int64_t innerDim) {
322
- // Get info.
323
- ShapedType type = mlir::dyn_cast<ShapedType>(value.getType ());
324
- assert (type && " expected shaped type" );
325
- int64_t rank = type.getRank ();
326
- ArrayRef<int64_t > shape = type.getShape ();
327
- for (int64_t i = std::max ((int64_t )0 , rank - innerDim); i < rank; ++i)
328
- if (shape[i] != 1 )
329
- return false ;
330
- return true ;
331
- }
332
-
333
220
void KrnlBuilder::simdIterateIE (IndexExpr lb, IndexExpr ub, int64_t VL,
334
221
bool fullySimd, bool useParallel, ArrayRef<Value> inputs,
335
222
ArrayRef<DimsExpr> inputAFs, ArrayRef<Value> outputs,
336
223
ArrayRef<DimsExpr> outputAFs,
337
- function_ref<void (KrnlBuilder &kb , ArrayRef<Value> inputVals,
224
+ function_ref<void (KrnlBuilder &b , ArrayRef<Value> inputVals,
338
225
llvm::SmallVectorImpl<Value> &resultVals, int64_t VL)>
339
- bodyBuilderFn) {
340
- int64_t inputNum = inputs.size ();
341
- assert (inputAFs.size () == inputs.size () && " expected same size" );
342
- int64_t outputNum = outputs.size ();
343
- assert (outputAFs.size () == outputs.size () && " expected same size" );
344
- MultiDialectBuilder<VectorBuilder> create (*this );
345
-
346
- if (VL > 1 ) {
347
- // Want SIMD, execute full SIMD loops blocked by VL.
348
- ValueRange loopDef = defineLoops (1 );
349
- ValueRange blockedLoopDef = block (loopDef[0 ], VL);
350
- if (useParallel)
351
- parallel ({blockedLoopDef[0 ]});
352
-
353
- // If we are not guaranteed that every iterations are SIMD iterations, then
354
- // we need to reduce the trip count by a bit so as to not over compute.
355
- // If we are not guaranteed that every iterations are SIMD iterations, then
356
- IndexExpr simdUb = ub;
357
- if (!fullySimd)
358
- simdUb = simdUb - (VL - 1 );
359
- iterateIE (loopDef, {blockedLoopDef[0 ]}, {lb}, {simdUb},
360
- [&](KrnlBuilder &ck, ValueRange loopInd) {
361
- IndexExprScope scope (ck);
362
- MultiDialectBuilder<KrnlBuilder, VectorBuilder> create (ck);
363
- IndexExpr ind = DimIE (loopInd[0 ]);
364
- // Load all the inputs as vectors of VL values, with a few exceptions.
365
- // One is if the value is a "none value", leave as is. Another one is
366
- // if the innermost dim is a scalar (ie dim[rank-1] == 1), then we
367
- // just load the scalar.
368
- llvm::SmallVector<Value, 4 > vecInputVals;
369
- for (int64_t i = 0 ; i < inputNum; ++i) {
370
- Value input = inputs[i];
371
- if (isNoneValue (input)) {
372
- // Simply enqueue the none value.
373
- vecInputVals.emplace_back (input);
374
- continue ;
375
- }
376
- MemRefType type = mlir::cast<MemRefType>(input.getType ());
377
- int64_t rank = type.getRank ();
378
- DimsExpr AF = SymListIE (inputAFs[i]);
379
- assert (rank == (int64_t )AF.size () && " AF expected input rank refs" );
380
- if (hasOneElementInInnermostDims (input, 1 )) {
381
- // Has a reference with a scalar innermost dim, just load as a
382
- // scalar. No need to add the induction variable.
383
- Value scalarVal = create.krnl .loadIE (input, AF);
384
- vecInputVals.emplace_back (scalarVal);
385
- } else {
386
- // Have a vector.
387
- VectorType vecType = VectorType::get ({VL}, type.getElementType ());
388
- AF[rank - 1 ] = AF[rank - 1 ] + ind; // Add induction var.
389
- Value vecVal = create.vec .loadIE (vecType, input, AF);
390
- vecInputVals.emplace_back (vecVal);
391
- }
392
- }
393
- // Call the method to compute the values.
394
- llvm::SmallVector<Value, 4 > vecResVals;
395
- bodyBuilderFn (create.krnl , vecInputVals, vecResVals, VL);
396
- assert ((int64_t )vecResVals.size () == outputNum &&
397
- " loop body with incorrect number of results" );
398
- // Store all the outputs as vectors of VL values,
399
- for (int64_t i = 0 ; i < outputNum; ++i) {
400
- MemRefType type = mlir::cast<MemRefType>(outputs[i].getType ());
401
- DimsExpr AF = SymListIE (outputAFs[i]);
402
- int64_t rank = type.getRank ();
403
- assert (rank == (int64_t )AF.size () && " AF expected ouput rank refs" );
404
- AF[rank - 1 ] = AF[rank - 1 ] + ind;
405
- create.vec .storeIE (vecResVals[i], outputs[i], AF);
406
- }
407
- });
408
- if (fullySimd)
409
- // Asserted that we only have SIMD iterations, we are done.
410
- return ;
411
- // Account for the loop iterations performed above.
412
- IndexExpr tripCount = ub - lb;
413
- IndexExpr missingIters = tripCount % VL;
414
- IndexExpr completedIters = tripCount - missingIters;
415
- if (missingIters.isLiteralAndIdenticalTo (0 )) {
416
- // Detect that we only have SIMD iterations, we are also done.
417
- return ;
418
- }
419
- // We may have additional iterations to perform, adjust lb to skip the
420
- // completed iterations.
421
- lb = lb + completedIters;
422
- }
423
- // Handle remaining scalar values (from lb to ub without unrolling).
424
- ValueRange loopDef = defineLoops (1 );
425
- iterateIE (
426
- loopDef, loopDef, {lb}, {ub}, [&](KrnlBuilder &ck, ValueRange loopInd) {
427
- IndexExprScope scope (ck);
428
- MultiDialectBuilder<KrnlBuilder> create (ck);
429
- IndexExpr ind = DimIE (loopInd[0 ]);
430
- // Load all the inputs as scalar values,
431
- llvm::SmallVector<Value, 4 > scalarInputVals;
432
- for (int64_t i = 0 ; i < inputNum; ++i) {
433
- Value input = inputs[i];
434
- if (isNoneValue (input)) {
435
- // Simply enqueue the none value.
436
- scalarInputVals.emplace_back (input);
437
- continue ;
438
- }
439
- MemRefType type = mlir::cast<MemRefType>(input.getType ());
440
- int64_t rank = type.getRank ();
441
- DimsExpr AF = SymListIE (inputAFs[i]);
442
- if (hasOneElementInInnermostDims (input, 1 )) {
443
- // Has a reference with a scalar innermost dim, just load as a
444
- // scalar. No need to add the induction variable.
445
- Value scalarVal = create.krnl .loadIE (input, AF);
446
- scalarInputVals.emplace_back (scalarVal);
447
- } else {
448
- AF[rank - 1 ] = AF[rank - 1 ] + ind;
449
- Value scalarVal = create.krnl .loadIE (input, AF);
450
- scalarInputVals.emplace_back (scalarVal);
451
- }
452
- }
453
- // Call the method to compute the values.
454
- llvm::SmallVector<Value, 4 > scalarResVals;
455
- bodyBuilderFn (create.krnl , scalarInputVals, scalarResVals, /* VL*/ 1 );
456
- assert ((int64_t )scalarResVals.size () == outputNum &&
457
- " loop body with incorrect number of results" );
458
- // Store all the outputs as vectors of VL values,
459
- for (int64_t i = 0 ; i < outputNum; ++i) {
460
- MemRefType type = mlir::cast<MemRefType>(outputs[i].getType ());
461
- DimsExpr AF = SymListIE (outputAFs[i]);
462
- int64_t rank = type.getRank ();
463
- assert (rank == (int64_t )AF.size () && " AF expected ouput rank refs" );
464
- AF[rank - 1 ] = AF[rank - 1 ] + ind;
465
- create.krnl .storeIE (scalarResVals[i], outputs[i], AF);
466
- }
467
- });
468
- }
469
-
470
- void KrnlBuilder::yield (ValueRange iterArgs) const {
226
+ bodyBuilderFn) const {
227
+ onnx_mlir::impl::simdIterateIE<KrnlBuilder, KrnlBuilder>(*this , lb, ub, VL,
228
+ fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs,
229
+ bodyBuilderFn);
230
+ }
231
+
232
+ void KrnlBuilder::simdReduceIE (IndexExpr lb, IndexExpr ub, int64_t VL,
233
+ bool fullySimd, ArrayRef<Value> inputs, ArrayRef<DimsExpr> inputAFs,
234
+ ArrayRef<Value> tmps, ArrayRef<DimsExpr> tmpAFs, ArrayRef<Value> outputs,
235
+ ArrayRef<DimsExpr> outputAFs, ArrayRef<Value> initVals,
236
+ /* reduction function (simd or scalar) */
237
+ function_ref<void (const KrnlBuilder &b, ArrayRef<Value> inputVals,
238
+ ArrayRef<Value> tmpVals, llvm::SmallVectorImpl<Value> &resultVals,
239
+ int64_t VL)>
240
+ reductionBuilderFn,
241
+ /* post reduction function (simd to scalar + post processing)*/
242
+ function_ref<void(const KrnlBuilder &b, ArrayRef<Value> tmpVals,
243
+ llvm::SmallVectorImpl<Value> &scalarOutputs, int64_t VL)>
244
+ postProcessingBuilderFn) const {
245
+ onnx_mlir::impl::simdReduceIE<KrnlBuilder, KrnlBuilder>(*this , lb, ub, VL,
246
+ fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals,
247
+ reductionBuilderFn, postProcessingBuilderFn);
248
+ }
249
+
250
+ void KrnlBuilder::yield (mlir::ValueRange iterArgs) const {
471
251
b ().create <KrnlYieldOp>(loc (), iterArgs);
472
252
}
473
253
@@ -533,8 +313,8 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name,
533
313
void KrnlBuilder::memcpy (Value dest, Value src, Value numElems) const {
534
314
MultiDialectBuilder<MathBuilder> create (*this );
535
315
Value zero = create.math .constantIndex (0 );
536
- b ().create <KrnlMemcpyOp>(
537
- loc (), dest, src, numElems, /* dest_offset=*/ zero, /* src_offset=*/ zero);
316
+ b ().create <KrnlMemcpyOp>(loc (), dest, src, numElems,
317
+ /* dest_offset=*/ zero, /* src_offset=*/ zero);
538
318
}
539
319
540
320
void KrnlBuilder::memcpy (Value dest, Value src, Value numElems,
0 commit comments