Skip to content

Commit 5c53b7e

Browse files
Added explicit register pressure estimate for SIMD and tuned [Dynamic]LinearQuantization operations (#2945)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 087f069 commit 5c53b7e

13 files changed

+710
-500
lines changed

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,9 @@ template <>
12891289
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
12901290
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
12911291
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
1292-
{GenericOps::FloorGop, 2}};
1292+
{GenericOps::FloorGop, 2},
1293+
{GenericOps::EstimatedVectorRegisterPressure,
1294+
4 /* Little parallelism in code. */}};
12931295
}
12941296

12951297
template <>

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp

+28-6
Original file line numberDiff line numberDiff line change
@@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType,
662662
return 1;
663663
}
664664
// Gather operation statics
665-
int64_t vectorizedOpNum, scalarOpNum;
666-
double avgVL = VectorMachineSupport::getAvgArchVectorLength(
667-
genOps, elementType, vectorizedOpNum, scalarOpNum);
665+
int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure;
666+
double avgVL =
667+
VectorMachineSupport::getAvgArchVectorLength(genOps, elementType,
668+
vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure);
668669
if (avgVL < 1.5) {
669670
LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with "
670671
<< avgVL << " avg VL\n");
671672
return 1;
672673
}
673-
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n");
674+
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL
675+
<< ", vec op num " << vectorizedOpNum
676+
<< ", max reg pressure "
677+
<< estimatedMaxVectorRegisterPressure << "\n");
674678

675679
// Define a target max unroll as a function of register pressure.
676680
int64_t unrollVL;
677681
int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum();
678-
if (vectorizedOpNum >= vrNum / 2)
682+
if (estimatedMaxVectorRegisterPressure >= vrNum)
683+
unrollVL = 1;
684+
else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum)
679685
unrollVL = 2;
680-
else if (vectorizedOpNum >= vrNum / 4)
686+
else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum)
681687
unrollVL = 4;
682688
else
683689
unrollVL = 8;
@@ -743,6 +749,22 @@ int64_t capVLForMaxUnroll(
743749
return archVL * unrollVL;
744750
}
745751

752+
int64_t boostVLForMinUnroll(
753+
MemRefType memRefType, MemRefType convertedMemRefType, int64_t totVL) {
754+
if (totVL == 1)
755+
return 1; // Simd already disabled, nothing to cap.
756+
Type convertedElementType = convertedMemRefType.getElementType();
757+
int64_t convertedArchVL =
758+
VectorMachineSupport::getArchVectorLength(convertedElementType);
759+
if (convertedArchVL > totVL) {
760+
LLVM_DEBUG(llvm::dbgs()
761+
<< " simd enable: boost totVL to " << convertedArchVL
762+
<< " because of type conversions.\n");
763+
return convertedArchVL;
764+
}
765+
return totVL;
766+
}
767+
746768
int64_t capVLForSimdOnly(
747769
MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) {
748770
if (totVL == 1)

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,12 @@ int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType,
663663
// Cap totVL so that it is at most maxUnrollVL * archVL.
664664
int64_t capVLForMaxUnroll(
665665
mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL);
666+
// In some type conversion loops we may have a given totVL based on a given
667+
// memRef type and gen op mix. But the final result may be converted to a
668+
// different type, which may requires a minimum unroll to proceed as a single
669+
// SIMD operation. This call adjust the totVL for that case.
670+
int64_t boostVLForMinUnroll(mlir::MemRefType memRefType,
671+
mlir::MemRefType convertedMemRefType, int64_t totVL);
666672
// Enabling a simdOnly code generation scheme by capping totVL so that it
667673
// divides simdLoopStaticTripCount. When not possible (either because
668674
// there is no totVL that divides simdLoopStaticTripCount or trip count is

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

+37-7
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
2929
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
3030
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
3131
bool enableParallel) {
32-
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
33-
rewriter, loc);
32+
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
33+
create(rewriter, loc);
3434

