Skip to content

Commit 3a62034

Browse files
attempt with temp buffers inside main loop
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 9dfe268 commit 3a62034

File tree

1 file changed

+38
-52
lines changed

1 file changed

+38
-52
lines changed

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
5757
{GenericOps::FloorGop, 2},
5858
{GenericOps::EstimatedVectorRegisterPressure,
5959
8 /* Little parallelism in code. */}};
60-
// Because quantization transforms, for example, a 4 bytes input type of
61-
// float into 1 byte output type of char, and since most of the computations
62-
// are in float, we need to provide the float type below to let the function
63-
// see that most generic operations are supported for floats. But at the
64-
// same time, we need a minimum total unrolling of 16 so as to generate a
65-
// single vector of uint8.
6660
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
6761
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
6862
simdOnly);
@@ -77,18 +71,19 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
7771
DimsExpr outputAF;
7872
outputAF.emplace_back(zero);
7973

80-
#if 0
81-
// Insert / extract are slow on z16: 169us for 64K vals.
82-
MemRefType outputType = llvm::cast<MemRefType>(alloc.getType());
83-
totVL = boostVLForMinUnroll(inputType, outputType, totVL);
84-
VectorType quantizedVectorType =
85-
VectorType::get({totVL}, quantizedElementType);
86-
Value qDummy = create.vec.loadIE(quantizedVectorType, flatAlloc, {zero});
74+
#if 1
75+
// Allocate output buffers.
76+
MemRefType inputBufferType =
77+
MemRefType::get({totVL}, inputType.getElementType());
78+
Value inputBuffer = create.mem.alignedAlloc(inputBufferType);
79+
MemRefType outputBufferType = MemRefType::get({totVL}, quantizedElementType);
80+
VectorType outputVectorType = VectorType::get({totVL}, quantizedElementType);
81+
Value outputBuffer = create.mem.alignedAlloc(outputBufferType);
8782

8883
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
8984
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
9085
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
91-
MultiDialectBuilder<VectorBuilder, MathBuilder> create(kb);
86+
MultiDialectBuilder<KrnlBuilder, MathBuilder, VectorBuilder> create(kb);
9287
Value x = inputVals[0];
9388
// Scale
9489
Value scaleX = create.math.div(x, scale);
@@ -102,21 +97,22 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
10297
adjustX = roundX;
10398
// Saturate: use max into a min.
10499
Value saturateX = create.math.clip(adjustX, qMin, qMax);
105-
Value res;
106-
if (VL == 1) {
107-
res = create.math.cast(quantizedElementType, saturateX);
108-
} else {
109-
res = qDummy; //
110-
for (int64_t v = 0; v < VL; ++v) {
111-
Value element = create.vec.extractElement(saturateX, v);
112-
Value resElement = create.math.cast(quantizedElementType, element);
113-
res = create.vec.insertElement(res, resElement, v);
114-
}
100+
if (VL == 1)
101+
return create.math.cast(quantizedElementType, saturateX);
102+
// Has VL values; first save all VL into buffer.
103+
create.vec.storeIE(saturateX, inputBuffer, {zero});
104+
// Now process each value in turn
105+
for (int64_t v = 0; v < VL; ++v) {
106+
IndexExpr vv = LitIE(v);
107+
Value scalarSaturateX = create.krnl.loadIE(inputBuffer, {vv});
108+
Value scalarRes =
109+
create.math.cast(quantizedElementType, scalarSaturateX);
110+
create.krnl.storeIE(scalarRes, outputBuffer, {vv});
115111
}
116-
return res;
112+
// Reload the output buffer as one vector.
113+
return create.vec.loadIE(outputVectorType, outputBuffer, {zero});
117114
}});
118-
119-
#elif 1
115+
#else
120116
// faster than original loop on z16, takes 124us for 64k vals
121117
// Allocate output buffers.
122118
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
@@ -141,40 +137,30 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
141137
adjustX = roundX;
142138
// Saturate: use max into a min.
143139
Value saturateX = create.math.clip(adjustX, qMin, qMax);
140+
// Old approach.
141+
// return create.math.cast(quantizedElementType, saturateX);
144142
return saturateX;
145143
}});
146-
create.krnl.forLoopIE(simdLb, simdUb, 1, /*parallel*/ false,
144+
145+
// A second loop that performs scalar float to int performs better than the
146+
// compiler's attempt to generate SIMD conversion code. This might not hold
147+
// with all data types, but is definitely noticeable with uint8.
148+
//
149+
// Todo: we might save the vector to a buffer on the fly (avoiding a second
150+
// loop as below), and then reload each value as scalar and then saved them as
151+
// scalar (thus avoiding the insert/extract SIMD operations that also do not
152+
// perform well). The problem is that the current SIMD scheme expect a return
153+
// value, either SIMD in SIMD mode or scalar in scalar mode. Thus that
154+
// alternative scheme is not easy to pull off here.
155+
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
147156
[&](KrnlBuilder &kb, ValueRange loopInd) {
148157
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
149158
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
150159
Value res = create.math.cast(quantizedElementType, buffVal);
151160
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
152161
});
153-
#else
154-
// original, slow on z16 where it takes 158us
155-
MemRefType outputType = llvm::cast<MemRefType>(alloc.getType());
156-
totVL = boostVLForMinUnroll(inputType, outputType, totVL);
157-
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
158-
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
159-
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
160-
MultiDialectBuilder<MathBuilder> create(kb);
161-
Value x = inputVals[0];
162-
// Scale
163-
Value scaleX = create.math.div(x, scale);
164-
// Round
165-
Value roundX = create.math.round(scaleX);
166-
// Adjust
167-
Value adjustX;
168-
if (hasZeroPoint)
169-
adjustX = create.math.add(roundX, zeroPoint);
170-
else
171-
adjustX = roundX;
172-
// Saturate: use max into a min.
173-
Value saturateX = create.math.clip(adjustX, qMin, qMax);
174-
Value res = create.math.cast(quantizedElementType, saturateX);
175-
return res;
176-
}});
177162
#endif
163+
178164
if (totVL > 1)
179165
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
180166
simdLoopStaticTripCount, "quantizationLinear whole tensor");

0 commit comments

Comments
 (0)