@@ -125,6 +125,12 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
125
125
IndexExpr T1 = outputDims[E1 ].ceilDiv (64 );
126
126
ubs[E1 ] = T1; // E1 dim is over tiles.
127
127
128
+ // Predicates used to avoid creating code that is never used.
129
+ bool neverHas64 = outputDims[E1 ].isLiteralAndSmallerThan (64 );
130
+ bool neverHas8 = outputDims[E1 ].isLiteralAndSmallerThan (8 );
131
+ bool hasOnly64 =
132
+ outputDims[E1 ].isLiteral () && (outputDims[E1 ].getLiteral () % 64 == 0 );
133
+
128
134
// Parallel...
129
135
if (enableParallel) {
130
136
int64_t parId;
@@ -184,10 +190,16 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
184
190
185
191
// I may process here up to [e1 ... e1 + m*64), make sure its
186
192
// not going out of bound, i.e. beyond outputDIms[E1];
193
+ IndexExpr isFullLogical;
187
194
IndexExpr ub1 = SymIE (outputDims[E1 ]);
188
- IndexExpr lit64Bis = LitIE (64 );
189
- IndexExpr isFull = create.krnlIE .isTileFull (e1 , lit64, ub1);
190
- IndexExpr isFullLogical = isFull >= 0 ;
195
+ if (hasOnly64) {
196
+ isFullLogical = PredIE (true );
197
+ } else if (neverHas64) {
198
+ isFullLogical = PredIE (false );
199
+ } else {
200
+ IndexExpr isFull = create.krnlIE .isTileFull (e1 , lit64, ub1);
201
+ isFullLogical = isFull >= 0 ;
202
+ }
191
203
create.scf .ifThenElse (
192
204
// Condition
193
205
isFullLogical.getValue (),
@@ -198,6 +210,9 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
198
210
const int64_t unrollVL = 4 ;
199
211
const int64_t totVL = unrollVL * archVL;
200
212
assert (totVL <= 64 && " bad unroll" );
213
+ if (neverHas64)
214
+ return ; // Nothing to do here.
215
+
201
216
create.scf .forLoop (litZero.getValue (), lit64.getValue (), totVL,
202
217
[&](const SCFBuilder b, ValueRange loopInd) {
203
218
MDBuilder create (b);
@@ -206,7 +221,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
206
221
IndexExpr l = DimIE (loopIndex);
207
222
Value vecF16[unrollVL], vecF32H[unrollVL],
208
223
vecF32L[unrollVL];
209
- // Load f16 values from input via reinterpreted data tile.
224
+ // Load f16 values from input via reinterpreted data
225
+ // tile.
210
226
for (int64_t i = 0 ; i < unrollVL; ++i) {
211
227
vecF16[i] = create.vec .loadIE (vecF16Type, inputAsTx64,
212
228
{SymIE (inputTileOffset), l + (i * archVL)}, {});
@@ -231,40 +247,45 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
231
247
}
232
248
});
233
249
},
234
- // else , we don't have a full (64 e1) tile.
250
+ // Else , we don't have a full (64 e1) tile.
235
251
[&](SCFBuilder b) {
236
252
MDBuilder create (b);
237
253
IndexExprScope middleScope (b, &outerScope);
238
254
IndexExpr tripCount = SymIE (ub1) - SymIE (e1 );
239
- // Note: if we only have multiple of VL, loop below will handle
240
- // all as we subtract (VL-1). Aka if VL=8 and tripCount = 16,
241
- // tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate
242
- // over i=0 & i=8 as both are < 9.
243
- IndexExpr tripCountWithoutPartialLastVL =
244
- tripCount - (archVL - 1 );
245
- create.scf .forLoop (litZero.getValue (),
246
- tripCountWithoutPartialLastVL.getValue (), archVL,
247
- [&](SCFBuilder b, ValueRange loopInd) {
248
- MDBuilder create (b);
249
- IndexExprScope innerScope (b, &middleScope);
250
- Value loopIndex = loopInd[0 ];
251
- IndexExpr l = DimIE (loopIndex);
252
- // Load f16 values from input via reinterpreted data tile.
253
- Value vecF16 = create.vec .loadIE (vecF16Type, inputAsTx64,
254
- {SymIE (inputTileOffset), l}, {});
255
- // Convert back to f32.
256
- auto convertOp =
257
- rewriter.create <ZLowConvertDLF16ToF32VectorOp>(
258
- loc, vecF16);
259
- Value vecF32H = convertOp.getResult (0 );
260
- Value vecF32L = convertOp.getResult (1 );
261
- // Store f32 values back to the (normal layout) output.
262
- DimsExpr outputAF = SymListIE (inputAF);
263
- outputAF[E1 ] = outputAF[E1 ] + l;
264
- create.vec .storeIE (vecF32H, alloc, outputAF);
265
- create.vec .storeIE (
266
- vecF32L, alloc, outputAF, {litArchVLHalf.getValue ()});
267
- });
255
+ if (hasOnly64)
256
+ return ;
257
+ if (!neverHas8) {
258
+ // Note: if we only have multiple of VL, loop below will
259
+ // handle all as we subtract (VL-1). Aka if VL=8 and tripCount
260
+ // = 16, tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we
261
+ // iterate over i=0 & i=8 as both are < 9.
262
+ IndexExpr tripCountWithoutPartialLastVL =
263
+ tripCount - (archVL - 1 );
264
+ create.scf .forLoop (litZero.getValue (),
265
+ tripCountWithoutPartialLastVL.getValue (), archVL,
266
+ [&](SCFBuilder b, ValueRange loopInd) {
267
+ MDBuilder create (b);
268
+ IndexExprScope innerScope (b, &middleScope);
269
+ Value loopIndex = loopInd[0 ];
270
+ IndexExpr l = DimIE (loopIndex);
271
+ // Load f16 values from input via reinterpreted data
272
+ // tile.
273
+ Value vecF16 = create.vec .loadIE (vecF16Type,
274
+ inputAsTx64, {SymIE (inputTileOffset), l}, {});
275
+ // Convert back to f32.
276
+ auto convertOp =
277
+ rewriter.create <ZLowConvertDLF16ToF32VectorOp>(
278
+ loc, vecF16);
279
+ Value vecF32H = convertOp.getResult (0 );
280
+ Value vecF32L = convertOp.getResult (1 );
281
+ // Store f32 values back to the (normal layout) output.
282
+ DimsExpr outputAF = SymListIE (inputAF);
283
+ outputAF[E1 ] = outputAF[E1 ] + l;
284
+ create.vec .storeIE (vecF32H, alloc, outputAF);
285
+ create.vec .storeIE (vecF32L, alloc, outputAF,
286
+ {litArchVLHalf.getValue ()});
287
+ });
288
+ }
268
289
// Deal with the last values: compute f32 using simd.
269
290
IndexExpr remainingScalarValues = tripCount % archVL;
270
291
IndexExpr lastL = tripCount - remainingScalarValues;
0 commit comments