Skip to content

Commit 359e095

Browse files
added register pressure estimate to be explicit when needed
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 9dd7c4a commit 359e095

File tree

7 files changed

+216
-162
lines changed

7 files changed

+216
-162
lines changed

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,9 @@ template <>
12891289
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
12901290
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
12911291
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
1292-
{GenericOps::FloorGop, 2}};
1292+
{GenericOps::FloorGop, 2},
1293+
{GenericOps::EstimatedVectorRegisterPressure,
1294+
4 /* Little parallelism in code. */}};
12931295
}
12941296

12951297
template <>

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType,
662662
return 1;
663663
}
664664
// Gather operation statics
665-
int64_t vectorizedOpNum, scalarOpNum;
666-
double avgVL = VectorMachineSupport::getAvgArchVectorLength(
667-
genOps, elementType, vectorizedOpNum, scalarOpNum);
665+
int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure;
666+
double avgVL =
667+
VectorMachineSupport::getAvgArchVectorLength(genOps, elementType,
668+
vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure);
668669
if (avgVL < 1.5) {
669670
LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with "
670671
<< avgVL << " avg VL\n");
671672
return 1;
672673
}
673-
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n");
674+
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL
675+
<< ", vec op num " << vectorizedOpNum
676+
<< ", max reg pressure "
677+
<< estimatedMaxVectorRegisterPressure << "\n");
674678

675679
// Define a target max unroll as a function of register pressure.
676680
int64_t unrollVL;
677681
int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum();
678-
if (vectorizedOpNum >= vrNum / 2)
682+
if (estimatedMaxVectorRegisterPressure >= vrNum)
683+
unrollVL = 1;
684+
else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum)
679685
unrollVL = 2;
680-
else if (vectorizedOpNum >= vrNum / 4)
686+
else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum)
681687
unrollVL = 4;
682688
else
683689
unrollVL = 8;

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
5454
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
5555
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
5656
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
57-
{GenericOps::FloorGop, 2}};
57+
{GenericOps::FloorGop, 2},
58+
{GenericOps::EstimatedVectorRegisterPressure,
59+
8 /* Little parallelism in code. */}};
5860
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
5961
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
6062
simdOnly);
@@ -83,7 +85,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
8385
adjustX = create.math.add(roundX, zeroPoint);
8486
else
8587
adjustX = roundX;
86-
// Saturate
88+
// Saturate: use max into a min.
8789
Value saturateX = create.math.clip(adjustX, qMin, qMax);
8890
Value res = create.math.cast(quantizedElementType, saturateX);
8991
return res;

src/Dialect/Mlir/VectorMachineSupport.cpp

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,31 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
7878
}
7979

8080
/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps,
81-
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
81+
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
82+
int64_t &maxVectorRegisterPressure) {
8283
int64_t size = genOps.size();
8384
if (!hasSimd()) {
84-
vectorizedOpNum = 0;
85+
vectorizedOpNum = maxVectorRegisterPressure = 0;
8586
scalarOpNum = size;
8687
return 1;
8788
}
8889
int64_t totProcessedValues = 0.0;
89-
vectorizedOpNum = 0;
90+
vectorizedOpNum = maxVectorRegisterPressure = 0;
9091
scalarOpNum = 0;
92+
bool hasRegisterPressure = false;
93+
9194
// Determine which operations support SIMD and accumulate their vector
9295
// lengths.
9396
for (auto pair : genOps) {
9497
GenericOps genOp = pair.first;
9598
int64_t num = pair.second;
99+
// Handle other metrics first.
100+
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
101+
maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num);
102+
hasRegisterPressure = true;
103+
continue;
104+
}
105+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
96106
int64_t vl = getArchVectorLength(genOp, elementType);
97107
// If past last value, assume 1; otherwise use actual value.
98108
// Accumulate weighted scalar/vectorized num and vl length.
@@ -107,6 +117,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
107117
// Compute final values
108118
int64_t totNum = vectorizedOpNum + scalarOpNum;
109119
scalarOpNum = size - vectorizedOpNum;
120+
if (!hasRegisterPressure) {
121+
// Estimate default register pressure as one per 2 vector operation.
122+
maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1);
123+
}
110124
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0;
111125
}
112126

@@ -115,13 +129,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
115129
// =============================================================================
116130

