Skip to content

Commit 5f798ae

Browse files
version with extract/insert
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 60a8b9e commit 5f798ae

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
2929
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
3030
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
3131
bool enableParallel) {
32-
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
33-
rewriter, loc);
32+
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
33+
create(rewriter, loc);
3434

3535
// Types
3636
Type quantizedElementType = quantizedType.getElementType();
@@ -78,6 +78,45 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
7878
outputAF.emplace_back(zero);
7979

8080
#if 1
81+
82+
MemRefType outputType = llvm::cast<MemRefType>(alloc.getType());
83+
totVL = boostVLForMinUnroll(inputType, outputType, totVL);
84+
VectorType quantizedVectorType =
85+
VectorType::get({totVL}, quantizedElementType);
86+
Value qDummy = create.vec.loadIE(quantizedVectorType, flatAlloc, {zero});
87+
88+
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
89+
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
90+
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
91+
MultiDialectBuilder<VectorBuilder, MathBuilder> create(kb);
92+
Value x = inputVals[0];
93+
// Scale
94+
Value scaleX = create.math.div(x, scale);
95+
// Round
96+
Value roundX = create.math.round(scaleX);
97+
// Adjust
98+
Value adjustX;
99+
if (hasZeroPoint)
100+
adjustX = create.math.add(roundX, zeroPoint);
101+
else
102+
adjustX = roundX;
103+
// Saturate: use max into a min.
104+
Value saturateX = create.math.clip(adjustX, qMin, qMax);
105+
Value res;
106+
if (VL == 1) {
107+
res = create.math.cast(quantizedElementType, saturateX);
108+
} else {
109+
res = qDummy; //
110+
for (int64_t v = 0; v < VL; ++v) {
111+
Value element = create.vec.extractElement(saturateX, v);
112+
Value resElement = create.math.cast(quantizedElementType, element);
113+
res = create.vec.insertElement(res, resElement, v);
114+
}
115+
}
116+
return res;
117+
}});
118+
119+
#elif 1
81120
// Allocate output buffers.
82121
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
83122
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);

src/Dialect/Mlir/DialectBuilder.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,6 +2073,29 @@ void VectorBuilder::multiReduction(ArrayRef<Value> inputVecArray,
20732073
}
20742074
}
20752075

2076+
Value VectorBuilder::extractElement(Value vector, int64_t index) const {
2077+
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
2078+
VectorType type = llvm::cast<VectorType>(vector.getType());
2079+
int64_t VL = type.getShape()[0];
2080+
assert(type.getRank() == 1 && "expected 1D vector only");
2081+
assert(index >= 0 && index < VL && "out of range vector index");
2082+
Value position = create.math.constantIndex(index);
2083+
return b().create<vector::ExtractElementOp>(loc(), vector, position);
2084+
}
2085+
2086+
Value VectorBuilder::insertElement(
2087+
Value vector, Value element, int64_t index) const {
2088+
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
2089+
VectorType type = llvm::cast<VectorType>(vector.getType());
2090+
int64_t VL = type.getShape()[0];
2091+
assert(type.getRank() == 1 && "expected 1D vector only");
2092+
assert(index >= 0 && index < VL && "out of range vector index");
2093+
Value position = create.math.constantIndex(index);
2094+
// Unlike LLVM insert element which takes <dest, source, position>, vector
2095+
// take <source, dest, position>
2096+
return b().create<vector::InsertElementOp>(loc(), element, vector, position);
2097+
}
2098+
20762099
//===----------------------------------------------------------------------===//
20772100
// LLVM Builder
20782101
//===----------------------------------------------------------------------===//

src/Dialect/Mlir/DialectBuilder.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,11 @@ struct VectorBuilder final : DialectBuilder {
574574
void multiReduction(mlir::ArrayRef<mlir::Value> inputVecArray,
575575
F2 reductionFct, llvm::SmallVectorImpl<mlir::Value> &outputVecArray);
576576

577+
// Insert and extract.
578+
mlir::Value extractElement(mlir::Value vector, int64_t position) const;
579+
mlir::Value insertElement(
580+
mlir::Value vector, mlir::Value element, int64_t position) const;
581+
577582
private:
578583
bool isPowerOf2(uint64_t num) const;
579584
uint64_t getLengthOf1DVector(mlir::Value vec) const;

0 commit comments

Comments
 (0)