Skip to content

Commit 6d53c6a

Browse files
committed
Merge branch 'main' into pr_support_softplus_on_nnpa
2 parents 7fcff74 + ee7eaca commit 6d53c6a

18 files changed

+893
-413
lines changed

src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp

Lines changed: 258 additions & 67 deletions
Large diffs are not rendered by default.

src/Conversion/KrnlToAffine/KrnlTerminator.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,23 @@ class KrnlTerminatorLowering : public ConversionPattern {
4242
}
4343
};
4444

45+
class KrnlYieldLowering : public ConversionPattern {
46+
public:
47+
explicit KrnlYieldLowering(TypeConverter &typeConverter, MLIRContext *context)
48+
: ConversionPattern(
49+
typeConverter, KrnlYieldOp::getOperationName(), 1, context) {}
50+
51+
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
52+
ConversionPatternRewriter &rewriter) const override {
53+
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(op, op->getOperands());
54+
return success();
55+
}
56+
};
57+
4558
void populateLoweringKrnlTerminatorOpPattern(TypeConverter &typeConverter,
4659
RewritePatternSet &patterns, MLIRContext *ctx) {
47-
patterns.insert<KrnlTerminatorLowering>(typeConverter, ctx);
60+
patterns.insert<KrnlTerminatorLowering, KrnlYieldLowering>(
61+
typeConverter, ctx);
4862
}
4963

5064
} // namespace krnl

