Skip to content

Commit bc18452

Browse files
Parallel code for quantization (#2923)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 1900ea7 commit bc18452

15 files changed

+1743
-653
lines changed

src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp

-14
Original file line numberDiff line numberDiff line change
@@ -372,20 +372,6 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
372372
// canonicalization after the lowering.
373373
target.addLegalOp<::mlir::ONNXNoneOp>();
374374

375-
// Use krnl.load/store instead of std.load/store and affine.load/store.
376-
// krnl.load/store will be lowered to std.load/store and affine.load/store
377-
// by `convert-krnl-to-affine` pass.
378-
target.addIllegalOp<mlir::memref::LoadOp>();
379-
target.addIllegalOp<mlir::affine::AffineLoadOp>();
380-
target.addIllegalOp<mlir::memref::StoreOp>();
381-
// Memref builder can use affine stores, it would be awkward for it to
382-
// generate Krnl stores as mem builder is part of MLIR. Thus the affine
383-
// stores should not be illegal here. Since affine loads are still illegal,
384-
// the regular krnl lowering will most likely trigger errors if non krnl mem
385-
// ops where generally used.
386-
//
387-
// target.addIllegalOp<mlir::affine::AffineStoreOp>();
388-
389375
// Option`emitDealloc` is deprecated and turned off, make sure we don't have
390376
// buffer deallocation at this level. Will use MLIR buffer-deallocation for
391377
// this purpose instead. However, since the SequenceErase needs to emit

src/Conversion/ONNXToKrnl/Math/Reduction.cpp

+244-139
Large diffs are not rendered by default.

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,9 @@ bool hasNonIdentityLayout(ValueRange operands) {
610610
// requirement by definition. If found one, it is parDim and the function
611611
// returns true.
612612

613-
bool findSuitableParallelDimension(llvm::SmallVectorImpl<IndexExpr> &lb,
614-
llvm::SmallVectorImpl<IndexExpr> &ub, int64_t firstInclusiveDim,
615-
int64_t lastExclusiveDim, int64_t &parDim, int64_t minSize) {
613+
bool findSuitableParallelDimension(ArrayRef<IndexExpr> lb,
614+
ArrayRef<IndexExpr> ub, int64_t firstInclusiveDim, int64_t lastExclusiveDim,
615+
int64_t &parDim, int64_t minSize) {
616616
assert(lb.size() == ub.size() && "expected identical ranks for lb/ub");
617617
if (firstInclusiveDim < 0)
618618
firstInclusiveDim = 0;

src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ bool hasNonIdentityLayout(mlir::ValueRange operands);
611611
// Return the outermost loop within [firstDim, lastDim) for which (ub-lb) >=
612612
// minSize. Runtime dimensions are assumed to satisfy the size requirement by
613613
// definition. If found one, it is parDim and the function returns true.
614-
bool findSuitableParallelDimension(llvm::SmallVectorImpl<IndexExpr> &lb,
615-
llvm::SmallVectorImpl<IndexExpr> &ub, int64_t firstInclusiveDim,
614+
bool findSuitableParallelDimension(mlir::ArrayRef<IndexExpr> lb,
615+
mlir::ArrayRef<IndexExpr> ub, int64_t firstInclusiveDim,
616616
int64_t lastExclusiveDim, int64_t &parDim, int64_t minSize = 4);
617617

618618
//===----------------------------------------------------------------------===//

src/Dialect/Krnl/DialectBuilder.cpp

+36-256
Original file line numberDiff line numberDiff line change
@@ -58,55 +58,26 @@ static StringRef getFormat(const Type &inputType) {
5858

5959
Value KrnlBuilder::load(
6060
Value memref, ValueRange indices, ValueRange offsets) const {
61-
// Handle offsets.
62-
SmallVector<Value, 4> computedIndices;
63-
MathBuilder createMath(*this);
64-
createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices);
65-
// Perform load.
66-
if (computedIndices.size() == 0) {
67-
// case memref<1xdtype>
68-
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
69-
assert(type && "Not MemRefType");
70-
if (type.getRank() == 1 && type.getShape()[0] == 1) {
71-
MultiDialectBuilder<MathBuilder> create(*this);
72-
Value iZero = create.math.constantIndex(0);
73-
return b().create<KrnlLoadOp>(loc(), memref, ValueRange({iZero}));
74-
}
75-
}
76-
return b().create<KrnlLoadOp>(loc(), memref, computedIndices);
61+
return onnx_mlir::impl::load<KrnlBuilder, KrnlLoadOp>(
62+
*this, memref, indices, offsets);
7763
}
7864

7965
Value KrnlBuilder::loadIE(
8066
Value memref, ArrayRef<IndexExpr> indices, ValueRange offsets) const {
81-
SmallVector<Value, 4> indexValues;
82-
IndexExpr::getValues(indices, indexValues);
83-
return load(memref, indexValues, offsets);
67+
return onnx_mlir::impl::loadIE<KrnlBuilder, KrnlLoadOp>(
68+
*this, memref, indices, offsets);
8469
}
8570

8671
void KrnlBuilder::store(
8772
Value val, Value memref, ValueRange indices, ValueRange offsets) const {
88-
SmallVector<Value, 4> computedIndices;
89-
MathBuilder createMath(*this);
90-
createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices);
91-
if (computedIndices.size() == 0) {
92-
// case memref<1xdtype>
93-
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
94-
assert(type && "Not MemRefType");
95-
if (type.getRank() == 1 && type.getShape()[0] == 1) {
96-
MultiDialectBuilder<MathBuilder> create(*this);
97-
Value iZero = create.math.constantIndex(0);
98-
b().create<KrnlStoreOp>(loc(), val, memref, ValueRange({iZero}));
99-
return;
100-
}
101-
}
102-
b().create<KrnlStoreOp>(loc(), val, memref, computedIndices);
73+
onnx_mlir::impl::store<KrnlBuilder, KrnlStoreOp>(
74+
*this, val, memref, indices, offsets);
10375
}
10476

10577
void KrnlBuilder::storeIE(Value val, Value memref, ArrayRef<IndexExpr> indices,
10678
ValueRange offsets) const {
107-
SmallVector<Value, 4> indexValues;
108-
IndexExpr::getValues(indices, indexValues);
109-
store(val, memref, indexValues, offsets);
79+
onnx_mlir::impl::storeIE<KrnlBuilder, KrnlStoreOp>(
80+
*this, val, memref, indices, offsets);
11081
}
11182

11283
Value KrnlBuilder::getLinearOffsetIndex(
@@ -246,228 +217,37 @@ KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops,
246217
});
247218
}
248219

249-
/*
250-
Example of how to use the interface:
251-
Say you have a loop of i=0..256, j=0..128 and want to exploit r[i,j] = a[i,j] +
252-
b[j] + c. For the loops, we will need access functions for a, b, and r.
253-
254-
Say we already have the loop for the outer loop of i
255-
256-
krnl.iterate(loop i from 0 to 256) {
257-
ii is the loop index.
258-
259-
// 1) compute access function for a, b, c
260-
// 2) launch simd loop with
261-
// 3) simd kernel
262-
}
263-
264-
1) Access functions
265-
Assuming here that we are not blocking the j loop, namely the simd iteration
266-
goes over all j values, the access functions should be defined as follows.
267-
268-
aAF = {ii, 0}
269-
bAF = {0}
270-
rAF = {ii, 0}
271-
272-
If the j loop was blocked (say j=0 to 128 by 16), then instead of `0` in the
273-
last dim, we would have 'blocked_jj'
274-
275-
2) Launch simd loop
276-
277-
create.krnl.simdIterateIE(
278-
lb=LitIE(0), ub=litIE(128), totVL=8, // loop params
279-
fullySimd=true, useParallel=false, // loop options
280-
inputs={A, B}, inputAFs={aAF, bAF}, // inputs
281-
outputs={R}, outputAFs={rAF}, // outputs
282-
krnl) // lambda function for kernel
283-
284-
3) Krnl for SIMD loop
285-
286-
The kernel functions has 4 inputs:
287-
a) krnl builder to further build code
288-
b) list of loaded input values, in the same order as in inputs
289-
c) list of results values, that must be enqueued by the kernel
290-
d) totVL used for the loop (VL for simd, 1 for scalar)
291-
292-
The same kernel will be used in a SIMD context, in which the inputs and
293-
outputs must be vectors of VL elements, or in a scalar context, in which the
294-
inputs and outputs must be scalars.
295-
296-
In our example, the kernel is as follows
297-
298-
[&](KrnlBuilder &kb, ArrayRef<Value> inputVals,
299-
SmallVectorImpl<Value> &resVals, int64_t VL) {
300-
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(kb);
301-
Value aVal = inputVals[0]; // simd or scalar
302-
Value bVal = inputVals[1]; // simd or scalar
303-
Value cVal = create.krnl.load(C); // scalar always
304-
Value newVal = create.math.add(aVal, bVal); // simd or scalar
305-
newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is
306-
// splatted
307-
res.emplace_back(newVal); // Save simd or scalar result.
308-
}
309-
310-
The krnl.simdIterateIE will be in charge of loading and saving the values in
311-
memory. The create.math functions have been extended so that when a SIMD
312-
value is computed with a scalar, that scalar will be automaticaly splatted
313-
(aka promoted to a vector of identical values). As a result, the kernel can
314-
be written in a SIMD agnostic value. However, in rare situations, we may
315-
want to know if we are in SIMD mode or not. VL will give the totVL used here
316-
(either totVL>1 or 1).
317-
*/
318-
319-
// Determine if an access has one element from the innermost dimensions up to
320-
// innerDim.
321-
bool static hasOneElementInInnermostDims(Value value, int64_t innerDim) {
322-
// Get info.
323-
ShapedType type = mlir::dyn_cast<ShapedType>(value.getType());
324-
assert(type && "expected shaped type");
325-
int64_t rank = type.getRank();
326-
ArrayRef<int64_t> shape = type.getShape();
327-
for (int64_t i = std::max((int64_t)0, rank - innerDim); i < rank; ++i)
328-
if (shape[i] != 1)
329-
return false;
330-
return true;
331-
}
332-
333220
void KrnlBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL,
334221
bool fullySimd, bool useParallel, ArrayRef<Value> inputs,
335222
ArrayRef<DimsExpr> inputAFs, ArrayRef<Value> outputs,
336223
ArrayRef<DimsExpr> outputAFs,
337-
function_ref<void(KrnlBuilder &kb, ArrayRef<Value> inputVals,
224+
function_ref<void(KrnlBuilder &b, ArrayRef<Value> inputVals,
338225
llvm::SmallVectorImpl<Value> &resultVals, int64_t VL)>
339-
bodyBuilderFn) {
340-
int64_t inputNum = inputs.size();
341-
assert(inputAFs.size() == inputs.size() && "expected same size");
342-
int64_t outputNum = outputs.size();
343-
assert(outputAFs.size() == outputs.size() && "expected same size");
344-
MultiDialectBuilder<VectorBuilder> create(*this);
345-
346-
if (VL > 1) {
347-
// Want SIMD, execute full SIMD loops blocked by VL.
348-
ValueRange loopDef = defineLoops(1);
349-
ValueRange blockedLoopDef = block(loopDef[0], VL);
350-
if (useParallel)
351-
parallel({blockedLoopDef[0]});
352-
353-
// If we are not guaranteed that every iterations are SIMD iterations, then
354-
// we need to reduce the trip count by a bit so as to not over compute.
355-
// If we are not guaranteed that every iterations are SIMD iterations, then
356-
IndexExpr simdUb = ub;
357-
if (!fullySimd)
358-
simdUb = simdUb - (VL - 1);
359-
iterateIE(loopDef, {blockedLoopDef[0]}, {lb}, {simdUb},
360-
[&](KrnlBuilder &ck, ValueRange loopInd) {
361-
IndexExprScope scope(ck);
362-
MultiDialectBuilder<KrnlBuilder, VectorBuilder> create(ck);
363-
IndexExpr ind = DimIE(loopInd[0]);
364-
// Load all the inputs as vectors of VL values, with a few exceptions.
365-
// One is if the value is a "none value", leave as is. Another one is
366-
// if the innermost dim is a scalar (ie dim[rank-1] == 1), then we
367-
// just load the scalar.
368-
llvm::SmallVector<Value, 4> vecInputVals;
369-
for (int64_t i = 0; i < inputNum; ++i) {
370-
Value input = inputs[i];
371-
if (isNoneValue(input)) {
372-
// Simply enqueue the none value.
373-
vecInputVals.emplace_back(input);
374-
continue;
375-
}
376-
MemRefType type = mlir::cast<MemRefType>(input.getType());
377-
int64_t rank = type.getRank();
378-
DimsExpr AF = SymListIE(inputAFs[i]);
379-
assert(rank == (int64_t)AF.size() && "AF expected input rank refs");
380-
if (hasOneElementInInnermostDims(input, 1)) {
381-
// Has a reference with a scalar innermost dim, just load as a
382-
// scalar. No need to add the induction variable.
383-
Value scalarVal = create.krnl.loadIE(input, AF);
384-
vecInputVals.emplace_back(scalarVal);
385-
} else {
386-
// Have a vector.
387-
VectorType vecType = VectorType::get({VL}, type.getElementType());
388-
AF[rank - 1] = AF[rank - 1] + ind; // Add induction var.
389-
Value vecVal = create.vec.loadIE(vecType, input, AF);
390-
vecInputVals.emplace_back(vecVal);
391-
}
392-
}
393-
// Call the method to compute the values.
394-
llvm::SmallVector<Value, 4> vecResVals;
395-
bodyBuilderFn(create.krnl, vecInputVals, vecResVals, VL);
396-
assert((int64_t)vecResVals.size() == outputNum &&
397-
"loop body with incorrect number of results");
398-
// Store all the outputs as vectors of VL values,
399-
for (int64_t i = 0; i < outputNum; ++i) {
400-
MemRefType type = mlir::cast<MemRefType>(outputs[i].getType());
401-
DimsExpr AF = SymListIE(outputAFs[i]);
402-
int64_t rank = type.getRank();
403-
assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs");
404-
AF[rank - 1] = AF[rank - 1] + ind;
405-
create.vec.storeIE(vecResVals[i], outputs[i], AF);
406-
}
407-
});
408-
if (fullySimd)
409-
// Asserted that we only have SIMD iterations, we are done.
410-
return;
411-
// Account for the loop iterations performed above.
412-
IndexExpr tripCount = ub - lb;
413-
IndexExpr missingIters = tripCount % VL;
414-
IndexExpr completedIters = tripCount - missingIters;
415-
if (missingIters.isLiteralAndIdenticalTo(0)) {
416-
// Detect that we only have SIMD iterations, we are also done.
417-
return;
418-
}
419-
// We may have additional iterations to perform, adjust lb to skip the
420-
// completed iterations.
421-
lb = lb + completedIters;
422-
}
423-
// Handle remaining scalar values (from lb to ub without unrolling).
424-
ValueRange loopDef = defineLoops(1);
425-
iterateIE(
426-
loopDef, loopDef, {lb}, {ub}, [&](KrnlBuilder &ck, ValueRange loopInd) {
427-
IndexExprScope scope(ck);
428-
MultiDialectBuilder<KrnlBuilder> create(ck);
429-
IndexExpr ind = DimIE(loopInd[0]);
430-
// Load all the inputs as scalar values,
431-
llvm::SmallVector<Value, 4> scalarInputVals;
432-
for (int64_t i = 0; i < inputNum; ++i) {
433-
Value input = inputs[i];
434-
if (isNoneValue(input)) {
435-
// Simply enqueue the none value.
436-
scalarInputVals.emplace_back(input);
437-
continue;
438-
}
439-
MemRefType type = mlir::cast<MemRefType>(input.getType());
440-
int64_t rank = type.getRank();
441-
DimsExpr AF = SymListIE(inputAFs[i]);
442-
if (hasOneElementInInnermostDims(input, 1)) {
443-
// Has a reference with a scalar innermost dim, just load as a
444-
// scalar. No need to add the induction variable.
445-
Value scalarVal = create.krnl.loadIE(input, AF);
446-
scalarInputVals.emplace_back(scalarVal);
447-
} else {
448-
AF[rank - 1] = AF[rank - 1] + ind;
449-
Value scalarVal = create.krnl.loadIE(input, AF);
450-
scalarInputVals.emplace_back(scalarVal);
451-
}
452-
}
453-
// Call the method to compute the values.
454-
llvm::SmallVector<Value, 4> scalarResVals;
455-
bodyBuilderFn(create.krnl, scalarInputVals, scalarResVals, /*VL*/ 1);
456-
assert((int64_t)scalarResVals.size() == outputNum &&
457-
"loop body with incorrect number of results");
458-
// Store all the outputs as vectors of VL values,
459-
for (int64_t i = 0; i < outputNum; ++i) {
460-
MemRefType type = mlir::cast<MemRefType>(outputs[i].getType());
461-
DimsExpr AF = SymListIE(outputAFs[i]);
462-
int64_t rank = type.getRank();
463-
assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs");
464-
AF[rank - 1] = AF[rank - 1] + ind;
465-
create.krnl.storeIE(scalarResVals[i], outputs[i], AF);
466-
}
467-
});
468-
}
469-
470-
void KrnlBuilder::yield(ValueRange iterArgs) const {
226+
bodyBuilderFn) const {
227+
onnx_mlir::impl::simdIterateIE<KrnlBuilder, KrnlBuilder>(*this, lb, ub, VL,
228+
fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs,
229+
bodyBuilderFn);
230+
}
231+
232+
void KrnlBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL,
233+
bool fullySimd, ArrayRef<Value> inputs, ArrayRef<DimsExpr> inputAFs,
234+
ArrayRef<Value> tmps, ArrayRef<DimsExpr> tmpAFs, ArrayRef<Value> outputs,
235+
ArrayRef<DimsExpr> outputAFs, ArrayRef<Value> initVals,
236+
/* reduction function (simd or scalar) */
237+
function_ref<void(const KrnlBuilder &b, ArrayRef<Value> inputVals,
238+
ArrayRef<Value> tmpVals, llvm::SmallVectorImpl<Value> &resultVals,
239+
int64_t VL)>
240+
reductionBuilderFn,
241+
/* post reduction function (simd to scalar + post processing)*/
242+
function_ref<void(const KrnlBuilder &b, ArrayRef<Value> tmpVals,
243+
llvm::SmallVectorImpl<Value> &scalarOutputs, int64_t VL)>
244+
postProcessingBuilderFn) const {
245+
onnx_mlir::impl::simdReduceIE<KrnlBuilder, KrnlBuilder>(*this, lb, ub, VL,
246+
fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals,
247+
reductionBuilderFn, postProcessingBuilderFn);
248+
}
249+
250+
void KrnlBuilder::yield(mlir::ValueRange iterArgs) const {
471251
b().create<KrnlYieldOp>(loc(), iterArgs);
472252
}
473253

@@ -533,8 +313,8 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name,
533313
void KrnlBuilder::memcpy(Value dest, Value src, Value numElems) const {
534314
MultiDialectBuilder<MathBuilder> create(*this);
535315
Value zero = create.math.constantIndex(0);
536-
b().create<KrnlMemcpyOp>(
537-
loc(), dest, src, numElems, /*dest_offset=*/zero, /*src_offset=*/zero);
316+
b().create<KrnlMemcpyOp>(loc(), dest, src, numElems,
317+
/*dest_offset=*/zero, /*src_offset=*/zero);
538318
}
539319

540320
void KrnlBuilder::memcpy(Value dest, Value src, Value numElems,

0 commit comments

Comments
 (0)