@@ -23,14 +23,18 @@ namespace onnx_mlir {
23
23
24
24
static void emitInnerLoops (KrnlBuilder &createKrnl, int64_t numberOfLoops,
25
25
SmallVectorImpl<IndexExpr> &Lbs, SmallVectorImpl<IndexExpr> &Ubs,
26
- ValueRange outerIndices, Value input, Value alloc, Value sumOp, Value maxOp ,
27
- int64_t axis, bool coerced = true ) {
26
+ ValueRange outerIndices, Value input, Value alloc, Value zero ,
27
+ Value negInfinity, int64_t axis, bool coerced = true ) {
28
28
int64_t rank = alloc.getType ().cast <MemRefType>().getRank ();
29
29
30
+ ValueRange maxInits = ValueRange (negInfinity);
30
31
// Compute the maximum value along axis.
31
32
ValueRange maxLoops = createKrnl.defineLoops (numberOfLoops);
32
- createKrnl.iterateIE (maxLoops, maxLoops, Lbs, Ubs,
33
- [&](KrnlBuilder &createKrnl, ValueRange maxIndices) {
33
+ auto maxLoop = createKrnl.iterateIE (maxLoops, maxLoops, Lbs, Ubs, maxInits,
34
+ [&](KrnlBuilder &createKrnl, ValueRange maxIndices, ValueRange iterArgs) {
35
+ // Get last argument for the iterate body.
36
+ Value iterArg = iterArgs.back ();
37
+
34
38
MultiDialectBuilder<KrnlBuilder, MathBuilder> create (createKrnl);
35
39
IndexExprScope ieScope (createKrnl);
36
40
@@ -49,19 +53,24 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
49
53
maxLoopIVs.push_back (outerIndices[i - 1 ]);
50
54
}
51
55
52
- Value max = create. krnl . load (maxOp, {}) ;
56
+ Value max = iterArg ;
53
57
Value nextMax = create.krnl .load (input, maxLoopIVs);
54
58
auto maxCond = create.math .sgt (max, nextMax);
55
59
max = create.math .select (maxCond, max, nextMax);
56
- create.krnl .store (max, maxOp, ArrayRef<Value>{});
60
+
61
+ create.krnl .yield (max);
57
62
});
58
- // Load the maximum value.
59
- Value max = createKrnl. load (maxOp, {} );
63
+ // Get the maximum value.
64
+ Value max = maxLoop. getResult ( 0 );
60
65
66
+ ValueRange sumInits = ValueRange (zero);
61
67
// Compute the sum of all values along axis.
62
68
ValueRange sumLoops = createKrnl.defineLoops (numberOfLoops);
63
- createKrnl.iterateIE (sumLoops, sumLoops, Lbs, Ubs,
64
- [&](KrnlBuilder &createKrnl, ValueRange sumIndices) {
69
+ auto sumLoop = createKrnl.iterateIE (sumLoops, sumLoops, Lbs, Ubs, sumInits,
70
+ [&](KrnlBuilder &createKrnl, ValueRange sumIndices, ValueRange iterArgs) {
71
+ // Get last argument for the iterate body.
72
+ Value iterArg = iterArgs.back ();
73
+
65
74
MultiDialectBuilder<KrnlBuilder, MathBuilder> create (createKrnl);
66
75
IndexExprScope ieScope (createKrnl);
67
76
@@ -80,19 +89,19 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
80
89
sumLoopIVs.push_back (outerIndices[i - 1 ]);
81
90
}
82
91
83
- Value sum = create. krnl . load (sumOp, {}) ;
92
+ Value sum = iterArg ;
84
93
Value next = create.krnl .load (input, sumLoopIVs);
85
94
Value sub = create.math .sub (next, max);
86
95
Value exp = create.math .exp (sub);
87
96
sum = create.math .add (sum, exp );
88
- create.krnl .store (sum, sumOp, ArrayRef<Value>{});
89
97
// Store intermediate values in the result to avoid
90
98
// recomputation.
91
99
create.krnl .store (exp , alloc, sumLoopIVs);
100
+ create.krnl .yield (sum);
92
101
});
93
102
94
103
// Load the sum value.
95
- Value sum = createKrnl. load (sumOp, {} );
104
+ Value sum = sumLoop. getResult ( 0 );
96
105
97
106
// Compute the softmax.
98
107
ValueRange softmaxLoops = createKrnl.defineLoops (numberOfLoops);
@@ -124,16 +133,14 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops,
124
133
125
134
template <typename T>
126
135
void emitInstForSoftmax (ConversionPatternRewriter &rewriter, Operation *op,
127
- Location loc, Value alloc, Value input, MemRefType scalarMemRefType,
128
- Value sumOp, Value maxOp, Value zero, Value negInfinity, int64_t axis,
129
- bool enableParallel) = delete;
136
+ Location loc, Value alloc, Value input, Value zero, Value negInfinity,
137
+ int64_t axis, bool enableParallel) = delete;
130
138
131
139
// For Softmax opset < 13, `axis` is the coerced point. All dimensions
132
140
// after `axis` will be logically coerced into a single dimension.
133
141
template <>
134
142
void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
135
- Operation *op, Location loc, Value alloc, Value input,
136
- MemRefType scalarMemRefType, Value sumOp, Value maxOp, Value zero,
143
+ Operation *op, Location loc, Value alloc, Value input, Value zero,
137
144
Value negInfinity, int64_t axis, bool enableParallel) {
138
145
int64_t rank = alloc.getType ().cast <MemRefType>().getRank ();
139
146
@@ -151,18 +158,15 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
151
158
if (axis == 0 ) {
152
159
assert (!enableParallel && " only outer loop parallelism at this time" );
153
160
// There is no need having outer loops.
154
- // Reset accumulators.
155
- create.krnl .store (zero, sumOp, ArrayRef<Value>{});
156
- create.krnl .store (negInfinity, maxOp, ArrayRef<Value>{});
157
161
158
162
// Common information to create nested loops.
159
163
int64_t numberOfLoops = rank;
160
164
SmallVector<IndexExpr, 4 > Lbs (numberOfLoops, zeroIE);
161
165
SmallVector<IndexExpr, 4 > Ubs;
162
166
create.krnlIE .getShapeAsDims (input, Ubs);
163
167
164
- emitInnerLoops (create.krnl , numberOfLoops, Lbs, Ubs, {}, input, alloc,
165
- sumOp, maxOp , axis, /* coerced=*/ true );
168
+ emitInnerLoops (create.krnl , numberOfLoops, Lbs, Ubs, {}, input, alloc, zero,
169
+ negInfinity , axis, /* coerced=*/ true );
166
170
} else {
167
171
// Define outer loops.
168
172
ValueRange outerLoops = create.krnl .defineLoops (axis);
@@ -183,16 +187,6 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
183
187
create (ck);
184
188
IndexExprScope ieScope (ck);
185
189
186
- if (enableParallel) {
187
- // Temporary results must be private when parallel. Use alloca here
188
- // as scalars are small.
189
- sumOp = create.mem .alignedAlloca (scalarMemRefType);
190
- maxOp = create.mem .alignedAlloca (scalarMemRefType);
191
- }
192
- // Reset accumulators.
193
- create.krnl .store (zero, sumOp, ArrayRef<Value>{});
194
- create.krnl .store (negInfinity, maxOp, ArrayRef<Value>{});
195
-
196
190
// Common information to create inner nested loops.
197
191
int64_t numberOfLoops = rank - axis;
198
192
SmallVector<IndexExpr, 4 > Lbs (numberOfLoops, zeroIE);
@@ -202,7 +196,7 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
202
196
203
197
// Emit the inner loops.
204
198
emitInnerLoops (create.krnl , numberOfLoops, Lbs, Ubs, outerIndices,
205
- input, alloc, sumOp, maxOp , axis, /* coerced=*/ true );
199
+ input, alloc, zero, negInfinity , axis, /* coerced=*/ true );
206
200
});
207
201
}
208
202
}
@@ -212,8 +206,7 @@ void emitInstForSoftmax<ONNXSoftmaxV11Op>(ConversionPatternRewriter &rewriter,
212
206
// `axis`.
213
207
template <>
214
208
void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
215
- Operation *op, Location loc, Value alloc, Value input,
216
- MemRefType scalarMemRefType, Value sumOp, Value maxOp, Value zero,
209
+ Operation *op, Location loc, Value alloc, Value input, Value zero,
217
210
Value negInfinity, int64_t axis, bool enableParallel) {
218
211
int64_t rank = alloc.getType ().cast <MemRefType>().getRank ();
219
212
@@ -246,17 +239,6 @@ void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
246
239
create (ck);
247
240
IndexExprScope ieScope (ck);
248
241
249
- if (enableParallel) {
250
- // Temporary results must be private when parallel. Use alloca here as
251
- // scalars are small.
252
- sumOp = create.mem .alignedAlloca (scalarMemRefType);
253
- maxOp = create.mem .alignedAlloca (scalarMemRefType);
254
- }
255
-
256
- // Reset accumulators.
257
- create.krnl .store (zero, sumOp, ArrayRef<Value>{});
258
- create.krnl .store (negInfinity, maxOp, ArrayRef<Value>{});
259
-
260
242
// Common information to create inner nested loops for axis only.
261
243
int64_t numberOfLoops = 1 ;
262
244
SmallVector<IndexExpr, 4 > Lbs (numberOfLoops, zeroIE);
@@ -265,7 +247,7 @@ void emitInstForSoftmax<ONNXSoftmaxOp>(ConversionPatternRewriter &rewriter,
265
247
266
248
// Emit the inner loops.
267
249
emitInnerLoops (create.krnl , numberOfLoops, Lbs, Ubs, outerIndices,
268
- input, alloc, sumOp, maxOp , axis, /* coerced=*/ false );
250
+ input, alloc, zero, negInfinity , axis, /* coerced=*/ false );
269
251
});
270
252
}
271
253
@@ -316,22 +298,12 @@ struct ONNXSoftmaxLowering : public OpConversionPattern<SoftmaxOp> {
316
298
MultiDialectBuilder<MemRefBuilder, MathBuilder> create (rewriter, loc);
317
299
Value alloc = create.mem .alignedAlloc (input, memRefType);
318
300
319
- // Insert allocations and deallocations for sum and max.
320
- MemRefType scalarMemRefType = MemRefType::get ({}, elementType, {}, 0 );
321
- Value sumOp, maxOp;
322
- if (!enableParallelLocal) {
323
- // Temporary results must be private when parallel.
324
- sumOp = create.mem .alignedAlloc (scalarMemRefType);
325
- maxOp = create.mem .alignedAlloc (scalarMemRefType);
326
- }
327
-
328
301
Value zero = create.math .constant (elementType, 0 );
329
302
Value negInfinity = create.math .constant (
330
303
elementType, -std::numeric_limits<float >::infinity ());
331
304
332
- emitInstForSoftmax<SoftmaxOp>(rewriter, op, loc, alloc, input,
333
- scalarMemRefType, sumOp, maxOp, zero, negInfinity, axis,
334
- enableParallelLocal);
305
+ emitInstForSoftmax<SoftmaxOp>(rewriter, op, loc, alloc, input, zero,
306
+ negInfinity, axis, enableParallelLocal);
335
307
336
308
rewriter.replaceOp (op, alloc);
337
309
onnxToKrnlSimdReport (op);
0 commit comments