3535
// Types
3636
Type quantizedElementType = quantizedType.getElementType();
@@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
5454
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
5555
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
5656
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
57-
{GenericOps::FloorGop, 2}};
57+
{GenericOps::FloorGop, 2},
58+
{GenericOps::EstimatedVectorRegisterPressure,
59+
8 /* Little parallelism in code. */}};
5860
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
5961
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
6062
simdOnly);
@@ -68,8 +70,16 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
6870
inputAF.emplace_back(zero);
6971
DimsExpr outputAF;
7072
outputAF.emplace_back(zero);
73+
74+
// faster than original loop on z16, takes 124us for 64k vals
75+
// Allocate output buffers.
76+
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
77+
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
78+
DimsExpr bufferAF;
79+
bufferAF.emplace_back(zero);
80+
7181
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
72-
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
82+
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
7383
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
7484
MultiDialectBuilder<MathBuilder> create(kb);
7585
Value x = inputVals[0];
@@ -83,11 +93,31 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
8393
adjustX = create.math.add(roundX, zeroPoint);
8494
else
8595
adjustX = roundX;
86-
// Saturate
96+
// Saturate: use max into a min.
8797
Value saturateX = create.math.clip(adjustX, qMin, qMax);
88-
Value res = create.math.cast(quantizedElementType, saturateX);
89-
return res;
98+
// Old approach.
99+
// return create.math.cast(quantizedElementType, saturateX);
100+
return saturateX;
90101
}});
102+
103+
// A second loop that performs scalar float to int performs better than the
104+
// compiler's attempt to generate SIMD conversion code. This might not hold
105+
// with all data types, but is definitely noticeable with uint8.
106+
//
107+
// Investigate further: we might save the vector to a buffer on the fly
108+
// (avoiding a second loop as below), and then reload each value as scalar and
109+
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
110+
// that also do not perform well). We can have a SIMD buffer in memory for the
111+
// non-quantized and quantized simd values, but then we also need to privatize
112+
// it, which is also not easy in this scheme. So ignore this for now.
113+
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
114+
[&](KrnlBuilder &kb, ValueRange loopInd) {
115+
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
116+
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
117+
Value res = create.math.cast(quantizedElementType, buffVal);
118+
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
119+
});
120+
91121
if (totVL > 1)
92122
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
93123
simdLoopStaticTripCount, "quantizationLinear whole tensor");

src/Dialect/Mlir/DialectBuilder.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -2073,6 +2073,29 @@ void VectorBuilder::multiReduction(ArrayRef<Value> inputVecArray,
20732073
}
20742074
}
20752075

2076+
Value VectorBuilder::extractElement(Value vector, int64_t index) const {
2077+
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
2078+
VectorType type = llvm::cast<VectorType>(vector.getType());
2079+
int64_t VL = type.getShape()[0];
2080+
assert(type.getRank() == 1 && "expected 1D vector only");
2081+
assert(index >= 0 && index < VL && "out of range vector index");
2082+
Value position = create.math.constantIndex(index);
2083+
return b().create<vector::ExtractElementOp>(loc(), vector, position);
2084+
}
2085+
2086+
Value VectorBuilder::insertElement(
2087+
Value vector, Value element, int64_t index) const {
2088+
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
2089+
VectorType type = llvm::cast<VectorType>(vector.getType());
2090+
int64_t VL = type.getShape()[0];
2091+
assert(type.getRank() == 1 && "expected 1D vector only");
2092+
assert(index >= 0 && index < VL && "out of range vector index");
2093+
Value position = create.math.constantIndex(index);
2094+
// Unlike LLVM insert element which takes <dest, source, position>, vector
2095+
// take <source, dest, position>
2096+
return b().create<vector::InsertElementOp>(loc(), element, vector, position);
2097+
}
2098+
20762099
//===----------------------------------------------------------------------===//
20772100
// LLVM Builder
20782101
//===----------------------------------------------------------------------===//

src/Dialect/Mlir/DialectBuilder.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,11 @@ struct VectorBuilder final : DialectBuilder {
574574
void multiReduction(mlir::ArrayRef<mlir::Value> inputVecArray,
575575
F2 reductionFct, llvm::SmallVectorImpl<mlir::Value> &outputVecArray);
576576

577+
// Insert and extract.
578+
mlir::Value extractElement(mlir::Value vector, int64_t position) const;
579+
mlir::Value insertElement(
580+
mlir::Value vector, mlir::Value element, int64_t position) const;
581+
577582
private:
578583
bool isPowerOf2(uint64_t num) const;
579584
uint64_t getLengthOf1DVector(mlir::Value vec) const;

src/Dialect/Mlir/VectorMachineSupport.cpp

+46-23
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
7878
}
7979

8080
/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps,
81-
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
81+
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
82+
int64_t &maxVectorRegisterPressure) {
8283
int64_t size = genOps.size();
84+
vectorizedOpNum = maxVectorRegisterPressure = 0;
8385
if (!hasSimd()) {
84-
vectorizedOpNum = 0;
8586
scalarOpNum = size;
8687
return 1;
8788
}
8889
int64_t totProcessedValues = 0.0;
89-
vectorizedOpNum = 0;
9090
scalarOpNum = 0;
91+
bool hasRegisterPressure = false;
92+
9193
// Determine which operations support SIMD and accumulate their vector
9294
// lengths.
9395
for (auto pair : genOps) {
9496
GenericOps genOp = pair.first;
9597
int64_t num = pair.second;
98+
// Handle other metrics first.
99+
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
100+
maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num);
101+
hasRegisterPressure = true;
102+
continue;
103+
}
104+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
96105
int64_t vl = getArchVectorLength(genOp, elementType);
97106
// If past last value, assume 1; otherwise use actual value.
98107
// Accumulate weighted scalar/vectorized num and vl length.
@@ -106,7 +115,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
106115
}
107116
// Compute final values
108117
int64_t totNum = vectorizedOpNum + scalarOpNum;
109-
scalarOpNum = size - vectorizedOpNum;
118+
if (!hasRegisterPressure) {
119+
// Estimate default register pressure as one per 2 vector operation.
120+
maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1);
121+
}
110122
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0;
111123
}
112124

@@ -115,13 +127,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
115127
// =============================================================================
116128

