Skip to content

Commit 7f4c785

Browse files
authored
Merge branch 'main' into hamptonm/feature/four-convo-models
2 parents 1b45ed1 + 56a610c commit 7f4c785

File tree

3 files changed

+165
-62
lines changed

3 files changed

+165
-62
lines changed

src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp

+55-34
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
125125
IndexExpr T1 = outputDims[E1].ceilDiv(64);
126126
ubs[E1] = T1; // E1 dim is over tiles.
127127

128+
// Predicates used to avoid creating code that is never used.
129+
bool neverHas64 = outputDims[E1].isLiteralAndSmallerThan(64);
130+
bool neverHas8 = outputDims[E1].isLiteralAndSmallerThan(8);
131+
bool hasOnly64 =
132+
outputDims[E1].isLiteral() && (outputDims[E1].getLiteral() % 64 == 0);
133+
128134
// Parallel...
129135
if (enableParallel) {
130136
int64_t parId;
@@ -184,10 +190,16 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
184190

185191
// I may process here up to [e1 ... e1 + m*64), make sure its
186192
// not going out of bound, i.e. beyond outputDIms[E1];
193+
IndexExpr isFullLogical;
187194
IndexExpr ub1 = SymIE(outputDims[E1]);
188-
IndexExpr lit64Bis = LitIE(64);
189-
IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1);
190-
IndexExpr isFullLogical = isFull >= 0;
195+
if (hasOnly64) {
196+
isFullLogical = PredIE(true);
197+
} else if (neverHas64) {
198+
isFullLogical = PredIE(false);
199+
} else {
200+
IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1);
201+
isFullLogical = isFull >= 0;
202+
}
191203
create.scf.ifThenElse(
192204
// Condition
193205
isFullLogical.getValue(),
@@ -198,6 +210,9 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
198210
const int64_t unrollVL = 4;
199211
const int64_t totVL = unrollVL * archVL;
200212
assert(totVL <= 64 && "bad unroll");
213+
if (neverHas64)
214+
return; // Nothing to do here.
215+
201216
create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL,
202217
[&](const SCFBuilder b, ValueRange loopInd) {
203218
MDBuilder create(b);
@@ -206,7 +221,8 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
206221
IndexExpr l = DimIE(loopIndex);
207222
Value vecF16[unrollVL], vecF32H[unrollVL],
208223
vecF32L[unrollVL];
209-
// Load f16 values from input via reinterpreted data tile.
224+
// Load f16 values from input via reinterpreted data
225+
// tile.
210226
for (int64_t i = 0; i < unrollVL; ++i) {
211227
vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64,
212228
{SymIE(inputTileOffset), l + (i * archVL)}, {});
@@ -231,40 +247,45 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
231247
}
232248
});
233249
},
234-
// else, we don't have a full (64 e1) tile.
250+
// Else, we don't have a full (64 e1) tile.
235251
[&](SCFBuilder b) {
236252
MDBuilder create(b);
237253
IndexExprScope middleScope(b, &outerScope);
238254
IndexExpr tripCount = SymIE(ub1) - SymIE(e1);
239-
// Note: if we only have multiple of VL, loop below will handle
240-
// all as we subtract (VL-1). Aka if VL=8 and tripCount = 16,
241-
// tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate
242-
// over i=0 & i=8 as both are < 9.
243-
IndexExpr tripCountWithoutPartialLastVL =
244-
tripCount - (archVL - 1);
245-
create.scf.forLoop(litZero.getValue(),
246-
tripCountWithoutPartialLastVL.getValue(), archVL,
247-
[&](SCFBuilder b, ValueRange loopInd) {
248-
MDBuilder create(b);
249-
IndexExprScope innerScope(b, &middleScope);
250-
Value loopIndex = loopInd[0];
251-
IndexExpr l = DimIE(loopIndex);
252-
// Load f16 values from input via reinterpreted data tile.
253-
Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64,
254-
{SymIE(inputTileOffset), l}, {});
255-
// Convert back to f32.
256-
auto convertOp =
257-
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
258-
loc, vecF16);
259-
Value vecF32H = convertOp.getResult(0);
260-
Value vecF32L = convertOp.getResult(1);
261-
// Store f32 values back to the (normal layout) output.
262-
DimsExpr outputAF = SymListIE(inputAF);
263-
outputAF[E1] = outputAF[E1] + l;
264-
create.vec.storeIE(vecF32H, alloc, outputAF);
265-
create.vec.storeIE(
266-
vecF32L, alloc, outputAF, {litArchVLHalf.getValue()});
267-
});
255+
if (hasOnly64)
256+
return;
257+
if (!neverHas8) {
258+
// Note: if we only have multiple of VL, loop below will
259+
// handle all as we subtract (VL-1). Aka if VL=8 and tripCount
260+
// = 16, tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we
261+
// iterate over i=0 & i=8 as both are < 9.
262+
IndexExpr tripCountWithoutPartialLastVL =
263+
tripCount - (archVL - 1);
264+
create.scf.forLoop(litZero.getValue(),
265+
tripCountWithoutPartialLastVL.getValue(), archVL,
266+
[&](SCFBuilder b, ValueRange loopInd) {
267+
MDBuilder create(b);
268+
IndexExprScope innerScope(b, &middleScope);
269+
Value loopIndex = loopInd[0];
270+
IndexExpr l = DimIE(loopIndex);
271+
// Load f16 values from input via reinterpreted data
272+
// tile.
273+
Value vecF16 = create.vec.loadIE(vecF16Type,
274+
inputAsTx64, {SymIE(inputTileOffset), l}, {});
275+
// Convert back to f32.
276+
auto convertOp =
277+
rewriter.create<ZLowConvertDLF16ToF32VectorOp>(
278+
loc, vecF16);
279+
Value vecF32H = convertOp.getResult(0);
280+
Value vecF32L = convertOp.getResult(1);
281+
// Store f32 values back to the (normal layout) output.
282+
DimsExpr outputAF = SymListIE(inputAF);
283+
outputAF[E1] = outputAF[E1] + l;
284+
create.vec.storeIE(vecF32H, alloc, outputAF);
285+
create.vec.storeIE(vecF32L, alloc, outputAF,
286+
{litArchVLHalf.getValue()});
287+
});
288+
}
268289
// Deal with the last values: compute f32 using simd.
269290
IndexExpr remainingScalarValues = tripCount % archVL;
270291
IndexExpr lastL = tripCount - remainingScalarValues;

src/Dialect/Mlir/IndexExpr.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ class SymbolIndexExpr : public IndexExpr {
828828
//===----------------------------------------------------------------------===//
829829

830830
using LitIE = LiteralIndexExpr;
831+
using PredIE = PredicateIndexExpr;
831832
using SymIE = SymbolIndexExpr;
832833
using DimIE = DimIndexExpr;
833834

0 commit comments

Comments
 (0)