@@ -78,21 +78,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
78
78
}
79
79
80
80
/* static*/ double VectorMachineSupport::getAvgArchVectorLength (GenOpMix &genOps,
81
- Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
81
+ Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
82
+ int64_t &maxVectorRegisterPressure) {
82
83
int64_t size = genOps.size ();
84
+ vectorizedOpNum = maxVectorRegisterPressure = 0 ;
83
85
if (!hasSimd ()) {
84
- vectorizedOpNum = 0 ;
85
86
scalarOpNum = size;
86
87
return 1 ;
87
88
}
88
89
int64_t totProcessedValues = 0.0 ;
89
- vectorizedOpNum = 0 ;
90
90
scalarOpNum = 0 ;
91
+ bool hasRegisterPressure = false ;
92
+
91
93
// Determine which operations support SIMD and accumulate their vector
92
94
// lengths.
93
95
for (auto pair : genOps) {
94
96
GenericOps genOp = pair.first ;
95
97
int64_t num = pair.second ;
98
+ // Handle other metrics first.
99
+ if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
100
+ maxVectorRegisterPressure = std::max (maxVectorRegisterPressure, num);
101
+ hasRegisterPressure = true ;
102
+ continue ;
103
+ }
104
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
96
105
int64_t vl = getArchVectorLength (genOp, elementType);
97
106
// If past last value, assume 1; otherwise use actual value.
98
107
// Accumulate weighted scalar/vectorized num and vl length.
@@ -106,7 +115,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
106
115
}
107
116
// Compute final values
108
117
int64_t totNum = vectorizedOpNum + scalarOpNum;
109
- scalarOpNum = size - vectorizedOpNum;
118
+ if (!hasRegisterPressure) {
119
+ // Estimate default register pressure as one per 2 vector operation.
120
+ maxVectorRegisterPressure = std::max (vectorizedOpNum / 2 , (int64_t )1 );
121
+ }
110
122
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0 ;
111
123
}
112
124
@@ -115,13 +127,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
115
127
// =============================================================================
116
128
117
129
int64_t Z16VectorMachineSupport::computeArchVectorLength (
118
- GenericOps Gop, Type elementType) {
130
+ GenericOps genOp, Type elementType) {
131
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
119
132
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
120
133
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
121
134
bool isFloat = mlir::isa<FloatType>(elementType);
122
-
123
135
// Support shared between int and float.
124
- switch (Gop ) {
136
+ switch (genOp ) {
125
137
case GenericOps::ScalarOnlyGop:
126
138
return 1 ; // Must be scalar.
127
139
case GenericOps::SelectGop:
@@ -137,10 +149,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
137
149
// Supports only 32 and 64 bit Floats; There is support for extended too
138
150
// but ignore this for now.
139
151
if (!(bitWidth == 32 || bitWidth == 64 ||
140
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
152
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
141
153
return UNSUPPORTED;
142
154
// Now we have a supported length, test for specific operations.
143
- switch (Gop ) {
155
+ switch (genOp ) {
144
156
case GenericOps::AbsGop: /* Supported via compare and select */
145
157
case GenericOps::ArithmeticGop: /* Add/sub,... */
146
158
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
@@ -161,7 +173,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
161
173
}
162
174
}
163
175
// Support for integer (we consider bit-wide ops as byte wide ops).
164
- switch (Gop ) {
176
+ switch (genOp ) {
165
177
// 1 - 16 byte operations.
166
178
case GenericOps::ArithmeticGop: /* Add/sub,... */
167
179
case GenericOps::ConversionGop:
@@ -190,13 +202,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
190
202
// =============================================================================
191
203
192
204
int64_t SSE42x86VectorMachineSupport::computeArchVectorLength (
193
- GenericOps Gop, Type elementType) {
205
+ GenericOps genOp, Type elementType) {
206
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
194
207
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
195
208
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
196
209
bool isFloat = mlir::isa<FloatType>(elementType);
197
210
198
211
// Support shared between int and float.
199
- switch (Gop ) {
212
+ switch (genOp ) {
200
213
case GenericOps::ScalarOnlyGop:
201
214
return 1 ; // Must be scalar.
202
215
case GenericOps::SelectGop:
@@ -212,10 +225,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
212
225
// Supports only 32 and 64 bit Floats; There is support for extended too
213
226
// but ignore this for now.
214
227
if (!(bitWidth == 32 || bitWidth == 64 ||
215
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
228
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
216
229
return UNSUPPORTED;
217
230
// Now we have a supported length, test for specific operations.
218
- switch (Gop ) {
231
+ switch (genOp ) {
219
232
case GenericOps::AbsGop:
220
233
case GenericOps::ArithmeticGop: /* Add/sub,... */
221
234
case GenericOps::CeilGop:
@@ -237,7 +250,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
237
250
}
238
251
}
239
252
// Support for integer (we consider bit-wide ops as byte wide ops).
240
- switch (Gop ) {
253
+ switch (genOp ) {
241
254
// 1 - 16 byte operations.
242
255
case GenericOps::ArithmeticGop: /* Add/sub,... */
243
256
case GenericOps::ConversionGop:
@@ -276,13 +289,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
276
289
// =============================================================================
277
290
278
291
int64_t NeonVectorMachineSupport::computeArchVectorLength (
279
- GenericOps Gop, Type elementType) {
292
+ GenericOps genOp, Type elementType) {
293
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
280
294
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
281
295
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
282
296
bool isFloat = mlir::isa<FloatType>(elementType);
283
297
284
298
// Support shared between int and float.
285
- switch (Gop ) {
299
+ switch (genOp ) {
286
300
case GenericOps::ScalarOnlyGop:
287
301
return 1 ; // Must be scalar.
288
302
case GenericOps::SelectGop:
@@ -297,10 +311,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
297
311
if (isFloat) {
298
312
// Supports only 32 and 64 bit Floats;
299
313
if (!(bitWidth == 32 || bitWidth == 64 ||
300
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
314
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
301
315
return UNSUPPORTED;
302
316
// Now we have a supported length, test for specific operations.
303
- switch (Gop ) {
317
+ switch (genOp ) {
304
318
case GenericOps::AbsGop:
305
319
case GenericOps::ArithmeticGop: /* Add/sub,... */
306
320
case GenericOps::CeilGop:
@@ -322,7 +336,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
322
336
}
323
337
}
324
338
// Support for integer (we consider bit-wide ops as byte wide ops).
325
- switch (Gop ) {
339
+ switch (genOp ) {
326
340
// 1 - 16 byte operations.
327
341
case GenericOps::ArithmeticGop: /* Add/sub,... */
328
342
case GenericOps::ConversionGop:
@@ -370,10 +384,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
370
384
for (auto pair : mix1) {
371
385
GenericOps genOp = pair.first ;
372
386
int64_t num = pair.second ;
373
- if (u.find (genOp) != u.end ())
374
- u[genOp] += num; // Has this op already, add to it.
375
- else
387
+ if (u.find (genOp) != u.end ()) {
388
+ // Merge the 2 operation counts/metrics.
389
+ if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
390
+ // For register pressure, pick the max of both.
391
+ u[genOp] = std::max (u[genOp], num);
392
+ } else {
393
+ // For operation count, use the sum of both
394
+ u[genOp] += num;
395
+ }
396
+ } else {
397
+ // First time we have this.
376
398
u[genOp] = num;
399
+ }
377
400
}
378
401
return u;
379
402
}
0 commit comments