Skip to content

Commit a6ebca0

Browse files
added support for no-zero-point quantization (#2938)
Signed-off-by: Alexandre Eichenberger <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent fd3eb99 commit a6ebca0

File tree

8 files changed

+242
-25
lines changed

8 files changed

+242
-25
lines changed

src/Compiler/CompilerOptions.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
4444
bool disableKrnlOpFusion; // common for both
45+
bool disableQuantZeroPoint; // common for both
4546
bool enableKrnlBufferReuse; // common for both
4647
bool disableMemRefPrefetch; // common for both
4748
EmissionTargetType emissionTarget; // onnx-mlir only
@@ -195,7 +196,7 @@ static llvm::cl::list<std::string, std::vector<std::string>>
195196
llvm::cl::cat(OnnxMlirCommonOptions));
196197

197198
static llvm::cl::opt<bool, true> enableONNXHybridPassOpt("onnx-hybrid-pass",
198-
llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n"
199+
llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n"
199200
"Set to 'false' if you want to disable ONNX hybrid pass."),
200201
llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true),
201202
llvm::cl::cat(OnnxMlirCommonOptions));
@@ -208,11 +209,20 @@ static llvm::cl::list<std::string, std::vector<std::string>>
208209

209210
static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
210211
"disable-krnl-op-fusion",
211-
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
212+
llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n"
212213
"Set to 'true' if you want to disable fusion."),
213214
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
214215
llvm::cl::cat(OnnxMlirCommonOptions));
215216

217+
static llvm::cl::opt<bool, true> disable_quantization_zero_point(
218+
"disable-quantization-zero-point",
219+
llvm::cl::desc(
220+
"Disable the use of zero-point in quantization (default=false).\n"
221+
"Set to 'true' if you want to disable the use of zero-point\n"
222+
"in dyn/static quantization/dequantization."),
223+
llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false),
224+
llvm::cl::cat(OnnxMlirCommonOptions));
225+
216226
static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
217227
"enable-krnl-buffer-reuse",
218228
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
@@ -223,7 +233,7 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
223233

224234
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
225235
"disable-memref-prefetch",
226-
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
236+
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
227237
"Set to 'true' if you want to disable prefetch."),
228238
llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false),
229239
llvm::cl::cat(OnnxMlirCommonOptions));
@@ -1145,7 +1155,6 @@ std::string getLibraryPath() {
11451155
// as lrodataScript.
11461156
std::string getToolPath(
11471157
const std::string &tool, bool flag /*false by default*/) {
1148-
11491158
if (!flag) {
11501159
std::string execDir = llvm::sys::path::parent_path(getExecPath()).str();
11511160
llvm::SmallString<8> toolPath(execDir);

src/Compiler/CompilerOptions.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
8787
extern std::vector<std::string> functionsToDecompose; // common for both
8888
extern std::string opsForCall; // common for both
8989
extern bool disableKrnlOpFusion; // common for both
90+
extern bool disableQuantZeroPoint; // common for both
9091
extern bool enableKrnlBufferReuse; // common for both
9192
extern bool disableMemRefPrefetch; // common for both
9293
extern EmissionTargetType emissionTarget; // onnx-mlir only

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -1358,9 +1358,15 @@ Value emitScalarOpFor<ONNXDequantizeLinearOp>(
13581358
Value scaleFloat = scalarOperands[1];
13591359
Value zeroPointInt = scalarOperands[2];
13601360

1361-
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
13621361
Value xFloat = create.math.cast(elementType, XInt);
1363-
Value sub = create.math.sub(xFloat, zeroPointFloat);
1362+
1363+
Value sub;
1364+
if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) {
1365+
Value zeroPointFloat = create.math.cast(elementType, zeroPointInt);
1366+
sub = create.math.sub(xFloat, zeroPointFloat);
1367+
} else {
1368+
sub = xFloat;
1369+
}
13641370
Value res = create.math.mul(sub, scaleFloat);
13651371
return res;
13661372
}

src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp

+15-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "src/Compiler/CompilerOptions.hpp"
1516
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
1617
#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp"
1718
#include "src/Dialect/Krnl/DialectBuilder.hpp"
@@ -29,7 +30,7 @@ void emitDynamicQuantizationLinearScalarParameters(
2930
ConversionPatternRewriter &rewriter, Location loc, Operation *op,
3031
MemRefType inputType, MemRefType quantizedType, Value input, Value qMin,
3132
Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint,
32-
bool enableSIMD, bool enableParallel) {
33+
bool wantZeroPoint, bool enableSIMD, bool enableParallel) {
3334
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
3435

3536
// Types
@@ -62,11 +63,15 @@ void emitDynamicQuantizationLinearScalarParameters(
6263
scale = create.math.div(xDiff, boundDiff);
6364

6465
// Compute y_zero_point.
65-
Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale));
66-
// Saturate zero point.
67-
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
68-
// Round zero point.
69-
zeroPoint = create.math.round(saturateZeroPoint);
66+
if (wantZeroPoint) {
67+
Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale));
68+
// Saturate zero point.
69+
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
70+
// Round zero point.
71+
zeroPoint = create.math.round(saturateZeroPoint);
72+
} else {
73+
zeroPoint = zero;
74+
}
7075
quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint);
7176
}
7277