src/Conversion/ONNXToKrnl/Math/MatMul.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ struct ONNXMatMulOpLowering : public OpConversionPattern<ONNXMatMulOp> {
8484
[&](KrnlBuilder &createKrnl, ValueRange outerIndices) {
8585
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
8686
createKrnl);
87-
// Single scalar, no need for default alignment.
88-
Value reductionVal =
89-
create.mem.alignedAlloca(MemRefType::get({}, elementType));
90-
create.krnl.store(fZero, reductionVal);
87+
88+
ValueRange inits = ValueRange(fZero);
9189
// Inner loop for reduction.
92-
create.krnl.iterate({}, innerLoop, {}, {},
93-
[&](KrnlBuilder &createKrnl, ValueRange innerIndex) {
90+
auto innerIterate = create.krnl.iterate({}, innerLoop, {}, {}, inits,
91+
[&](KrnlBuilder &createKrnl, ValueRange innerIndex,
92+
ValueRange iterArgs) {
93+
// Get last argument for the iterate body.
94+
Value iterArg = iterArgs.back();
95+
9496
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(
9597
createKrnl);
9698
Value k = innerIndex[0];
@@ -128,13 +130,16 @@ struct ONNXMatMulOpLowering : public OpConversionPattern<ONNXMatMulOp> {
128130
create.krnl.load(operandAdaptor.getA(), aAccessFct);
129131
Value loadedB =
130132
create.krnl.load(operandAdaptor.getB(), bAccessFct);
131-
Value loadedY = create.krnl.load(reductionVal);
133+
Value loadedY = iterArg;
132134
Value AB = create.math.mul(loadedA, loadedB);
133135
Value accumulated = create.math.add(loadedY, AB);
134-
create.krnl.store(accumulated, reductionVal);
136+
// Create yield.
137+
create.krnl.yield(accumulated);
135138
});
136-
Value accumulated = create.krnl.load(reductionVal);
139+
Value accumulated = innerIterate.getResult(0);
137140
create.krnl.store(accumulated, alloc, outerIndices);
141+
// Create yield.
142+
create.krnl.yield({});
138143
});
139144
}
140145

src/Conversion/ONNXToKrnl/Math/Softmax.cpp

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ namespace onnx_mlir {
2323

2424
static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
2525
SmallVectorImpl<IndexExpr> &Lbs, SmallVectorImpl<IndexExpr> &Ubs,
26-
ValueRange outerIndices, Value input, Value alloc, Value sumOp, Value maxOp,
27-
int64_t axis, bool coerced = true) {
26+
ValueRange outerIndices, Value input, Value alloc, Value zero,
27+
Value negInfinity, int64_t axis, bool coerced = true) {
2828
int64_t rank = alloc.getType().cast<MemRefType>().getRank();
2929

30+
ValueRange maxInits = ValueRange(negInfinity);
3031
// Compute the maximum value along axis.
3132
ValueRange maxLoops = createKrnl.defineLoops(numberOfLoops);
32-
createKrnl.iterateIE(maxLoops, maxLoops, Lbs, Ubs,
33-
[&](KrnlBuilder &createKrnl, ValueRange maxIndices) {
33+
auto maxLoop = createKrnl.iterateIE(maxLoops, maxLoops, Lbs, Ubs, maxInits,
34+
[&](KrnlBuilder &createKrnl, ValueRange maxIndices, ValueRange iterArgs) {
35+
// Get last argument for the iterate body.
36+
Value iterArg = iterArgs.back();
37+
3438
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
3539
IndexExprScope ieScope(createKrnl);
3640

@@ -49,19 +53,24 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
4953
maxLoopIVs.push_back(outerIndices[i - 1]);
5054
}
5155

52-
Value max = create.krnl.load(maxOp, {});
56+
Value max = iterArg;
5357
Value nextMax = create.krnl.load(input, maxLoopIVs);
5458
auto maxCond = create.math.sgt(max, nextMax);
5559
max = create.math.select(maxCond, max, nextMax);
56-
create.krnl.store(max, maxOp, ArrayRef<Value>{});
60+
61+
create.krnl.yield(max);
5762
});
58-
// Load the maximum value.
59-
Value max = createKrnl.load(maxOp, {});
63+
// Get the maximum value.
64+
Value max = maxLoop.getResult(0);
6065

66+
ValueRange sumInits = ValueRange(zero);
6167
// Compute the sum of all values along axis.
6268
ValueRange sumLoops = createKrnl.defineLoops(numberOfLoops);
63-
createKrnl.iterateIE(sumLoops, sumLoops, Lbs, Ubs,
64-
[&](KrnlBuilder &createKrnl, ValueRange sumIndices) {
69+
auto sumLoop = createKrnl.iterateIE(sumLoops, sumLoops, Lbs, Ubs, sumInits,
70+
[&](KrnlBuilder &createKrnl, ValueRange sumIndices, ValueRange iterArgs) {
71+
// Get last argument for the iterate body.
72+
Value iterArg = iterArgs.back();
73+
6574
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
6675
IndexExprScope ieScope(createKrnl);
6776

@@ -80,19 +89,19 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
8089
sumLoopIVs.push_back(outerIndices[i - 1]);
8190
}
8291

83-
Value sum = create.krnl.load(sumOp, {});
92+
Value sum = iterArg;
8493
Value next = create.krnl.load(input, sumLoopIVs);
8594
Value sub = create.math.sub(next, max);
8695
Value exp = create.math.exp(sub);
8796
sum = create.math.add(sum, exp);
88-
create.krnl.store(sum, sumOp, ArrayRef<Value>{});
8997
// Store intermediate values in the result to avoid
9098
// recomputation.
9199
create.krnl.store(exp, alloc, sumLoopIVs);
100+
create.krnl.yield(sum);
92101
});
93102

94103
// Load the sum value.
95-
Value sum = createKrnl.load(sumOp, {});
104+
Value sum = sumLoop.getResult(0);
96105

97106
// Compute the softmax.
98107
ValueRange softmaxLoops = createKrnl.defineLoops(numberOfLoops);
@@ -124,16 +133,14 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
124133

125134
template <typename T>
126135
void emitInstForSoftmax(ConversionPatternRewriter &rewriter, Operation *op,
127-
Location loc, Value alloc, Value input, MemRefType scalarMemRefType,
128-
Value sumOp, Value maxOp, Value zero, Value negInfinity, int64_t axis,
129-
bool enableParallel) = delete;
136+
Location loc, Value alloc, Value input, Value zero, Value negInfinity,
137+
int64_t axis, bool enableParallel) = delete;
130138

131139
// For Softmax opset < 13, `axis` is the coerced point. All dimensions
132140
// after `axis` will be logically coerced into a single dimension.
133141
template <>
134142
void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
135-
Operation *op, Location loc, Value alloc, Value input,
136-
MemRefType scalarMemRefType, Value sumOp, Value maxOp, Value zero,
143+
Operation *op, Location loc, Value alloc, Value input, Value zero,
137144
Value negInfinity, int64_t axis, bool enableParallel) {
138145
int64_t rank = alloc.getType().cast<MemRefType>().getRank();
139146

@@ -151,18 +158,15 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
151158
if (axis == 0) {
152159
assert(!enableParallel && "only outer loop parallelism at this time");
153160
// There is no need having outer loops.
154-
// Reset accumulators.
155-
create.krnl.store(zero, sumOp, ArrayRef<Value>{});
156-
create.krnl.store(negInfinity, maxOp, ArrayRef<Value>{});
157161

158162
// Common information to create nested loops.
159163
int64_t numberOfLoops = rank;
160164
SmallVector<IndexExpr, 4> Lbs(numberOfLoops, zeroIE);
161165
SmallVector<IndexExpr, 4> Ubs;
162166
create.krnlIE.getShapeAsDims(input, Ubs);
163167

164-
emitInnerLoops(create.krnl, numberOfLoops, Lbs, Ubs, {}, input, alloc,
165-
sumOp, maxOp, axis, /*coerced=*/true);
168+
emitInnerLoops(create.krnl, numberOfLoops, Lbs, Ubs, {}, input, alloc, zero,
169+
negInfinity, axis, /*coerced=*/true);
166170
} else {
167171
// Define outer loops.
168172
ValueRange outerLoops = create.krnl.defineLoops(axis);
@@ -183,16 +187,6 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
183187
create(ck);
184188
IndexExprScope ieScope(ck);
185189

186-
if (enableParallel) {
187-
// Temporary results must be private when parallel. Use alloca here
188-
// as scalars are small.
189-
sumOp = create.mem.alignedAlloca(scalarMemRefType);
190-
maxOp = create.mem.alignedAlloca(scalarMemRefType);
191-
}
192-
// Reset accumulators.
193-
create.krnl.store(zero, sumOp, ArrayRef<Value>{});
194-
create.krnl.store(negInfinity, maxOp, ArrayRef<Value>{});
195-
196190
// Common information to create inner nested loops.
197191
int64_t numberOfLoops = rank - axis;
198192
SmallVector<IndexExpr, 4> Lbs(numberOfLoops, zeroIE);
@@ -202,7 +196,7 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
202196

203197
// Emit the inner loops.
204198
emitInnerLoops(create.krnl, numberOfLoops, Lbs, Ubs, outerIndices,
205-
input, alloc, sumOp, maxOp, axis, /*coerced=*/true);
199+
input, alloc, zero, negInfinity, axis, /*coerced=*/true);
206200
});
207201
}
208202
}
@@ -212,8 +206,7 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
212206
// `axis`.
213207
template <>
214208
void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
215-
Operation *op, Location loc, Value alloc, Value input,
216-
MemRefType scalarMemRefType, Value sumOp, Value maxOp, Value zero,
209+
Operation *op, Location loc, Value alloc, Value input, Value zero,
217210
Value negInfinity, int64_t axis, bool enableParallel) {
218211
int64_t rank = alloc.getType().cast<MemRefType>().getRank();
219212

@@ -246,17 +239,6 @@ void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
246239
create(ck);
247240
IndexExprScope ieScope(ck);
248241

249-
if (enableParallel) {
250-
// Temporary results must be private when parallel. Use alloca here as
251-
// scalars are small.
252-
sumOp = create.mem.alignedAlloca(scalarMemRefType);
253-
maxOp = create.mem.alignedAlloca(scalarMemRefType);
254-
}
255-
256-
// Reset accumulators.
257-
create.krnl.store(zero, sumOp, ArrayRef<Value>{});
258-
create.krnl.store(negInfinity, maxOp, ArrayRef<Value>{});
259-
260242
// Common information to create inner nested loops for axis only.
261243
int64_t numberOfLoops = 1;
262244
SmallVector<IndexExpr, 4> Lbs(numberOfLoops, zeroIE);
@@ -265,7 +247,7 @@ void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
265247

266248
// Emit the inner loops.
267249
emitInnerLoops(create.krnl, numberOfLoops, Lbs, Ubs, outerIndices,
268-
input, alloc, sumOp, maxOp, axis, /*coerced=*/false);
250+
input, alloc, zero, negInfinity, axis, /*coerced=*/false);
269251
});
270252
}
271253

