@@ -391,108 +391,120 @@ struct RecomposeGeluFromMulPattern : public OpRewritePattern<ONNXMulOp> {
391
391
return isDenseONNXConstant (v) && isConstOf (v, n);
392
392
};
393
393
394
- // Match A * B * C: 0.5 * x * (+)
395
- // Three associative cases depending on the order of Mul:
396
- // - (0.5 * x) * (+)
397
- // - 0.5 * (x * (+))
398
- // - (0.5 * (+)) * x
399
- // For each case, we have two communitive cases.
400
- // In total, we handle 6 cases.
394
+ // Match 0.5 * a * b
395
+ // Two associative cases depending on which Mul 0.5 belongs to:
396
+ // - 0.5 * (a * b)
397
+ // - (0.5 * a) * b
398
+ // For each case, we have two communitive cases for the outer Mul (not count
399
+ // the inner Mul). In total, we handle 4 cases.
401
400
Value lhs = mulOp.getOperand (0 );
402
401
Value rhs = mulOp.getOperand (1 );
403
- auto lhsMulOp = lhs.getDefiningOp <ONNXMulOp>();
404
- auto rhsMulOp = rhs.getDefiningOp <ONNXMulOp>();
405
- auto lhsAddOp = lhs.getDefiningOp <ONNXAddOp>();
406
- auto rhsAddOp = rhs.getDefiningOp <ONNXAddOp>();
407
- bool lhsIsCst = constOf (lhs, 0.5 );
408
- bool rhsIsCst = constOf (rhs, 0.5 );
409
-
410
- // Match to get AddOp and x from 0.5 * x * (+).
411
- ONNXAddOp add1Op;
412
- if (lhsIsCst && rhsMulOp) {
413
- // - 0.5 * (x * (+))
414
- if (auto aOp = rhsMulOp.getOperand (0 ).getDefiningOp <ONNXAddOp>()) {
415
- x = rhsMulOp.getOperand (1 );
416
- add1Op = aOp;
417
- } else if (auto aOp = rhsMulOp.getOperand (1 ).getDefiningOp <ONNXAddOp>()) {
418
- x = rhsMulOp.getOperand (0 );
419
- add1Op = aOp;
420
- } else
421
- return reportFailure (" missing 0.5 * (x * (+))" );
422
- } else if (lhsMulOp && rhsIsCst) {
423
- // - (x * (+)) * 0.5
424
- if (auto aOp = lhsMulOp.getOperand (0 ).getDefiningOp <ONNXAddOp>()) {
425
- x = lhsMulOp.getOperand (1 );
426
- add1Op = aOp;
427
- } else if (auto aOp = lhsMulOp.getOperand (1 ).getDefiningOp <ONNXAddOp>()) {
428
- x = lhsMulOp.getOperand (0 );
429
- add1Op = aOp;
430
- } else
431
- return reportFailure (" missing (x * (+)) * 0.5" );
432
- } else if (lhsMulOp && rhsAddOp) {
433
- // - (0.5 * x) * (+)
434
- if (constOf (lhsMulOp.getOperand (0 ), 0.5 )) {
435
- x = lhsMulOp.getOperand (1 );
436
- add1Op = rhsAddOp;
437
- } else if (constOf (lhsMulOp.getOperand (1 ), 0.5 )) {
438
- x = lhsMulOp.getOperand (0 );
439
- add1Op = rhsAddOp;
440
- } else
441
- return reportFailure (" missing (0.5 * x) * (+)" );
442
- } else if (lhsAddOp && rhsMulOp) {
443
- // - (+) * (0.5 * x)
444
- if (constOf (rhsMulOp.getOperand (0 ), 0.5 )) {
445
- x = rhsMulOp.getOperand (1 );
446
- add1Op = lhsAddOp;
447
- } else if (constOf (lhsMulOp.getOperand (1 ), 0.5 )) {
448
- x = rhsMulOp.getOperand (0 );
449
- add1Op = lhsAddOp;
450
- } else
451
- return reportFailure (" missing (+) * (0.5 * x)" );
452
- } else if (rhsMulOp) {
453
- // - (0.5 * (+)) * x
454
- if (matchConstAndOp<ONNXAddOp>(
455
- rhsMulOp.getOperand (0 ), rhsMulOp.getOperand (1 ), 0.5 , add1Op))
456
- x = lhs;
457
- else
458
- return reportFailure (" missing (0.5 * (+)) * x" );
459
- } else if (lhsMulOp) {
460
- // - x * (0.5 * (+))
461
- if (matchConstAndOp<ONNXAddOp>(
462
- lhsMulOp.getOperand (0 ), lhsMulOp.getOperand (1 ), 0.5 , add1Op))
463
- x = rhs;
464
- else
465
- return reportFailure (" missing x * (0.5 * (+))" );
466
- } else {
467
- return reportFailure (" missing 0.5 * x * (+)" );
402
+
403
+ Value fstMulVal, sndMulVal;
404
+ bool foundHalf = false ;
405
+
406
+ ONNXMulOp innerMulOp;
407
+ if (matchConstAndOp<ONNXMulOp>(lhs, rhs, 0.5 , innerMulOp)) {
408
+ // - 0.5 * (a * b) or (a * b) * 0.5
409
+ fstMulVal = innerMulOp.getOperand (0 );
410
+ sndMulVal = innerMulOp.getOperand (1 );
411
+ foundHalf = true ;
412
+ }
413
+ if (!foundHalf && !constOf (lhs, 0.5 ) && !constOf (rhs, 0.5 )) {
414
+ if (auto lhsMulOp = lhs.getDefiningOp <ONNXMulOp>()) {
415
+ // - (0.5 * a) * b
416
+ Value l = lhsMulOp.getOperand (0 );
417
+ Value r = lhsMulOp.getOperand (1 );
418
+ if (constOf (l, 0.5 )) {
419
+ fstMulVal = r;
420
+ sndMulVal = rhs;
421
+ foundHalf = true ;
422
+ } else if (constOf (r, 0.5 )) {
423
+ fstMulVal = l;
424
+ sndMulVal = rhs;
425
+ foundHalf = true ;
426
+ }
427
+ }
428
+ if (!foundHalf) {
429
+ if (auto rhsMulOp = rhs.getDefiningOp <ONNXMulOp>()) {
430
+ // - b * (0.5 * a)
431
+ Value l = rhsMulOp.getOperand (0 );
432
+ Value r = rhsMulOp.getOperand (1 );
433
+ if (constOf (l, 0.5 )) {
434
+ fstMulVal = lhs;
435
+ sndMulVal = r;
436
+ foundHalf = true ;
437
+ } else if (constOf (r, 0.5 )) {
438
+ fstMulVal = lhs;
439
+ sndMulVal = l;
440
+ foundHalf = true ;
441
+ }
442
+ }
443
+ }
468
444
}
445
+ if (!foundHalf)
446
+ return reportFailure (" missing 0.5 * a * b" );
469
447
470
448
// Exact gelu.
471
449
// Match 1 + erf()
450
+ bool foundErf = false ;
472
451
ONNXErfOp erfOp;
473
- if (matchConstAndOp<ONNXErfOp>(
474
- add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , erfOp)) {
452
+ // Try the first operand.
453
+ if (auto add1Op = fstMulVal.getDefiningOp <ONNXAddOp>()) {
454
+ foundErf = matchConstAndOp<ONNXErfOp>(
455
+ add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , erfOp);
456
+ if (foundErf)
457
+ x = sndMulVal;
458
+ }
459
+ if (!foundErf) {
460
+ // Try the second operand.
461
+ if (auto add1Op = sndMulVal.getDefiningOp <ONNXAddOp>()) {
462
+ foundErf = matchConstAndOp<ONNXErfOp>(
463
+ add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , erfOp);
464
+ if (foundErf)
465
+ x = fstMulVal;
466
+ }
467
+ }
468
+ if (foundErf) {
475
469
// gelu(x) = 0.5 * x * (1 + erf(x/1.41421354))
476
470
Value erfInput = erfOp.getOperand ();
477
471
auto divOp = erfInput.getDefiningOp <ONNXDivOp>();
478
472
if (!divOp)
479
473
return reportFailure (" [Exact] missing div op" );
474
+ if (divOp.getOperand (0 ) != x)
475
+ return reportFailure (" [Exact] missing x in x/1.41421354" );
480
476
if (!constOf (divOp.getOperand (1 ), 1.41421354 ))
481
477
return reportFailure (" [Exact] missing 1.41421354" );
482
478
isExactGelu = true ;
483
479
return true ;
484
480
} else {
485
481
// Do not return here, we still check the approximate case.
486
- reportFailure (" [Exact] missing erf op " );
482
+ reportFailure (" [Exact] missing (1 + erf) " );
487
483
}
488
484
489
485
// Approximate gelu.
490
486
// gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)])
491
487
// Match 1 + tanh()
488
+ bool foundTanh = false ;
492
489
ONNXTanhOp tanhOp;
493
- if (!matchConstAndOp<ONNXTanhOp>(
494
- add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , tanhOp))
495
- return reportFailure (" [Approximate] missing tanh op" );
490
+ // Try the first operand.
491
+ if (auto add1Op = fstMulVal.getDefiningOp <ONNXAddOp>()) {
492
+ foundTanh = matchConstAndOp<ONNXTanhOp>(
493
+ add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , tanhOp);
494
+ if (foundTanh)
495
+ x = sndMulVal;
496
+ }
497
+ if (!foundTanh) {
498
+ // Try the second operand.
499
+ if (auto add1Op = sndMulVal.getDefiningOp <ONNXAddOp>()) {
500
+ foundTanh = matchConstAndOp<ONNXTanhOp>(
501
+ add1Op.getOperand (0 ), add1Op.getOperand (1 ), 1.0 , tanhOp);
502
+ if (foundTanh)
503
+ x = fstMulVal;
504
+ }
505
+ }
506
+ if (!foundTanh)
507
+ return reportFailure (" [Approximate] missing (1 + tanh)" );
496
508
497
509
// Match 0.797884583 * (x + 0.044715 * x^3)
498
510
auto mul1Op = tanhOp.getOperand ().getDefiningOp <ONNXMulOp>();
@@ -633,14 +645,15 @@ void RecomposeONNXToONNXPass::runOnOperation() {
633
645
FloatAttr epsilon;
634
646
int64_t axis;
635
647
bool isRMSLayerNorm;
636
- bool rewriteToLayerNorm =
637
- RecomposeLayerNormFromMulPattern::matchLayerNormPattern (
638
- op, x, scale, axis, epsilon, isRMSLayerNorm) ;
648
+ if ( RecomposeLayerNormFromMulPattern::matchLayerNormPattern (
649
+ op, x, scale, axis, epsilon, isRMSLayerNorm))
650
+ return false ;
639
651
640
652
bool isExactGelu;
641
- bool rewriteToGelu =
642
- RecomposeGeluFromMulPattern::matchGeluPattern (op, x, isExactGelu);
643
- return (!(rewriteToLayerNorm || rewriteToGelu));
653
+ if (RecomposeGeluFromMulPattern::matchGeluPattern (op, x, isExactGelu))
654
+ return false ;
655
+
656
+ return true ;
644
657
});
645
658
646
659
// Recompose QLinearMatMul, starting from QuantizeLinear.
0 commit comments