12
12
//
13
13
// ===----------------------------------------------------------------------===//
14
14
15
+ #include " src/Compiler/CompilerOptions.hpp"
15
16
#include " src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
16
17
#include " src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp"
17
18
#include " src/Dialect/Krnl/DialectBuilder.hpp"
@@ -29,7 +30,7 @@ void emitDynamicQuantizationLinearScalarParameters(
29
30
ConversionPatternRewriter &rewriter, Location loc, Operation *op,
30
31
MemRefType inputType, MemRefType quantizedType, Value input, Value qMin,
31
32
Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint,
32
- bool enableSIMD, bool enableParallel) {
33
+ bool wantZeroPoint, bool enableSIMD, bool enableParallel) {
33
34
MultiDialectBuilder<KrnlBuilder, MathBuilder> create (rewriter, loc);
34
35
35
36
// Types
@@ -62,11 +63,15 @@ void emitDynamicQuantizationLinearScalarParameters(
62
63
scale = create.math .div (xDiff, boundDiff);
63
64
64
65
// Compute y_zero_point.
65
- Value interZeroPoint = create.math .sub (qMin, create.math .div (xMin, scale));
66
- // Saturate zero point.
67
- Value saturateZeroPoint = create.math .clip (interZeroPoint, qMin, qMax);
68
- // Round zero point.
69
- zeroPoint = create.math .round (saturateZeroPoint);
66
+ if (wantZeroPoint) {
67
+ Value interZeroPoint = create.math .sub (qMin, create.math .div (xMin, scale));
68
+ // Saturate zero point.
69
+ Value saturateZeroPoint = create.math .clip (interZeroPoint, qMin, qMax);
70
+ // Round zero point.
71
+ zeroPoint = create.math .round (saturateZeroPoint);
72
+ } else {
73
+ zeroPoint = zero;
74
+ }
70
75
quantizedZeroPoint = create.math .cast (quantizedElementType, zeroPoint);
71
76
}
72
77
@@ -122,15 +127,17 @@ struct ONNXDynamicQuantizeLinearOpLowering
122
127
Value qMin = create.math .constant (elementType, 0.0 );
123
128
Value scale, zeroPoint, zeroPointInt;
124
129
130
+ bool wantZeroPoint = !disableQuantZeroPoint;
125
131
emitDynamicQuantizationLinearScalarParameters (rewriter, loc, op,
126
132
xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt,
127
- enableSIMD, enableParallel);
133
+ wantZeroPoint, enableSIMD, enableParallel);
128
134
create.krnl .store (scale, YScale);
129
135
create.krnl .store (zeroPointInt, YZeroPoint);
130
136
131
137
emitQuantizationLinearScalarParameters (rewriter, loc, op, xMemRefType,
132
138
yMemRefType, Y, shapeHelper.getOutputDims (0 ), X, qMin, qMax, scale,
133
- zeroPoint, enableSIMD, enableParallel);
139
+ zeroPoint, wantZeroPoint /* wanted one, so we have a zero point*/ ,
140
+ enableSIMD, enableParallel);
134
141
135
142
rewriter.replaceOp (op, {Y, YScale, YZeroPoint});
136
143
onnxToKrnlSimdReport (op);
0 commit comments