@@ -122,15 +127,17 @@ struct ONNXDynamicQuantizeLinearOpLowering
122127
Value qMin = create.math.constant(elementType, 0.0);
123128
Value scale, zeroPoint, zeroPointInt;
124129

130+
bool wantZeroPoint = !disableQuantZeroPoint;
125131
emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op,
126132
xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt,
127-
enableSIMD, enableParallel);
133+
wantZeroPoint, enableSIMD, enableParallel);
128134
create.krnl.store(scale, YScale);
129135
create.krnl.store(zeroPointInt, YZeroPoint);
130136

131137
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
132138
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
133-
zeroPoint, enableSIMD, enableParallel);
139+
zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/,
140+
enableSIMD, enableParallel);
134141

135142
rewriter.replaceOp(op, {Y, YScale, YZeroPoint});
136143
onnxToKrnlSimdReport(op);

src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ void emitQuantizationLinearScalarParameters(
2323
mlir::Operation *op, mlir::MemRefType inputType,
2424
mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims,
2525
mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale,
26-
mlir::Value zeroPoint, bool enableSIMD, bool enableParallel);
26+
mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
27+
bool enableParallel);
2728

2829
// Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin
2930
// and qMax.
@@ -32,5 +33,6 @@ void emitDynamicQuantizationLinearScalarParameters(
3233
mlir::Operation *op, mlir::MemRefType inputType,
3334
mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin,
3435
mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint,
35-
mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel);
36+
mlir::Value &quantizedZeroPoint, bool wantZeroPoint, bool enableSIMD,
37+
bool enableParallel);
3638
} // namespace onnx_mlir

src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "src/Compiler/CompilerOptions.hpp"
1516
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
1617
#include "src/Dialect/Krnl/DialectBuilder.hpp"
1718
#include "src/Dialect/ONNX/DialectBuilder.hpp"
@@ -26,7 +27,8 @@ namespace onnx_mlir {
2627
void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
2728
Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType,
2829
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
29-
Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) {
30+
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
31+
bool enableParallel) {
3032
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
3133
rewriter, loc);
3234

@@ -77,7 +79,11 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
7779
// Round
7880
Value roundX = create.math.round(scaleX);
7981
// Adjust
80-
Value adjustX = create.math.add(roundX, zeroPoint);
82+
Value adjustX;
83+
if (hasZeroPoint)
84+
adjustX = create.math.add(roundX, zeroPoint);
85+
else
86+
adjustX = roundX;
8187
// Saturate
8288
Value saturateX = create.math.clip(adjustX, qMin, qMax);
8389
Value res = create.math.cast(quantizedElementType, saturateX);
@@ -160,15 +166,21 @@ struct ONNXQuantizeLinearOpLowering
160166

161167
// Load y_zero_point.
162168
Value zeroPoint;
169+
bool hasZeroPoint = false;
163170
if (!isNoneValue(YZeroPoint)) {
164171
zeroPoint = create.krnl.load(adaptor.getYZeroPoint());
165172
zeroPoint = create.math.cast(elementType, zeroPoint);
166-
} else
167-
zeroPoint = create.math.constant(elementType, 0.0);
168-
173+
hasZeroPoint = true;
174+
}
175+
if (disableQuantZeroPoint) {
176+
// TODO: should we expect to disable hasZeroPoint forcefully, or generate
177+
// an error if we had a zero point? Right now, just forcefully assert we
178+
// have no zero point, i.e. ignore one even if we had a zero point.
179+
hasZeroPoint = false;
180+
}
169181
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
170182
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
171-
zeroPoint, enableSIMD, enableParallel);
183+
zeroPoint, hasZeroPoint, enableSIMD, enableParallel);
172184

173185
rewriter.replaceOp(op, {Y});
174186
onnxToKrnlSimdReport(op);

test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// Adding canonicalize is important here as this is the only way to check the values of the map,
44
// which are otherwise before the function, and thus are hard to test.
55

6+
// -----
7+
68
func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> tensor<4xf32> {
79
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor<f32>, tensor<i8>) -> tensor<4xf32>
810
return %0 : tensor<4xf32>
@@ -29,10 +31,12 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %ar
2931

3032
// -----
3133

34+
3235
func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %arg2: tensor<ui8>) -> tensor<4xf32> {
3336
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor<f32>, tensor<ui8>) -> tensor<4xf32>
3437
return %0 : tensor<4xf32>
3538

39+
// mlir2FileCheck.py
3640
// CHECK-LABEL: func.func @test_dequantizelinear_ui8
3741
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref<f32>, [[PARAM_2_:%.+]]: memref<ui8>) -> memref<4xf32> {
3842
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32>
@@ -42,11 +46,11 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %
4246
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8>
4347
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
4448
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<ui8>
45-
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
49+
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
4650
// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32
47-
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
51+
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
4852
// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32
49-
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_8_]], [[VAR_6_]] : f32
53+
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
5054
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
5155
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
5256
// CHECK: }

0 commit comments

Comments
 (0)