@@ -340,6 +340,65 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
340
340
}
341
341
};
342
342
343
+ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
344
+ : public OpRewritePattern<ONNXQuantizeLinearOp> {
345
+ using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
346
+
347
+ LogicalResult matchAndRewrite (
348
+ ONNXQuantizeLinearOp qlOp, PatternRewriter &rewriter) const final {
349
+ using namespace onnx_mlir ;
350
+ Location loc = qlOp.getLoc ();
351
+ // Match
352
+ Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint;
353
+ if (!matchQLinearMatMulPattern (qlOp, a, aScale, aZeroPoint, b, bScale,
354
+ bZeroPoint, outScale, outZeroPoint))
355
+ return failure ();
356
+
357
+ // Replace
358
+ MultiDialectBuilder<OnnxBuilder> create (rewriter, loc);
359
+ Value res = create.onnx .qlinearMatMul (qlOp.getY ().getType (), a, aScale,
360
+ aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint);
361
+
362
+ rewriter.replaceOp (qlOp, res);
363
+ return success ();
364
+ }
365
+
366
+ // Recompose QLinearMatMul, starting from QuantizeLinear.
367
+ // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
368
+ static bool matchQLinearMatMulPattern (ONNXQuantizeLinearOp op, Value &a,
369
+ Value &aScale, Value &aZeroPoint, Value &b, Value &bScale,
370
+ Value &bZeroPoint, Value &outScale, Value &outZeroPoint) {
371
+ Operation *quantizeOp = op.getOperation ();
372
+ outScale = op.getYScale ();
373
+ outZeroPoint = op.getYZeroPoint ();
374
+ // Matching MatMul.
375
+ Value qlX, matA, matB;
376
+ Operation *matmulOp;
377
+ bool matchMatMul = onnx_mlir::operandOfOpDefinedBy<ONNXMatMulOp>(
378
+ matmulOp, quantizeOp, qlX, 0 );
379
+ if (!matchMatMul)
380
+ return false ;
381
+ matA = cast<ONNXMatMulOp>(matmulOp).getA ();
382
+ matB = cast<ONNXMatMulOp>(matmulOp).getB ();
383
+ // Matching input A of MatMul.
384
+ auto dlOpA = matA.getDefiningOp <ONNXDequantizeLinearOp>();
385
+ if (!dlOpA)
386
+ return false ;
387
+ a = dlOpA.getX ();
388
+ aScale = dlOpA.getXScale ();
389
+ aZeroPoint = dlOpA.getXZeroPoint ();
390
+ // Matching input B of MatMul.
391
+ auto dlOpB = matB.getDefiningOp <ONNXDequantizeLinearOp>();
392
+ if (!dlOpB)
393
+ return false ;
394
+ b = dlOpB.getX ();
395
+ bScale = dlOpB.getXScale ();
396
+ bZeroPoint = dlOpB.getXZeroPoint ();
397
+ // Matched the pattern.
398
+ return true ;
399
+ }
400
+ };
401
+
343
402
struct RecomposeONNXToONNXPass
344
403
: public PassWrapper<RecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
345
404
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (RecomposeONNXToONNXPass)
@@ -387,6 +446,17 @@ void RecomposeONNXToONNXPass::runOnOperation() {
387
446
op, x, scale, axis, epsilon, isRMSLayerNorm);
388
447
});
389
448
449
+ // Recompose QLinearMatMul, starting from QuantizeLinear.
450
+ // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
451
+ target.addDynamicallyLegalOp <ONNXQuantizeLinearOp>(
452
+ [](ONNXQuantizeLinearOp op) {
453
+ Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
454
+ outZeroPoint;
455
+ return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
456
+ matchQLinearMatMulPattern (op, a, aScale, aZeroPoint, b, bScale,
457
+ bZeroPoint, outScale, outZeroPoint);
458
+ });
459
+
390
460
RewritePatternSet patterns (context);
391
461
onnx_mlir::getRecomposeONNXToONNXPatterns (patterns);
392
462
@@ -400,6 +470,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
400
470
mlir::RewritePatternSet &patterns) {
401
471
MLIRContext *context = patterns.getContext ();
402
472
patterns.insert <RecomposeLayerNormFromMulPattern>(context);
473
+ patterns.insert <RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
403
474
}
404
475
405
476
/* !
0 commit comments