@@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
29
29
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
30
30
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
31
31
bool enableParallel) {
32
- MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create (
33
- rewriter, loc);
32
+ MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
33
+ create ( rewriter, loc);
34
34
35
35
// Types
36
36
Type quantizedElementType = quantizedType.getElementType ();
@@ -78,6 +78,45 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
78
78
outputAF.emplace_back (zero);
79
79
80
80
#if 1
81
+
82
+ MemRefType outputType = llvm::cast<MemRefType>(alloc.getType ());
83
+ totVL = boostVLForMinUnroll (inputType, outputType, totVL);
84
+ VectorType quantizedVectorType =
85
+ VectorType::get ({totVL}, quantizedElementType);
86
+ Value qDummy = create.vec .loadIE (quantizedVectorType, flatAlloc, {zero});
87
+
88
+ create.krnl .simdIterateIE (simdLb, simdUb, totVL, simdOnly, enableParallel,
89
+ {flatInput}, {inputAF}, {flatAlloc}, {outputAF},
90
+ {[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
91
+ MultiDialectBuilder<VectorBuilder, MathBuilder> create (kb);
92
+ Value x = inputVals[0 ];
93
+ // Scale
94
+ Value scaleX = create.math .div (x, scale);
95
+ // Round
96
+ Value roundX = create.math .round (scaleX);
97
+ // Adjust
98
+ Value adjustX;
99
+ if (hasZeroPoint)
100
+ adjustX = create.math .add (roundX, zeroPoint);
101
+ else
102
+ adjustX = roundX;
103
+ // Saturate: use max into a min.
104
+ Value saturateX = create.math .clip (adjustX, qMin, qMax);
105
+ Value res;
106
+ if (VL == 1 ) {
107
+ res = create.math .cast (quantizedElementType, saturateX);
108
+ } else {
109
+ res = qDummy; //
110
+ for (int64_t v = 0 ; v < VL; ++v) {
111
+ Value element = create.vec .extractElement (saturateX, v);
112
+ Value resElement = create.math .cast (quantizedElementType, element);
113
+ res = create.vec .insertElement (res, resElement, v);
114
+ }
115
+ }
116
+ return res;
117
+ }});
118
+
119
+ #elif 1
81
120
// Allocate output buffers.
82
121
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType ());
83
122
Value flatBuffer = create.mem .alignedAlloc (flatBufferType, flatInputDims);
0 commit comments