Skip to content

Commit 3a9cf3c

Browse files
committed
Better pattern check and add one more lit test
Signed-off-by: Tung D. Le <[email protected]>
1 parent 5202bac commit 3a9cf3c

File tree

2 files changed

+121
-84
lines changed

2 files changed

+121
-84
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 97 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -391,108 +391,120 @@ struct RecomposeGeluFromMulPattern : public OpRewritePattern<ONNXMulOp> {
391391
return isDenseONNXConstant(v) && isConstOf(v, n);
392392
};
393393

394-
// Match A * B * C: 0.5 * x * (+)
395-
// Three associative cases depending on the order of Mul:
396-
// - (0.5 * x) * (+)
397-
// - 0.5 * (x * (+))
398-
// - (0.5 * (+)) * x
399-
// For each case, we have two communitive cases.
400-
// In total, we handle 6 cases.
394+
// Match 0.5 * a * b
395+
// Two associative cases depending on which Mul 0.5 belongs to:
396+
// - 0.5 * (a * b)
397+
// - (0.5 * a) * b
398+
// For each case, we have two communitive cases for the outer Mul (not count
399+
// the inner Mul). In total, we handle 4 cases.
401400
Value lhs = mulOp.getOperand(0);
402401
Value rhs = mulOp.getOperand(1);
403-
auto lhsMulOp = lhs.getDefiningOp<ONNXMulOp>();
404-
auto rhsMulOp = rhs.getDefiningOp<ONNXMulOp>();
405-
auto lhsAddOp = lhs.getDefiningOp<ONNXAddOp>();
406-
auto rhsAddOp = rhs.getDefiningOp<ONNXAddOp>();
407-
bool lhsIsCst = constOf(lhs, 0.5);
408-
bool rhsIsCst = constOf(rhs, 0.5);
409-
410-
// Match to get AddOp and x from 0.5 * x * (+).
411-
ONNXAddOp add1Op;
412-
if (lhsIsCst && rhsMulOp) {
413-
// - 0.5 * (x * (+))
414-
if (auto aOp = rhsMulOp.getOperand(0).getDefiningOp<ONNXAddOp>()) {
415-
x = rhsMulOp.getOperand(1);
416-
add1Op = aOp;
417-
} else if (auto aOp = rhsMulOp.getOperand(1).getDefiningOp<ONNXAddOp>()) {
418-
x = rhsMulOp.getOperand(0);
419-
add1Op = aOp;
420-
} else
421-
return reportFailure("missing 0.5 * (x * (+))");
422-
} else if (lhsMulOp && rhsIsCst) {
423-
// - (x * (+)) * 0.5
424-
if (auto aOp = lhsMulOp.getOperand(0).getDefiningOp<ONNXAddOp>()) {
425-
x = lhsMulOp.getOperand(1);
426-
add1Op = aOp;
427-
} else if (auto aOp = lhsMulOp.getOperand(1).getDefiningOp<ONNXAddOp>()) {
428-
x = lhsMulOp.getOperand(0);
429-
add1Op = aOp;
430-
} else
431-
return reportFailure("missing (x * (+)) * 0.5");
432-
} else if (lhsMulOp && rhsAddOp) {
433-
// - (0.5 * x) * (+)
434-
if (constOf(lhsMulOp.getOperand(0), 0.5)) {
435-
x = lhsMulOp.getOperand(1);
436-
add1Op = rhsAddOp;
437-
} else if (constOf(lhsMulOp.getOperand(1), 0.5)) {
438-
x = lhsMulOp.getOperand(0);
439-
add1Op = rhsAddOp;
440-
} else
441-
return reportFailure("missing (0.5 * x) * (+)");
442-
} else if (lhsAddOp && rhsMulOp) {
443-
// - (+) * (0.5 * x)
444-
if (constOf(rhsMulOp.getOperand(0), 0.5)) {
445-
x = rhsMulOp.getOperand(1);
446-
add1Op = lhsAddOp;
447-
} else if (constOf(lhsMulOp.getOperand(1), 0.5)) {
448-
x = rhsMulOp.getOperand(0);
449-
add1Op = lhsAddOp;
450-
} else
451-
return reportFailure("missing (+) * (0.5 * x)");
452-
} else if (rhsMulOp) {
453-
// - (0.5 * (+)) * x
454-
if (matchConstAndOp<ONNXAddOp>(
455-
rhsMulOp.getOperand(0), rhsMulOp.getOperand(1), 0.5, add1Op))
456-
x = lhs;
457-
else
458-
return reportFailure("missing (0.5 * (+)) * x");
459-
} else if (lhsMulOp) {
460-
// - x * (0.5 * (+))
461-
if (matchConstAndOp<ONNXAddOp>(
462-
lhsMulOp.getOperand(0), lhsMulOp.getOperand(1), 0.5, add1Op))
463-
x = rhs;
464-
else
465-
return reportFailure("missing x * (0.5 * (+))");
466-
} else {
467-
return reportFailure("missing 0.5 * x * (+)");
402+
403+
Value fstMulVal, sndMulVal;
404+
bool foundHalf = false;
405+
406+
ONNXMulOp innerMulOp;
407+
if (matchConstAndOp<ONNXMulOp>(lhs, rhs, 0.5, innerMulOp)) {
408+
// - 0.5 * (a * b) or (a * b) * 0.5
409+
fstMulVal = innerMulOp.getOperand(0);
410+
sndMulVal = innerMulOp.getOperand(1);
411+
foundHalf = true;
412+
}
413+
if (!foundHalf && !constOf(lhs, 0.5) && !constOf(rhs, 0.5)) {
414+
if (auto lhsMulOp = lhs.getDefiningOp<ONNXMulOp>()) {
415+
// - (0.5 * a) * b
416+
Value l = lhsMulOp.getOperand(0);
417+
Value r = lhsMulOp.getOperand(1);
418+
if (constOf(l, 0.5)) {
419+
fstMulVal = r;
420+
sndMulVal = rhs;
421+
foundHalf = true;
422+
} else if (constOf(r, 0.5)) {
423+
fstMulVal = l;
424+
sndMulVal = rhs;
425+
foundHalf = true;
426+
}
427+
}
428+
if (!foundHalf) {
429+
if (auto rhsMulOp = rhs.getDefiningOp<ONNXMulOp>()) {
430+
// - b * (0.5 * a)
431+
Value l = rhsMulOp.getOperand(0);
432+
Value r = rhsMulOp.getOperand(1);
433+
if (constOf(l, 0.5)) {
434+
fstMulVal = lhs;
435+
sndMulVal = r;
436+
foundHalf = true;
437+
} else if (constOf(r, 0.5)) {
438+
fstMulVal = lhs;
439+
sndMulVal = l;
440+
foundHalf = true;
441+
}
442+
}
443+
}
468444
}
445+
if (!foundHalf)
446+
return reportFailure("missing 0.5 * a * b");
469447

470448
// Exact gelu.
471449
// Match 1 + erf()
450+
bool foundErf = false;
472451
ONNXErfOp erfOp;
473-
if (matchConstAndOp<ONNXErfOp>(
474-
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp)) {
452+
// Try the first operand.
453+
if (auto add1Op = fstMulVal.getDefiningOp<ONNXAddOp>()) {
454+
foundErf = matchConstAndOp<ONNXErfOp>(
455+
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp);
456+
if (foundErf)
457+
x = sndMulVal;
458+
}
459+
if (!foundErf) {
460+
// Try the second operand.
461+
if (auto add1Op = sndMulVal.getDefiningOp<ONNXAddOp>()) {
462+
foundErf = matchConstAndOp<ONNXErfOp>(
463+
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp);
464+
if (foundErf)
465+
x = fstMulVal;
466+
}
467+
}
468+
if (foundErf) {
475469
// gelu(x) = 0.5 * x * (1 + erf(x/1.41421354))
476470
Value erfInput = erfOp.getOperand();
477471
auto divOp = erfInput.getDefiningOp<ONNXDivOp>();
478472
if (!divOp)
479473
return reportFailure("[Exact] missing div op");
474+
if (divOp.getOperand(0) != x)
475+
return reportFailure("[Exact] missing x in x/1.41421354");
480476
if (!constOf(divOp.getOperand(1), 1.41421354))
481477
return reportFailure("[Exact] missing 1.41421354");
482478
isExactGelu = true;
483479
return true;
484480
} else {
485481
// Do not return here, we still check the approximate case.
486-
reportFailure("[Exact] missing erf op");
482+
reportFailure("[Exact] missing (1 + erf)");
487483
}
488484

489485
// Approximate gelu.
490486
// gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)])
491487
// Match 1 + tanh()
488+
bool foundTanh = false;
492489
ONNXTanhOp tanhOp;
493-
if (!matchConstAndOp<ONNXTanhOp>(
494-
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp))
495-
return reportFailure("[Approximate] missing tanh op");
490+
// Try the first operand.
491+
if (auto add1Op = fstMulVal.getDefiningOp<ONNXAddOp>()) {
492+
foundTanh = matchConstAndOp<ONNXTanhOp>(
493+
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp);
494+
if (foundTanh)
495+
x = sndMulVal;
496+
}
497+
if (!foundTanh) {
498+
// Try the second operand.
499+
if (auto add1Op = sndMulVal.getDefiningOp<ONNXAddOp>()) {
500+
foundTanh = matchConstAndOp<ONNXTanhOp>(
501+
add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp);
502+
if (foundTanh)
503+
x = fstMulVal;
504+
}
505+
}
506+
if (!foundTanh)
507+
return reportFailure("[Approximate] missing (1 + tanh)");
496508

497509
// Match 0.797884583 * (x + 0.044715 * x^3)
498510
auto mul1Op = tanhOp.getOperand().getDefiningOp<ONNXMulOp>();
@@ -633,14 +645,15 @@ void RecomposeONNXToONNXPass::runOnOperation() {
633645
FloatAttr epsilon;
634646
int64_t axis;
635647
bool isRMSLayerNorm;
636-
bool rewriteToLayerNorm =
637-
RecomposeLayerNormFromMulPattern::matchLayerNormPattern(
638-
op, x, scale, axis, epsilon, isRMSLayerNorm);
648+
if (RecomposeLayerNormFromMulPattern::matchLayerNormPattern(
649+
op, x, scale, axis, epsilon, isRMSLayerNorm))
650+
return false;
639651

640652
bool isExactGelu;
641-
bool rewriteToGelu =
642-
RecomposeGeluFromMulPattern::matchGeluPattern(op, x, isExactGelu);
643-
return (!(rewriteToLayerNorm || rewriteToGelu));
653+
if (RecomposeGeluFromMulPattern::matchGeluPattern(op, x, isExactGelu))
654+
return false;
655+
656+
return true;
644657
});
645658

646659
// Recompose QLinearMatMul, starting from QuantizeLinear.

test/mlir/onnx/onnx_recompose.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,27 @@ func.func @test_gelu_tanh(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
372372
// CHECK: }
373373
}
374374

375+
// -----
376+
377+
func.func @test_gelu_erf_two_adds(%arg0: tensor<?x?x3072xf32>, %arg1: tensor<3072x768xf32>) -> tensor<?x?x768xf32> {
378+
%0 = onnx.Constant dense<5.000000e-01> : tensor<f32>
379+
%1 = onnx.Constant dense<1.000000e+00> : tensor<f32>
380+
%2 = onnx.Constant dense<1.41421354> : tensor<f32>
381+
%3 = onnx.Constant dense<3.000000e-01> : tensor<3072xf32>
382+
%4 = "onnx.Add"(%arg0, %3) : (tensor<?x?x3072xf32>, tensor<3072xf32>) -> tensor<?x?x3072xf32>
383+
%5 = "onnx.Div"(%4, %2) : (tensor<?x?x3072xf32>, tensor<f32>) -> tensor<?x?x3072xf32>
384+
%6 = "onnx.Erf"(%5) : (tensor<?x?x3072xf32>) -> tensor<?x?x3072xf32>
385+
%7 = "onnx.Add"(%6, %1) : (tensor<?x?x3072xf32>, tensor<f32>) -> tensor<?x?x3072xf32>
386+
%8 = "onnx.Mul"(%4, %7) : (tensor<?x?x3072xf32>, tensor<?x?x3072xf32>) -> tensor<?x?x3072xf32>
387+
%9 = "onnx.Mul"(%8, %0) : (tensor<?x?x3072xf32>, tensor<f32>) -> tensor<?x?x3072xf32>
388+
%10 = "onnx.MatMul"(%9, %arg1) : (tensor<?x?x3072xf32>, tensor<3072x768xf32>) -> tensor<?x?x768xf32>
389+
return %10 : tensor<?x?x768xf32>
390+
}
391+
// CHECK-LABEL: func.func @test_gelu_erf_two_adds
392+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x3072xf32>, [[PARAM_1_:%.+]]: tensor<3072x768xf32>) -> tensor<?x?x768xf32> {
393+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<3.000000e-01> : tensor<3072xf32>
394+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<?x?x3072xf32>, tensor<3072xf32>) -> tensor<?x?x3072xf32>
395+
// CHECK: [[VAR_2_:%.+]] = "onnx.Gelu"([[VAR_1_]]) {approximate = "none"} : (tensor<?x?x3072xf32>) -> tensor<?x?x3072xf32>
396+
// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_2_]], [[PARAM_1_]]) : (tensor<?x?x3072xf32>, tensor<3072x768xf32>) -> tensor<?x?x768xf32>
397+
// CHECK: return [[VAR_3_]] : tensor<?x?x768xf32>
398+
// CHECK: }

0 commit comments

Comments
 (0)