Skip to content

Commit 7879d17

Browse files
Saturation in compiler generated Stickify (#2877)
* added saturation for ZLowStickExpansion.cpp under #def at this time Signed-off-by: Alexandre Eichenberger <[email protected]> --------- Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent a198252 commit 7879d17

File tree

3 files changed

+56
-21
lines changed

3 files changed

+56
-21
lines changed

src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp

+51-17
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
3030
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
3131
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
32+
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
3233
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
3334
#include "src/Dialect/Krnl/DialectBuilder.hpp"
3435
#include "src/Dialect/Krnl/KrnlHelper.hpp"
@@ -43,6 +44,9 @@
4344
#define PREFETCH_CSU_DIST 0
4445
#define PREFETCH_CSU 1
4546

47+
// TODO, integrate.
48+
#define SATURATION_ON 0
49+
4650
using namespace mlir;
4751

4852
namespace onnx_mlir {
@@ -71,14 +75,14 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
7175
layout.getValue().equals_insensitive("3D") ||
7276
layout.getValue().equals_insensitive("2D") ||
7377
layout.getValue().equals_insensitive("3DS")) {
74-
return generateUnstickCodeNoBuffer(rewriter, unstickOp, layout);
78+
return generateUnstickCodeNoBuffer(rewriter, unstickOp);
7579
}
7680
// Otherwise, we don't replace and keep the zdnn call.
7781
return failure();
7882
}
7983

