Skip to content

Commit 4a241ef

Browse files
authored
Recompose QLinearMatMul and remove Quantize-Dequantize pairs (#2875)
* Recompose QLinearMatMul and remove Quantize-Dequantize pairs Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent 7879d17 commit 4a241ef

11 files changed

+161
-1
lines changed

src/Dialect/ONNX/DialectBuilder.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
165165
return layerNormOp.getY();
166166
}
167167

168+
Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale,
169+
Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale,
170+
Value yZeroPoint) const {
171+
return createOpAndInferShapes<ONNXQLinearMatMulOp>(toTensor(outputType),
172+
toTensor(a), toTensor(aScale), toTensor(aZeroPoint), toTensor(b),
173+
toTensor(bScale), toTensor(bZeroPoint), toTensor(yScale),
174+
toTensor(yZeroPoint));
175+
}
176+
168177
Value OnnxBuilder::RMSLayerNorm(Type outputType, Value input, Value scale,
169178
Value bias, int64_t axis, FloatAttr epsilon) const {
170179
IntegerAttr axisAttr = getSignedInt64Attr(axis);

src/Dialect/ONNX/DialectBuilder.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ struct OnnxBuilder : DialectBuilder {
9191
mlir::Value scale, mlir::Value bias, int64_t axis,
9292
mlir::FloatAttr epsilon) const;
9393

94+
// ONNXQLinearMatMulOp
95+
mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a,
96+
mlir::Value aScale, mlir::Value aZeroPoint, mlir::Value b,
97+
mlir::Value bScale, mlir::Value bZeroPoint, mlir::Value yScale,
98+
mlir::Value yZeroPoint) const;
99+
94100
// ONNXRMSLayerNormalizationOp, version with one output only (Y).
95101
mlir::Value RMSLayerNorm(mlir::Type outputType, mlir::Value input,
96102
mlir::Value scale, mlir::Value bias, int64_t axis,

src/Dialect/ONNX/ONNXOps.td.inc

+1
Original file line numberDiff line numberDiff line change
@@ -1822,6 +1822,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace",
18221822

18231823
def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
18241824
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
1825+
let hasCanonicalizer = 1;
18251826
let summary = "ONNX DequantizeLinear operation";
18261827
let description = [{
18271828
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor.

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1858,3 +1858,9 @@ void ONNXWhereOp::getCanonicalizationPatterns(
18581858
RewritePatternSet &result, MLIRContext *context) {
18591859
result.insert<AlwaysFalseWherePattern>(context);
18601860
}
1861+
1862+
// on the ONNXDequantizeLinearOp.
1863+
void ONNXDequantizeLinearOp::getCanonicalizationPatterns(
1864+
RewritePatternSet &result, MLIRContext *context) {
1865+
result.insert<QuantizeDequantizePattern>(context);
1866+
}

src/Dialect/ONNX/ONNXOps/Canonicalize.td

+11
Original file line numberDiff line numberDiff line change
@@ -1055,4 +1055,15 @@ def AlwaysFalseWherePattern : Pat<
10551055
[(IsNegativeSplatConstant:$negative_constant), (AreAllDimSizes:$dims)]
10561056
>;
10571057

1058+
//===----------------------------------------------------------------------===//
1059+
// Canonicalization for ONNXDequantizeLinear
1060+
//===----------------------------------------------------------------------===//
1061+
1062+
// Convert QuantizeLinear+DequantizeLinear to Identity.
1063+
def QuantizeDequantizePattern: Pat<
1064+
(ONNXDequantizeLinearOp (ONNXQuantizeLinearOp $x, $x_scale, $x_zeropoint, $x_axis, $x_saturate),
1065+
$y_scale, $y_zeropoint, $y_axis),
1066+
(replaceWithValue $x)
1067+
>;
1068+
10581069
#endif // ONNX_REWRITE

src/Dialect/ONNX/Transforms/Recompose.cpp

+71
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,65 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
340340
}
341341
};
342342

343+
struct RecomposeQLinearMatMulFromQuantizeLinearPattern
344+
: public OpRewritePattern<ONNXQuantizeLinearOp> {
345+
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
346+
347+
LogicalResult matchAndRewrite(
348+
ONNXQuantizeLinearOp qlOp, PatternRewriter &rewriter) const final {
349+
using namespace onnx_mlir;
350+
Location loc = qlOp.getLoc();
351+
// Match
352+
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint;
353+
if (!matchQLinearMatMulPattern(qlOp, a, aScale, aZeroPoint, b, bScale,
354+
bZeroPoint, outScale, outZeroPoint))
355+
return failure();
356+
357+
// Replace
358+
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
359+
Value res = create.onnx.qlinearMatMul(qlOp.getY().getType(), a, aScale,
360+
aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint);
361+
362+
rewriter.replaceOp(qlOp, res);
363+
return success();
364+
}
365+
366+
// Recompose QLinearMatMul, starting from QuantizeLinear.
367+
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
368+
static bool matchQLinearMatMulPattern(ONNXQuantizeLinearOp op, Value &a,
369+
Value &aScale, Value &aZeroPoint, Value &b, Value &bScale,
370+
Value &bZeroPoint, Value &outScale, Value &outZeroPoint) {
371+
Operation *quantizeOp = op.getOperation();
372+
outScale = op.getYScale();
373+
outZeroPoint = op.getYZeroPoint();
374+
// Matching MatMul.
375+
Value qlX, matA, matB;
376+
Operation *matmulOp;
377+
bool matchMatMul = onnx_mlir::operandOfOpDefinedBy<ONNXMatMulOp>(
378+
matmulOp, quantizeOp, qlX, 0);
379+
if (!matchMatMul)
380+
return false;
381+
matA = cast<ONNXMatMulOp>(matmulOp).getA();
382+
matB = cast<ONNXMatMulOp>(matmulOp).getB();
383+
// Matching input A of MatMul.
384+
auto dlOpA = matA.getDefiningOp<ONNXDequantizeLinearOp>();
385+
if (!dlOpA)
386+
return false;
387+
a = dlOpA.getX();
388+
aScale = dlOpA.getXScale();
389+
aZeroPoint = dlOpA.getXZeroPoint();
390+
// Matching input B of MatMul.
391+
auto dlOpB = matB.getDefiningOp<ONNXDequantizeLinearOp>();
392+
if (!dlOpB)
393+
return false;
394+
b = dlOpB.getX();
395+
bScale = dlOpB.getXScale();
396+
bZeroPoint = dlOpB.getXZeroPoint();
397+
// Matched the pattern.
398+
return true;
399+
}
400+
};
401+
343402
struct RecomposeONNXToONNXPass
344403
: public PassWrapper<RecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
345404
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RecomposeONNXToONNXPass)
@@ -387,6 +446,17 @@ void RecomposeONNXToONNXPass::runOnOperation() {
387446
op, x, scale, axis, epsilon, isRMSLayerNorm);
388447
});
389448

