@@ -78,21 +78,31 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
78
78
}
79
79
80
80
/* static*/ double VectorMachineSupport::getAvgArchVectorLength (GenOpMix &genOps,
81
- Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
81
+ Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
82
+ int64_t &maxVectorRegisterPressure) {
82
83
int64_t size = genOps.size ();
83
84
if (!hasSimd ()) {
84
- vectorizedOpNum = 0 ;
85
+ vectorizedOpNum = maxVectorRegisterPressure = 0 ;
85
86
scalarOpNum = size;
86
87
return 1 ;
87
88
}
88
89
int64_t totProcessedValues = 0.0 ;
89
- vectorizedOpNum = 0 ;
90
+ vectorizedOpNum = maxVectorRegisterPressure = 0 ;
90
91
scalarOpNum = 0 ;
92
+ bool hasRegisterPressure = false ;
93
+
91
94
// Determine which operations support SIMD and accumulate their vector
92
95
// lengths.
93
96
for (auto pair : genOps) {
94
97
GenericOps genOp = pair.first ;
95
98
int64_t num = pair.second ;
99
+ // Handle other metrics first.
100
+ if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
101
+ maxVectorRegisterPressure = std::max (maxVectorRegisterPressure, num);
102
+ hasRegisterPressure = true ;
103
+ continue ;
104
+ }
105
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
96
106
int64_t vl = getArchVectorLength (genOp, elementType);
97
107
// If past last value, assume 1; otherwise use actual value.
98
108
// Accumulate weighted scalar/vectorized num and vl length.
@@ -107,6 +117,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
107
117
// Compute final values
108
118
int64_t totNum = vectorizedOpNum + scalarOpNum;
109
119
scalarOpNum = size - vectorizedOpNum;
120
+ if (!hasRegisterPressure) {
121
+ // Estimate default register pressure as one per 2 vector operation.
122
+ maxVectorRegisterPressure = std::max (vectorizedOpNum / 2 , (int64_t )1 );
123
+ }
110
124
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0 ;
111
125
}
112
126
@@ -115,13 +129,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
115
129
// =============================================================================
116
130
117
131
int64_t Z16VectorMachineSupport::computeArchVectorLength (
118
- GenericOps Gop, Type elementType) {
132
+ GenericOps genOp, Type elementType) {
133
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
119
134
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
120
135
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
121
136
bool isFloat = mlir::isa<FloatType>(elementType);
122
-
123
137
// Support shared between int and float.
124
- switch (Gop ) {
138
+ switch (genOp ) {
125
139
case GenericOps::ScalarOnlyGop:
126
140
return 1 ; // Must be scalar.
127
141
case GenericOps::SelectGop:
@@ -137,10 +151,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
137
151
// Supports only 32 and 64 bit Floats; There is support for extended too
138
152
// but ignore this for now.
139
153
if (!(bitWidth == 32 || bitWidth == 64 ||
140
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
154
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
141
155
return UNSUPPORTED;
142
156
// Now we have a supported length, test for specific operations.
143
- switch (Gop ) {
157
+ switch (genOp ) {
144
158
case GenericOps::AbsGop: /* Supported via compare and select */
145
159
case GenericOps::ArithmeticGop: /* Add/sub,... */
146
160
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
@@ -161,7 +175,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
161
175
}
162
176
}
163
177
// Support for integer (we consider bit-wide ops as byte wide ops).
164
- switch (Gop ) {
178
+ switch (genOp ) {
165
179
// 1 - 16 byte operations.
166
180
case GenericOps::ArithmeticGop: /* Add/sub,... */
167
181
case GenericOps::ConversionGop:
@@ -190,13 +204,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
190
204
// =============================================================================
191
205
192
206
int64_t SSE42x86VectorMachineSupport::computeArchVectorLength (
193
- GenericOps Gop, Type elementType) {
207
+ GenericOps genOp, Type elementType) {
208
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
194
209
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
195
210
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
196
211
bool isFloat = mlir::isa<FloatType>(elementType);
197
212
198
213
// Support shared between int and float.
199
- switch (Gop ) {
214
+ switch (genOp ) {
200
215
case GenericOps::ScalarOnlyGop:
201
216
return 1 ; // Must be scalar.
202
217
case GenericOps::SelectGop:
@@ -212,10 +227,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
212
227
// Supports only 32 and 64 bit Floats; There is support for extended too
213
228
// but ignore this for now.
214
229
if (!(bitWidth == 32 || bitWidth == 64 ||
215
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
230
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
216
231
return UNSUPPORTED;
217
232
// Now we have a supported length, test for specific operations.
218
- switch (Gop ) {
233
+ switch (genOp ) {
219
234
case GenericOps::AbsGop:
220
235
case GenericOps::ArithmeticGop: /* Add/sub,... */
221
236
case GenericOps::CeilGop:
@@ -237,7 +252,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
237
252
}
238
253
}
239
254
// Support for integer (we consider bit-wide ops as byte wide ops).
240
- switch (Gop ) {
255
+ switch (genOp ) {
241
256
// 1 - 16 byte operations.
242
257
case GenericOps::ArithmeticGop: /* Add/sub,... */
243
258
case GenericOps::ConversionGop:
@@ -276,13 +291,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
276
291
// =============================================================================
277
292
278
293
int64_t NeonVectorMachineSupport::computeArchVectorLength (
279
- GenericOps Gop, Type elementType) {
294
+ GenericOps genOp, Type elementType) {
295
+ assert (genOp < GenericOps::LastGop && " no metrics here, only genOps" );
280
296
int64_t bitWidth = elementType.getIntOrFloatBitWidth ();
281
297
int64_t archVL = VectorMachineSupport::getArchVectorLength (elementType);
282
298
bool isFloat = mlir::isa<FloatType>(elementType);
283
299
284
300
// Support shared between int and float.
285
- switch (Gop ) {
301
+ switch (genOp ) {
286
302
case GenericOps::ScalarOnlyGop:
287
303
return 1 ; // Must be scalar.
288
304
case GenericOps::SelectGop:
@@ -297,10 +313,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
297
313
if (isFloat) {
298
314
// Supports only 32 and 64 bit Floats;
299
315
if (!(bitWidth == 32 || bitWidth == 64 ||
300
- (bitWidth == 16 && Gop == GenericOps::ConversionGop)))
316
+ (bitWidth == 16 && genOp == GenericOps::ConversionGop)))
301
317
return UNSUPPORTED;
302
318
// Now we have a supported length, test for specific operations.
303
- switch (Gop ) {
319
+ switch (genOp ) {
304
320
case GenericOps::AbsGop:
305
321
case GenericOps::ArithmeticGop: /* Add/sub,... */
306
322
case GenericOps::CeilGop:
@@ -322,7 +338,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
322
338
}
323
339
}
324
340
// Support for integer (we consider bit-wide ops as byte wide ops).
325
- switch (Gop ) {
341
+ switch (genOp ) {
326
342
// 1 - 16 byte operations.
327
343
case GenericOps::ArithmeticGop: /* Add/sub,... */
328
344
case GenericOps::ConversionGop:
@@ -370,10 +386,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
370
386
for (auto pair : mix1) {
371
387
GenericOps genOp = pair.first ;
372
388
int64_t num = pair.second ;
373
- if (u.find (genOp) != u.end ())
374
- u[genOp] += num; // Has this op already, add to it.
375
- else
389
+ if (u.find (genOp) != u.end ()) {
390
+ // Merge the 2 operation counts/metrics.
391
+ if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
392
+ // For register pressure, pick the max of both.
393
+ u[genOp] = std::max (u[genOp], num);
394
+ } else {
395
+ // For operation count, use the sum of both
396
+ u[genOp] += num;
397
+ }
398
+ } else {
399
+ // First time we have this.
376
400
u[genOp] = num;
401
+ }
377
402
}
378
403
return u;
379
404
}
0 commit comments