117131
int64_t Z16VectorMachineSupport::computeArchVectorLength(
118-
GenericOps Gop, Type elementType) {
132+
GenericOps genOp, Type elementType) {
133+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
119134
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
120135
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
121136
bool isFloat = mlir::isa<FloatType>(elementType);
122-
123137
// Support shared between int and float.
124-
switch (Gop) {
138+
switch (genOp) {
125139
case GenericOps::ScalarOnlyGop:
126140
return 1; // Must be scalar.
127141
case GenericOps::SelectGop:
@@ -137,10 +151,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
137151
// Supports only 32 and 64 bit Floats; There is support for extended too
138152
// but ignore this for now.
139153
if (!(bitWidth == 32 || bitWidth == 64 ||
140-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
154+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
141155
return UNSUPPORTED;
142156
// Now we have a supported length, test for specific operations.
143-
switch (Gop) {
157+
switch (genOp) {
144158
case GenericOps::AbsGop: /* Supported via compare and select */
145159
case GenericOps::ArithmeticGop: /* Add/sub,... */
146160
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
@@ -161,7 +175,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
161175
}
162176
}
163177
// Support for integer (we consider bit-wide ops as byte wide ops).
164-
switch (Gop) {
178+
switch (genOp) {
165179
// 1 - 16 byte operations.
166180
case GenericOps::ArithmeticGop: /* Add/sub,... */
167181
case GenericOps::ConversionGop:
@@ -190,13 +204,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
190204
// =============================================================================
191205

192206
int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
193-
GenericOps Gop, Type elementType) {
207+
GenericOps genOp, Type elementType) {
208+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
194209
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
195210
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
196211
bool isFloat = mlir::isa<FloatType>(elementType);
197212

198213
// Support shared between int and float.
199-
switch (Gop) {
214+
switch (genOp) {
200215
case GenericOps::ScalarOnlyGop:
201216
return 1; // Must be scalar.
202217
case GenericOps::SelectGop:
@@ -212,10 +227,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
212227
// Supports only 32 and 64 bit Floats; There is support for extended too
213228
// but ignore this for now.
214229
if (!(bitWidth == 32 || bitWidth == 64 ||
215-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
230+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
216231
return UNSUPPORTED;
217232
// Now we have a supported length, test for specific operations.
218-
switch (Gop) {
233+
switch (genOp) {
219234
case GenericOps::AbsGop:
220235
case GenericOps::ArithmeticGop: /* Add/sub,... */
221236
case GenericOps::CeilGop:
@@ -237,7 +252,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
237252
}
238253
}
239254
// Support for integer (we consider bit-wide ops as byte wide ops).
240-
switch (Gop) {
255+
switch (genOp) {
241256
// 1 - 16 byte operations.
242257
case GenericOps::ArithmeticGop: /* Add/sub,... */
243258
case GenericOps::ConversionGop:
@@ -276,13 +291,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
276291
// =============================================================================
277292

278293
int64_t NeonVectorMachineSupport::computeArchVectorLength(
279-
GenericOps Gop, Type elementType) {
294+
GenericOps genOp, Type elementType) {
295+
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
280296
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
281297
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
282298
bool isFloat = mlir::isa<FloatType>(elementType);
283299

284300
// Support shared between int and float.
285-
switch (Gop) {
301+
switch (genOp) {
286302
case GenericOps::ScalarOnlyGop:
287303
return 1; // Must be scalar.
288304
case GenericOps::SelectGop:
@@ -297,10 +313,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
297313
if (isFloat) {
298314
// Supports only 32 and 64 bit Floats;
299315
if (!(bitWidth == 32 || bitWidth == 64 ||
300-
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
316+
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
301317
return UNSUPPORTED;
302318
// Now we have a supported length, test for specific operations.
303-
switch (Gop) {
319+
switch (genOp) {
304320
case GenericOps::AbsGop:
305321
case GenericOps::ArithmeticGop: /* Add/sub,... */
306322
case GenericOps::CeilGop:
@@ -322,7 +338,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
322338
}
323339
}
324340
// Support for integer (we consider bit-wide ops as byte wide ops).
325-
switch (Gop) {
341+
switch (genOp) {
326342
// 1 - 16 byte operations.
327343
case GenericOps::ArithmeticGop: /* Add/sub,... */
328344
case GenericOps::ConversionGop:
@@ -370,10 +386,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
370386
for (auto pair : mix1) {
371387
GenericOps genOp = pair.first;
372388
int64_t num = pair.second;
373-
if (u.find(genOp) != u.end())
374-
u[genOp] += num; // Has this op already, add to it.
375-
else
389+
if (u.find(genOp) != u.end()) {
390+
// Merge the 2 operation counts/metrics.
391+
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
392+
// For register pressure, pick the max of both.
393+
u[genOp] = std::max(u[genOp], num);
394+
} else {
395+
// For operation count, use the sum of both
396+
u[genOp] += num;
397+
}
398+
} else {
399+
// First time we have this.
376400
u[genOp] = num;
401+
}
377402
}
378403
return u;
379404
}

src/Dialect/Mlir/VectorMachineSupport.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ namespace onnx_mlir {
3232
// (e.g. all the compares).
3333

3434
enum class GenericOps {
35+
/////////////////////////////////////
36+
// Generic ops.
37+
/////////////////////////////////////
38+
3539
AbsGop,
3640
ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */
3741
CeilDivGop,
@@ -62,6 +66,17 @@ enum class GenericOps {
6266
TrigArcGop, /* Arc trigonometry ops: asin, acos, atan. */
6367
TrigGop, /* Trigonometry ops: sin, cos, tan. */
6468
TrigHyperbolicGop, /* Hyperbolic trig. */
69+
70+
LastGop, /* Marker of the last op. Used to delineate from other metrics. */
71+
72+
/////////////////////////////////////
73+
// Metrics others than operations.
74+
/////////////////////////////////////
75+
76+
// Metric that provides an estimate of the maximum number of vector registers
77+
// used in a kernel. If none is provided, we estimate the pressure based on
78+
// the number of operations.
79+
EstimatedVectorRegisterPressure,
6580
};
6681

6782
// Describe the mix of Generic operations in a given kernel. Each generic
@@ -132,8 +147,12 @@ class VectorMachineSupport {
132147
// number of times that generic operation was found. Note that scalar
133148
// operation have a vector length of one in the weighted average as they still
134149
// contribute one result.
150+
// Max vector register pressure is also reported, either from an explicit
151+
// mention in the genOps, or estimated as one vector register per vector
152+
// operation.
135153
static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType,
136-
int64_t &vectorizedOpNum, int64_t &scalarOpNum);
154+
int64_t &vectorizedOpNum, int64_t &scalarOpNum,
155+
int64_t &maxVectorRegisterPressure);
137156

138157
protected:
139158
// Virtual functions that do the actual work. Called by the "get" functions.

0 commit comments

Comments
 (0)