@@ -316,22 +298,12 @@ struct ONNXSoftmaxLowering : public OpConversionPattern<SoftmaxOp> {
316298
MultiDialectBuilder<MemRefBuilder, MathBuilder> create(rewriter, loc);
317299
Value alloc = create.mem.alignedAlloc(input, memRefType);
318300

319-
// Insert allocations and deallocations for sum and max.
320-
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
321-
Value sumOp, maxOp;
322-
if (!enableParallelLocal) {
323-
// Temporary results must be private when parallel.
324-
sumOp = create.mem.alignedAlloc(scalarMemRefType);
325-
maxOp = create.mem.alignedAlloc(scalarMemRefType);
326-
}
327-
328301
Value zero = create.math.constant(elementType, 0);
329302
Value negInfinity = create.math.constant(
330303
elementType, -std::numeric_limits<float>::infinity());
331304

332-
emitInstForSoftmax<SoftmaxOp>(rewriter, op, loc, alloc, input,
333-
scalarMemRefType, sumOp, maxOp, zero, negInfinity, axis,
334-
enableParallelLocal);
305+
emitInstForSoftmax<SoftmaxOp>(rewriter, op, loc, alloc, input, zero,
306+
negInfinity, axis, enableParallelLocal);
335307

336308
rewriter.replaceOp(op, alloc);
337309
onnxToKrnlSimdReport(op);

src/Conversion/ONNXToKrnl/NN/Conv.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,7 @@ struct ONNXConvOpLowering : public OpConversionPattern<ONNXConvOp> {
9090
// for coPerGroup = 0 .. COPerGroup:
9191
// co = g * COPerGroup + coPerGroup;
9292

93-
// Create a local reduction value.
94-
MemRefType tmpType = MemRefType::get({}, memRefType.getElementType());
9593
auto bodyFunction = [&](ValueRange outerIndices) {
96-
// Single scalar, no need for default alignment.
97-
Value reductionVal = create.mem.alloca(tmpType);
9894
// Compute the Channel In Indices.
9995
IndexExprScope outerScope(create.krnl);
10096
// Compute the channel out index "co".
@@ -122,8 +118,8 @@ struct ONNXConvOpLowering : public OpConversionPattern<ONNXConvOp> {
122118
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl,
123119
MathBuilder>
124120
create(createKrnl);
125-
// Reset reduction value to zero.
126-
create.krnl.store(fZero, reductionVal);
121+
122+
ValueRange inits = ValueRange(fZero);
127123

128124
// Bounds for reduction loops.
129125
ValueRange redLoops = create.krnl.defineLoops(spacialRank + 1);
@@ -158,51 +154,55 @@ struct ONNXConvOpLowering : public OpConversionPattern<ONNXConvOp> {
158154
// for ciPerGroup = 0 .. CIPerGroup:
159155
// for kh in lb .. ub:
160156
// for kw in lb .. ub:
161-
create.krnl.iterateIE(redLoops, redLoops, redLbs, redUbs,
162-
[&](KrnlBuilder &createKrnl, ValueRange redIndices) {
163-
IndexExprScope redScope(createKrnl);
164-
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl,
165-
MathBuilder>
166-
create(createKrnl);
167-
// Create access function for input image:
168-
// [n, ci, ho * sh + kh * dh - ph, wo * sw + kw * dw -
169-
// pw].
170-
SmallVector<IndexExpr, 4> inputAccessFct;
171-
DimIndexExpr n(outerIndices[0]);
172-
inputAccessFct.emplace_back(n);
173-
// ci = g * CIPerG + ciPerG
174-
DimIndexExpr ciPerG(redIndices[0]);
175-
IndexExpr ci = SymbolIndexExpr(gTimesCIPerGroup) + ciPerG;
176-
inputAccessFct.emplace_back(ci);
177-
for (int i = 0; i < spacialRank; ++i) {
178-
// for each spacial dims: access is o * s + k * d - p.
179-
DimIndexExpr k(redIndices[1 + i]);
180-
SymbolIndexExpr pos(pMinOS[i]);
181-
LiteralIndexExpr d(shapeHelper.dilations[i]);
182-
// k*d - (p - o*s) = k*d + o*s - p
183-
IndexExpr t = (k * d) - pos;
184-
inputAccessFct.emplace_back(t);
185-
}
186-
Value image =
187-
create.krnl.loadIE(inputOperand, inputAccessFct);
188-
// Create access fct for filter: [co, ciPerG, kh, kw].
189-
SmallVector<IndexExpr, 4> filterAccessFct;
190-
filterAccessFct.emplace_back(DimIndexExpr(co));
191-
filterAccessFct.emplace_back(DimIndexExpr(ciPerG));
157+
auto innerIterate =
158+
create.krnl.iterateIE(redLoops, redLoops, redLbs, redUbs, inits,
159+
[&](KrnlBuilder &createKrnl, ValueRange redIndices,
160+
ValueRange iterArgs) {
161+
// Get last argument for the iterate body.
162+
Value iterArg = iterArgs.back();
163+
IndexExprScope redScope(createKrnl);
164+
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl,
165+
MathBuilder>
166+
create(createKrnl);
167+
// Create access function for input image:
168+
// [n, ci, ho * sh + kh * dh - ph, wo * sw + kw * dw -
169+
// pw].
170+
SmallVector<IndexExpr, 4> inputAccessFct;
171+
DimIndexExpr n(outerIndices[0]);
172+
inputAccessFct.emplace_back(n);
173+
// ci = g * CIPerG + ciPerG
174+
DimIndexExpr ciPerG(redIndices[0]);
175+
IndexExpr ci = SymbolIndexExpr(gTimesCIPerGroup) + ciPerG;
176+
inputAccessFct.emplace_back(ci);
177+
for (int i = 0; i < spacialRank; ++i) {
178+
// for each spacial dims: access is o * s + k * d - p.
179+
DimIndexExpr k(redIndices[1 + i]);
180+
SymbolIndexExpr pos(pMinOS[i]);
181+
LiteralIndexExpr d(shapeHelper.dilations[i]);
182+
// k*d - (p - o*s) = k*d + o*s - p
183+
IndexExpr t = (k * d) - pos;
184+
inputAccessFct.emplace_back(t);
185+
}
186+
Value image =
187+
create.krnl.loadIE(inputOperand, inputAccessFct);
188+
// Create access fct for filter: [co, ciPerG, kh, kw].
189+
SmallVector<IndexExpr, 4> filterAccessFct;
190+
filterAccessFct.emplace_back(DimIndexExpr(co));
191+
filterAccessFct.emplace_back(DimIndexExpr(ciPerG));
192192

193-
for (int i = 0; i < spacialRank; ++i) {
194-
DimIndexExpr k(redIndices[1 + i]);
195-
filterAccessFct.emplace_back(k);
196-
}
197-
Value filter =
198-
create.krnl.loadIE(filterOperand, filterAccessFct);
199-
Value oldRed = create.krnl.load(reductionVal);
200-
Value mul = create.math.mul(image, filter);
201-
Value newRed = create.math.add(oldRed, mul);
202-
create.krnl.store(newRed, reductionVal);
203-
}); // Reduction loops.
204-
// Finish the reduction and store in result array.
205-
Value result = create.krnl.load(reductionVal);
193+
for (int i = 0; i < spacialRank; ++i) {
194+
DimIndexExpr k(redIndices[1 + i]);
195+
filterAccessFct.emplace_back(k);
196+
}
197+
Value filter =
198+
create.krnl.loadIE(filterOperand, filterAccessFct);
199+
Value oldRed = iterArg;
200+
Value mul = create.math.mul(image, filter);
201+
Value newRed = create.math.add(oldRed, mul);
202+
create.krnl.yield(newRed);
203+
}); // Reduction loops.
204+
// Finish the reduction and store in result array.
205+
Value result = innerIterate.getResult(0);
206206
// Store the result. Optionally add bias.
207207
SymbolIndexExpr coInOutputSpacial(co);
208208
if (hasBias) {

0 commit comments

Comments
 (0)