29
29
#include " src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
30
30
#include " src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
31
31
#include " src/Accelerators/NNPA/Support/LayoutHelper.hpp"
32
+ #include " src/Accelerators/NNPA/Support/NNPALimit.hpp"
32
33
#include " src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
33
34
#include " src/Dialect/Krnl/DialectBuilder.hpp"
34
35
#include " src/Dialect/Krnl/KrnlHelper.hpp"
43
44
#define PREFETCH_CSU_DIST 0
44
45
#define PREFETCH_CSU 1
45
46
47
+ // TODO, integrate.
48
+ #define SATURATION_ON 0
49
+
46
50
using namespace mlir ;
47
51
48
52
namespace onnx_mlir {
@@ -71,14 +75,14 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
71
75
layout.getValue ().equals_insensitive (" 3D" ) ||
72
76
layout.getValue ().equals_insensitive (" 2D" ) ||
73
77
layout.getValue ().equals_insensitive (" 3DS" )) {
74
- return generateUnstickCodeNoBuffer (rewriter, unstickOp, layout );
78
+ return generateUnstickCodeNoBuffer (rewriter, unstickOp);
75
79
}
76
80
// Otherwise, we don't replace and keep the zdnn call.
77
81
return failure ();
78
82
}
79
83
80
- LogicalResult generateUnstickCodeNoBuffer (PatternRewriter &rewriter,
81
- ZLowUnstickOp unstickOp, StringAttr layout ) const {
84
+ LogicalResult generateUnstickCodeNoBuffer (
85
+ PatternRewriter &rewriter, ZLowUnstickOp unstickOp ) const {
82
86
Operation *op = unstickOp.getOperation ();
83
87
Location loc = unstickOp.getLoc ();
84
88
MDBuilder create (rewriter, loc);
@@ -187,7 +191,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
187
191
// Then (is full).
188
192
[&](SCFBuilder b) {
189
193
MDBuilder create (b);
190
- // Loop
194
+ // Loop (tried unroll of 2 and 8, 4 was best).
191
195
const int64_t U = 4 ;
192
196
assert (U * VL <= 64 && " bad unroll" );
193
197
create.scf .forLoop (litZero.getValue (), lit64.getValue (), U * VL,
@@ -309,15 +313,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
309
313
layout.getValue ().equals_insensitive (" 3D" ) ||
310
314
layout.getValue ().equals_insensitive (" 2D" ) ||
311
315
layout.getValue ().equals_insensitive (" 3DS" )) {
312
- return generateStickCodeNoBuffer (rewriter, stickOp, layout );
316
+ return generateStickCodeNoBuffer (rewriter, stickOp);
313
317
}
314
318
// Otherwise, we don't replace and keep the zdnn call.
315
319
return failure ();
316
320
}
317
321
318
322
/* Version without buffer, more like zdnn */
319
323
LogicalResult generateStickCodeNoBuffer (
320
- PatternRewriter &rewriter, ZLowStickOp stickOp, StringAttr layout ) const {
324
+ PatternRewriter &rewriter, ZLowStickOp stickOp) const {
321
325
Operation *op = stickOp.getOperation ();
322
326
Location loc = stickOp.getLoc ();
323
327
MDBuilder create (rewriter, loc);
@@ -327,6 +331,12 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
327
331
Value input = stickOp.getX ();
328
332
Value alloc = stickOp.getOut ();
329
333
334
+ bool saturation = false ;
335
+ #if SATURATION_ON
336
+ // TODO: hook to operation's attribute.
337
+ saturation = true ;
338
+ #endif
339
+
330
340
DimsExpr outputDims;
331
341
create.krnlIE .getShapeAsSymbols (alloc, outputDims);
332
342
int64_t rank = outputDims.size ();
@@ -344,6 +354,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
344
354
IndexExpr litVLHalf = LiteralIndexExpr (VLHalf);
345
355
IndexExpr lit64 = LiteralIndexExpr (64 );
346
356
357
+ // Values for saturation.
358
+ Value vecDlf16Min, vecDlf16Max;
359
+ if (saturation) {
360
+ Value dlf16Min = create.math .constant (f32Type, DLF16_MIN);
361
+ vecDlf16Min = create.vec .splat (vecF32Type, dlf16Min);
362
+ Value dlf16Max = create.math .constant (f32Type, DLF16_MAX);
363
+ vecDlf16Max = create.vec .splat (vecF32Type, dlf16Max);
364
+ }
365
+
347
366
// Useful references for indexing dimensions (neg val are not used).
348
367
int64_t E1 = rank - 1 ;
349
368
@@ -406,7 +425,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
406
425
#endif
407
426
#endif
408
427
409
- const int64_t U = 4 ;
428
+ const int64_t U = 4 ; // Tried 2 and 8, 4 was best.
410
429
assert (U * VL <= 64 && " bad unroll" );
411
430
create.affine .forIE (litZero, lit64, U * VL,
412
431
[&](AffineBuilder &b, ValueRange loopInd) {
@@ -417,21 +436,36 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
417
436
getIndexExprList<SymbolIndexExpr>(memAF, inputAF);
418
437
// E1: add the "l" local E1 offset.
419
438
inputAF[E1 ] = inputAF[E1 ] + l;
439
+ // Load the f32.
420
440
Value vecF32H[U], vecF32L[U], vecF16[U];
421
- for (int64_t i = 0 ; i < U; ++i ) {
422
- LiteralIndexExpr iH (i * VL), iL (i * VL + VL / 2 );
423
- vecF32H[i ] = create.vec .loadIE (
441
+ for (int64_t u = 0 ; u < U; ++u ) {
442
+ LiteralIndexExpr iH (u * VL), iL (u * VL + VL / 2 );
443
+ vecF32H[u ] = create.vec .loadIE (
424
444
vecF32Type, input, inputAF, {iH.getValue ()});
425
- vecF32L[i ] = create.vec .loadIE (
445
+ vecF32L[u ] = create.vec .loadIE (
426
446
vecF32Type, input, inputAF, {iL.getValue ()});
427
447
}
428
- for (int64_t i = 0 ; i < U; ++i) {
429
- vecF16[i] = rewriter.create <ZLowConvertF32ToDLF16VectorOp>(
430
- loc, vecF32H[i], vecF32L[i]);
448
+ if (saturation) {
449
+ // Get rid of too-high values.
450
+ for (int64_t u = 0 ; u < U; ++u) {
451
+ vecF32H[u] = create.math .min (vecF32H[u], vecDlf16Max);
452
+ vecF32L[u] = create.math .min (vecF32L[u], vecDlf16Max);
453
+ }
454
+ // Get rid of too-low values.
455
+ for (int64_t u = 0 ; u < U; ++u) {
456
+ vecF32H[u] = create.math .max (vecF32H[u], vecDlf16Min);
457
+ vecF32L[u] = create.math .max (vecF32L[u], vecDlf16Min);
458
+ }
459
+ }
460
+ // Convert f32 to dlfloat16.
461
+ for (int64_t u = 0 ; u < U; ++u) {
462
+ vecF16[u] = rewriter.create <ZLowConvertF32ToDLF16VectorOp>(
463
+ loc, vecF32H[u], vecF32L[u]);
431
464
}
432
- for (int64_t i = 0 ; i < U; ++i) {
433
- create.vec .storeIE (vecF16[i], allocAsTx64,
434
- {SymIE (allocTileIndex), l + (i * VL)}, {});
465
+ // Store the dlfloat16.
466
+ for (int64_t u = 0 ; u < U; ++u) {
467
+ create.vec .storeIE (vecF16[u], allocAsTx64,
468
+ {SymIE (allocTileIndex), l + (u * VL)}, {});
435
469
}
436
470
});
437
471
});
0 commit comments