Skip to content

Commit 79a8875

Browse files
removed alternative code versions
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent d30d7c7 commit 79a8875

File tree

1 file changed

+0
-122
lines changed

1 file changed

+0
-122
lines changed

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
7171
DimsExpr outputAF;
7272
outputAF.emplace_back(zero);
7373

74-
#if 1
7574
Type inputElementType = inputType.getElementType();
7675
unsigned inputWidth;
7776
if (isa<Float32Type>(inputElementType))
@@ -111,127 +110,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
111110
return res;
112111
}});
113112

114-
#elif 1
115-
// hi alex: test with 2 loops for easier debugging
116-
// Allocate output buffers (same type as input).
117-
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
118-
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
119-
DimsExpr bufferAF;
120-
bufferAF.emplace_back(zero);
121-
122-
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
123-
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
124-
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
125-
MultiDialectBuilder<MathBuilder> create(kb);
126-
Value x = inputVals[0];
127-
// Scale
128-
Value scaleX = create.math.div(x, scale);
129-
// Round
130-
Value roundX = create.math.round(scaleX);
131-
// Adjust
132-
Value adjustX;
133-
if (hasZeroPoint)
134-
adjustX = create.math.add(roundX, zeroPoint);
135-
else
136-
adjustX = roundX;
137-
// Saturate: use max into a min.
138-
Value saturateX = create.math.clip(adjustX, qMin, qMax);
139-
// Old approach.
140-
// return create.math.cast(quantizedElementType, saturateX);
141-
return saturateX;
142-
}});
143-
144-
// Need transient types.
145-
Type inputElementType = flatBufferType.getElementType();
146-
unsigned inputWidth;
147-
if (isa<Float32Type>(inputElementType))
148-
inputWidth = 32;
149-
else if (isa<Float64Type>(inputElementType))
150-
inputWidth = 64;
151-
else
152-
llvm_unreachable("unsupported input type");
153-
IntegerType quantizedIntType = cast<IntegerType>(quantizedElementType);
154-
bool isSignless = quantizedIntType.isSignless();
155-
bool isSigned = quantizedIntType.isSigned();
156-
Type quantizedElementTypeSameSizeAsInput =
157-
rewriter.getIntegerType(inputWidth, isSignless || isSigned);
158-
159-
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
160-
{flatBuffer}, {bufferAF}, {flatAlloc}, {outputAF},
161-
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
162-
MultiDialectBuilder<KrnlBuilder, VectorBuilder, MathBuilder> create(kb);
163-
// Convert float* to int*/uint* where * is 32 (64?)
164-
Value input = inputVals[0];
165-
Value quantizedSameSizeAsInput =
166-
create.math.cast(quantizedElementTypeSameSizeAsInput, input);
167-
// Convert int32/uint32 to int*/unint* where * is 8, 16...
168-
#if 0
169-
// Code get normalized to the code below
170-
unsigned quantizedWidth = quantizedIntType.getWidth();
171-
unsigned currWidth = inputWidth;
172-
Value qVal = quantizedSameSizeAsInput;
173-
while (currWidth > quantizedWidth) {
174-
currWidth = currWidth / 2;
175-
Type qType =
176-
rewriter.getIntegerType(currWidth, isSignless || isSigned);
177-
qVal = create.math.cast(qType, qVal);
178-
}
179-
#else
180-
Value qVal =
181-
create.math.cast(quantizedElementType, quantizedSameSizeAsInput);
182-
#endif
183-
return qVal;
184-
}});
185-
186-
#else
187-
// faster than original loop on z16, takes 124us for 64k vals
188-
// Allocate output buffers.
189-
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
190-
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
191-
DimsExpr bufferAF;
192-
bufferAF.emplace_back(zero);
193-
194-
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
195-
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
196-
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
197-
MultiDialectBuilder<MathBuilder> create(kb);
198-
Value x = inputVals[0];
199-
// Scale
200-
Value scaleX = create.math.div(x, scale);
201-
// Round
202-
Value roundX = create.math.round(scaleX);
203-
// Adjust
204-
Value adjustX;
205-
if (hasZeroPoint)
206-
adjustX = create.math.add(roundX, zeroPoint);
207-
else
208-
adjustX = roundX;
209-
// Saturate: use max into a min.
210-
Value saturateX = create.math.clip(adjustX, qMin, qMax);
211-
// Old approach.
212-
// return create.math.cast(quantizedElementType, saturateX);
213-
return saturateX;
214-
}});
215-
216-
// A second loop that performs scalar float to int performs better than the
217-
// compiler's attempt to generate SIMD conversion code. This might not hold
218-
// with all data types, but is definitely noticeable with uint8.
219-
//
220-
// Investigate further: we might save the vector to a buffer on the fly
221-
// (avoiding a second loop as below), and then reload each value as scalar and
222-
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
223-
// that also do not perform well). We can have a SIMD buffer in memory for the
224-
// non-quantized and quantized simd values, but then we also need to privatize
225-
// it, which is also not easy in this scheme. So ignore this for now.
226-
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
227-
[&](const KrnlBuilder &kb, ValueRange loopInd) {
228-
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
229-
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
230-
Value res = create.math.cast(quantizedElementType, buffVal);
231-
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
232-
});
233-
#endif
234-
235113
if (totVL > 1)
236114
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
237115
simdLoopStaticTripCount, "quantizationLinear whole tensor");

0 commit comments

Comments
 (0)