117129
int64_t Z16VectorMachineSupport::computeArchVectorLength(
118-
GenericOps Gop, Type elementType) {
130+
GenericOps genOp, Type elementType) {
131+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
119132
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
120133
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
121134
bool isFloat = mlir::isa<FloatType>(elementType);
122-
123135
// Support shared between int and float.
124-
switch (Gop) {
136+
switch (genOp) {
125137
case GenericOps::ScalarOnlyGop:
126138
return 1; // Must be scalar.
127139
case GenericOps::SelectGop:
@@ -137,10 +149,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
137149
// Supports only 32 and 64 bit Floats; There is support for extended too
138150
// but ignore this for now.
139151
if (!(bitWidth == 32 || bitWidth == 64 ||
140-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
152+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
141153
return UNSUPPORTED;
142154
// Now we have a supported length, test for specific operations.
143-
switch (Gop) {
155+
switch (genOp) {
144156
case GenericOps::AbsGop: /* Supported via compare and select */
145157
case GenericOps::ArithmeticGop: /* Add/sub,... */
146158
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
@@ -161,7 +173,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
161173
}
162174
}
163175
// Support for integer (we consider bit-wide ops as byte wide ops).
164-
switch (Gop) {
176+
switch (genOp) {
165177
// 1 - 16 byte operations.
166178
case GenericOps::ArithmeticGop: /* Add/sub,... */
167179
case GenericOps::ConversionGop:
@@ -190,13 +202,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
190202
// =============================================================================
191203

192204
int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
193-
GenericOps Gop, Type elementType) {
205+
GenericOps genOp, Type elementType) {
206+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
194207
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
195208
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
196209
bool isFloat = mlir::isa<FloatType>(elementType);
197210

198211
// Support shared between int and float.
199-
switch (Gop) {
212+
switch (genOp) {
200213
case GenericOps::ScalarOnlyGop:
201214
return 1; // Must be scalar.
202215
case GenericOps::SelectGop:
@@ -212,10 +225,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
212225
// Supports only 32 and 64 bit Floats; There is support for extended too
213226
// but ignore this for now.
214227
if (!(bitWidth == 32 || bitWidth == 64 ||
215-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
228+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
216229
return UNSUPPORTED;
217230
// Now we have a supported length, test for specific operations.
218-
switch (Gop) {
231+
switch (genOp) {
219232
case GenericOps::AbsGop:
220233
case GenericOps::ArithmeticGop: /* Add/sub,... */
221234
case GenericOps::CeilGop:
@@ -237,7 +250,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
237250
}
238251
}
239252
// Support for integer (we consider bit-wide ops as byte wide ops).
240-
switch (Gop) {
253+
switch (genOp) {
241254
// 1 - 16 byte operations.
242255
case GenericOps::ArithmeticGop: /* Add/sub,... */
243256
case GenericOps::ConversionGop:
@@ -276,13 +289,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
276289
// =============================================================================
277290

278291
int64_t NeonVectorMachineSupport::computeArchVectorLength(
279-
GenericOps Gop, Type elementType) {
292+
GenericOps genOp, Type elementType) {
293+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
280294
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
281295
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
282296
bool isFloat = mlir::isa<FloatType>(elementType);
283297

284298
// Support shared between int and float.
285-
switch (Gop) {
299+
switch (genOp) {
286300
case GenericOps::ScalarOnlyGop:
287301
return 1; // Must be scalar.
288302
case GenericOps::SelectGop:
@@ -297,10 +311,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
297311
if (isFloat) {
298312
// Supports only 32 and 64 bit Floats;
299313
if (!(bitWidth == 32 || bitWidth == 64 ||
300-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
314+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
301315
return UNSUPPORTED;
302316
// Now we have a supported length, test for specific operations.
303-
switch (Gop) {
317+
switch (genOp) {
304318
case GenericOps::AbsGop:
305319
case GenericOps::ArithmeticGop: /* Add/sub,... */
306320
case GenericOps::CeilGop:
@@ -322,7 +336,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
322336
}
323337
}
324338
// Support for integer (we consider bit-wide ops as byte wide ops).
325-
switch (Gop) {
339+
switch (genOp) {
326340
// 1 - 16 byte operations.
327341
case GenericOps::ArithmeticGop: /* Add/sub,... */
328342
case GenericOps::ConversionGop:
@@ -370,10 +384,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
370384
for (auto pair : mix1) {
371385
GenericOps genOp = pair.first;
372386
int64_t num = pair.second;
373-
if (u.find(genOp) != u.end())
374-
u[genOp] += num; // Has this op already, add to it.
375-
else
387+
if (u.find(genOp) != u.end()) {
388+
// Merge the 2 operation counts/metrics.
389+
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
390+
// For register pressure, pick the max of both.
391+
u[genOp] = std::max(u[genOp], num);
392+
} else {
393+
// For operation count, use the sum of both
394+
u[genOp] += num;
395+
}
396+
} else {
397+
// First time we have this.
376398
u[genOp] = num;
399+
}
377400
}
378401
return u;
379402
}

0 commit comments

Comments
 (0)