449+
// Recompose QLinearMatMul, starting from QuantizeLinear.
450+
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
451+
target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
452+
[](ONNXQuantizeLinearOp op) {
453+
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
454+
outZeroPoint;
455+
return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
456+
matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
457+
bZeroPoint, outScale, outZeroPoint);
458+
});
459+
390460
RewritePatternSet patterns(context);
391461
onnx_mlir::getRecomposeONNXToONNXPatterns(patterns);
392462

@@ -400,6 +470,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
400470
mlir::RewritePatternSet &patterns) {
401471
MLIRContext *context = patterns.getContext();
402472
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
473+
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
403474
}
404475

405476
/*!

test/mlir/driver/compile_phases.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir %s | FileCheck %s
1+
// RUN: onnx-mlir %s -o %t| FileCheck %s && rm %t.so
22

33
// CHECK: [1/5] {{.*}} Importing ONNX Model to MLIR Module
44
// CHECK: [2/5] {{.*}} Compiling and Optimizing MLIR Module
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: onnx-mlir --printIR --EmitONNXIR %s | FileCheck %s
2+
3+
// COM: Check that Dequantize-MatMul-Quantize is always recomposed to QLinearMatMul before the removal of Quantize-Dequantize is applied.
4+
// COM: Otherwise, the recomposition of QLinearMatMul failed due to pattern mismatched (lack of DequantizeLinear).
5+
module {
6+
func.func @qlinear_matmul(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
7+
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
8+
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
9+
%2 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor<f32>, tensor<i8>) -> tensor<768x768xf32>
10+
%3 = "onnx.MatMul"(%1, %2) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
11+
%4 = "onnx.QuantizeLinear"(%3, %arg6, %arg7) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
12+
return %4: tensor<?x?x768xi8>
13+
14+
}
15+
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
16+
17+
// CHECK-LABEL: func.func @qlinear_matmul
18+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
19+
// CHECK: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, onnx_node_name = "onnx.QuantizeLinear_0", saturate = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
20+
// CHECK: [[VAR_1_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) {onnx_node_name = "onnx.QLinearMatMul_1"} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
21+
// CHECK: return [[VAR_1_]] : tensor<?x?x768xi8>
22+
// CHECK: }
23+
// CHECK: "onnx.EntryPoint"() {func = @main_graph} : () -> ()
24+
}

test/mlir/onnx/onnx_canonicalization.mlir

+15
Original file line numberDiff line numberDiff line change
@@ -1825,3 +1825,18 @@ func.func @test_where_with_always_false_3(%arg0: tensor<?x?xi64>) -> tensor<2xi6
18251825
// CHECK: onnx.Return [[VAR_6_]] : tensor<2xi64>
18261826
// CHECK: }
18271827
}
1828+
1829+
// -----
1830+
1831+
func.func @test_dequantize_linear(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> (tensor<?x?x768xf32>) {
1832+
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
1833+
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
1834+
return %1: tensor<?x?x768xf32>
1835+
1836+
// CHECK-LABEL: func.func @test_dequantize_linear
1837+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>) -> tensor<?x?x768xf32> {
1838+
// CHECK-NOT: "onnx.QuantizeLinear"
1839+
// CHECK-NOT: "onnx.DequantizeLinear"
1840+
// CHECK: return [[PARAM_0_]] : tensor<?x?x768xf32>
1841+
// CHECK: }
1842+
}

test/mlir/onnx/onnx_recompose.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,19 @@ func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>,
245245
// CHECK: }
246246
}
247247

248+
// -----
249+
250+
// COM: QLinearMatMul
251+
func.func @qlinear_matmul(%arg0: tensor<?x?x768xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
252+
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
253+
%1 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor<f32>, tensor<i8>) -> tensor<768x768xf32>
254+
%2 = "onnx.MatMul"(%0, %1) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
255+
%3 = "onnx.QuantizeLinear"(%2, %arg6, %arg7) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
256+
return %3: tensor<?x?x768xi8>
257+
258+
// CHECK-LABEL: func.func @qlinear_matmul
259+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xi8>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
260+
// CHECK: [[VAR_0_:%.+]] = "onnx.QLinearMatMul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
261+
// CHECK: return [[VAR_0_]] : tensor<?x?x768xi8>
262+
// CHECK: }
263+
}

utils/gen_onnx_mlir.py

+1
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@
332332
"Cast",
333333
"Constant",
334334
"DepthToSpace",
335+
"DequantizeLinear",
335336
"Div",
336337
"Dropout",
337338
"Equal",

0 commit comments

Comments
 (0)