80-
LogicalResult generateUnstickCodeNoBuffer(PatternRewriter &rewriter,
81-
ZLowUnstickOp unstickOp, StringAttr layout) const {
84+
LogicalResult generateUnstickCodeNoBuffer(
85+
PatternRewriter &rewriter, ZLowUnstickOp unstickOp) const {
8286
Operation *op = unstickOp.getOperation();
8387
Location loc = unstickOp.getLoc();
8488
MDBuilder create(rewriter, loc);
@@ -187,7 +191,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
187191
// Then (is full).
188192
[&](SCFBuilder b) {
189193
MDBuilder create(b);
190-
// Loop
194+
// Loop (tried unroll of 2 and 8, 4 was best).
191195
const int64_t U = 4;
192196
assert(U * VL <= 64 && "bad unroll");
193197
create.scf.forLoop(litZero.getValue(), lit64.getValue(), U * VL,
@@ -309,15 +313,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
309313
layout.getValue().equals_insensitive("3D") ||
310314
layout.getValue().equals_insensitive("2D") ||
311315
layout.getValue().equals_insensitive("3DS")) {
312-
return generateStickCodeNoBuffer(rewriter, stickOp, layout);
316+
return generateStickCodeNoBuffer(rewriter, stickOp);
313317
}
314318
// Otherwise, we don't replace and keep the zdnn call.
315319
return failure();
316320
}
317321

318322
/* Version without buffer, more like zdnn */
319323
LogicalResult generateStickCodeNoBuffer(
320-
PatternRewriter &rewriter, ZLowStickOp stickOp, StringAttr layout) const {
324+
PatternRewriter &rewriter, ZLowStickOp stickOp) const {
321325
Operation *op = stickOp.getOperation();
322326
Location loc = stickOp.getLoc();
323327
MDBuilder create(rewriter, loc);
@@ -327,6 +331,12 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
327331
Value input = stickOp.getX();
328332
Value alloc = stickOp.getOut();
329333

334+
bool saturation = false;
335+
#if SATURATION_ON
336+
// TODO: hook to operation's attribute.
337+
saturation = true;
338+
#endif
339+
330340
DimsExpr outputDims;
331341
create.krnlIE.getShapeAsSymbols(alloc, outputDims);
332342
int64_t rank = outputDims.size();
@@ -344,6 +354,15 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
344354
IndexExpr litVLHalf = LiteralIndexExpr(VLHalf);
345355
IndexExpr lit64 = LiteralIndexExpr(64);
346356

357+
// Values for saturation.
358+
Value vecDlf16Min, vecDlf16Max;
359+
if (saturation) {
360+
Value dlf16Min = create.math.constant(f32Type, DLF16_MIN);
361+
vecDlf16Min = create.vec.splat(vecF32Type, dlf16Min);
362+
Value dlf16Max = create.math.constant(f32Type, DLF16_MAX);
363+
vecDlf16Max = create.vec.splat(vecF32Type, dlf16Max);
364+
}
365+
347366
// Useful references for indexing dimensions (neg val are not used).
348367
int64_t E1 = rank - 1;
349368

@@ -406,7 +425,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
406425
#endif
407426
#endif
408427

409-
const int64_t U = 4;
428+
const int64_t U = 4; // Tried 2 and 8, 4 was best.
410429
assert(U * VL <= 64 && "bad unroll");
411430
create.affine.forIE(litZero, lit64, U * VL,
412431
[&](AffineBuilder &b, ValueRange loopInd) {
@@ -417,21 +436,36 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
417436
getIndexExprList<SymbolIndexExpr>(memAF, inputAF);
418437
// E1: add the "l" local E1 offset.
419438
inputAF[E1] = inputAF[E1] + l;
439+
// Load the f32.
420440
Value vecF32H[U], vecF32L[U], vecF16[U];
421-
for (int64_t i = 0; i < U; ++i) {
422-
LiteralIndexExpr iH(i * VL), iL(i * VL + VL / 2);
423-
vecF32H[i] = create.vec.loadIE(
441+
for (int64_t u = 0; u < U; ++u) {
442+
LiteralIndexExpr iH(u * VL), iL(u * VL + VL / 2);
443+
vecF32H[u] = create.vec.loadIE(
424444
vecF32Type, input, inputAF, {iH.getValue()});
425-
vecF32L[i] = create.vec.loadIE(
445+
vecF32L[u] = create.vec.loadIE(
426446
vecF32Type, input, inputAF, {iL.getValue()});
427447
}
428-
for (int64_t i = 0; i < U; ++i) {
429-
vecF16[i] = rewriter.create<ZLowConvertF32ToDLF16VectorOp>(
430-
loc, vecF32H[i], vecF32L[i]);
448+
if (saturation) {
449+
// Get rid of too-high values.
450+
for (int64_t u = 0; u < U; ++u) {
451+
vecF32H[u] = create.math.min(vecF32H[u], vecDlf16Max);
452+
vecF32L[u] = create.math.min(vecF32L[u], vecDlf16Max);
453+
}
454+
// Get rid of too-low values.
455+
for (int64_t u = 0; u < U; ++u) {
456+
vecF32H[u] = create.math.max(vecF32H[u], vecDlf16Min);
457+
vecF32L[u] = create.math.max(vecF32L[u], vecDlf16Min);
458+
}
459+
}
460+
// Convert f32 to dlfloat16.
461+
for (int64_t u = 0; u < U; ++u) {
462+
vecF16[u] = rewriter.create<ZLowConvertF32ToDLF16VectorOp>(
463+
loc, vecF32H[u], vecF32L[u]);
431464
}
432-
for (int64_t i = 0; i < U; ++i) {
433-
create.vec.storeIE(vecF16[i], allocAsTx64,
434-
{SymIE(allocTileIndex), l + (i * VL)}, {});
465+
// Store the dlfloat16.
466+
for (int64_t u = 0; u < U; ++u) {
467+
create.vec.storeIE(vecF16[u], allocAsTx64,
468+
{SymIE(allocTileIndex), l + (u * VL)}, {});
435469
}
436470
});
437471
});

src/Dialect/Mlir/DialectBuilder.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,9 @@ Value MathBuilder::neq(Value lhs, Value rhs) const {
414414
llvm_unreachable("expected int or float");
415415
}
416416

417-
Value MathBuilder::select(Value cmp, Value lhs, Value rhs) const {
418-
assert(lhs.getType() == rhs.getType() && "expected same type");
419-
return b().create<arith::SelectOp>(loc(), cmp, lhs, rhs);
417+
Value MathBuilder::select(Value cmp, Value trueVal, Value falseVal) const {
418+
assert(trueVal.getType() == falseVal.getType() && "expected same type");
419+
return b().create<arith::SelectOp>(loc(), cmp, trueVal, falseVal);
420420
}
421421

422422
Value MathBuilder::constant(Type type, double val) const {

src/Dialect/Mlir/DialectBuilder.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ struct MathBuilder final : DialectBuilder {
131131
mlir::Value tanh(mlir::Value val) const; // Float only.
132132
mlir::Value xori(mlir::Value lhs, mlir::Value rhs) const; // Int only.
133133

134-
mlir::Value select(mlir::Value cmp, mlir::Value lhs, mlir::Value rhs) const;
134+
mlir::Value select(
135+
mlir::Value cmp, mlir::Value trueVal, mlir::Value valseVal) const;
135136
mlir::Value gt(mlir::Value lhs, mlir::Value rhs) const;
136137
mlir::Value ge(mlir::Value lhs, mlir::Value rhs) const;
137138
mlir::Value lt(mlir::Value lhs, mlir::Value rhs) const;

0 commit comments

Comments
 (0)