Skip to content

Commit 869f152

Browse files
authored
A pttern to simplify WhereOp (#2818)
Signed-off-by: Tung D. Le <[email protected]>
1 parent 733dfac commit 869f152

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

src/Dialect/ONNX/ONNXOps.td.inc

+1
Original file line numberDiff line numberDiff line change
@@ -10029,6 +10029,7 @@ def ONNXUpsampleV7Op:ONNX_Op<"UpsampleV7",
1002910029

1003010030
def ONNXWhereOp:ONNX_Op<"Where",
1003110031
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
10032+
let hasCanonicalizer = 1;
1003210033
let summary = "ONNX Where operation";
1003310034
let description = [{
1003410035
Return elements, either from X or Y, depending on condition.

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,56 @@ bool haveSameStaticShape(Value lhs, Value rhs) {
198198
return hasStaticShape(lhsT) && (getShape(lhsT) == getShape(rhsT));
199199
}
200200

201+
/// Test if the input is a splat constant with a negative value or not.
202+
bool isNegativeSplatConstant(Value val) {
203+
if (!isDenseONNXConstant(val))
204+
return false;
205+
ONNXConstantOp constOp = val.getDefiningOp<ONNXConstantOp>();
206+
auto valAttr =
207+
llvm::dyn_cast_or_null<DenseElementsAttr>(constOp.getValueAttr());
208+
if (!valAttr)
209+
return false;
210+
211+
if (!valAttr.isSplat())
212+
return false;
213+
214+
Type elemTy = val.getType().cast<ShapedType>().getElementType();
215+
if (elemTy.isa<FloatType>()) {
216+
double v = valAttr.getSplatValue<double>();
217+
return (v < 0.0);
218+
} else if (elemTy.isa<IntegerType>()) {
219+
int64_t v = valAttr.getSplatValue<int64_t>();
220+
return (v < 0);
221+
}
222+
return false;
223+
}
224+
225+
/// Test if all values in the input ValueRange are dimension sizes.
226+
bool areAllDimSizes(ValueRange vals) {
227+
return llvm::all_of(vals, [](Value val) {
228+
// Block arguments.
229+
if (val.isa<BlockArgument>())
230+
return false;
231+
// Defined by DimOp.
232+
if (val.getDefiningOp<ONNXDimOp>())
233+
return true;
234+
// Defined by ConstantOp.
235+
if (isDenseONNXConstant(val) && isScalarTensor(val)) {
236+
Type elemTy = val.getType().cast<ShapedType>().getElementType();
237+
if (!elemTy.isa<IntegerType>())
238+
return false;
239+
ONNXConstantOp constOp = val.getDefiningOp<ONNXConstantOp>();
240+
auto valAttr =
241+
llvm::dyn_cast_or_null<DenseElementsAttr>(constOp.getValueAttr());
242+
if (!valAttr)
243+
return false;
244+
int64_t v = (*valAttr.getValues<APInt>().begin()).getSExtValue();
245+
return (v > 0);
246+
}
247+
return false;
248+
});
249+
}
250+
201251
// Match v = shape_transform(X*A + B).
202252
// shape_transform is a sequence of operations like Reshape, Transpose,
203253
// Squeeze, Unsqueeze, etc. that do not change the numerical values by data
@@ -1799,3 +1849,9 @@ void ONNXXorOp::getCanonicalizationPatterns(
17991849
RewritePatternSet &result, MLIRContext *context) {
18001850
result.insert<BinaryOpBroadcastAxisPattern<ONNXXorOp>>(context);
18011851
}
1852+
1853+
// on the ONNXWhereOp.
1854+
void ONNXWhereOp::getCanonicalizationPatterns(
1855+
RewritePatternSet &result, MLIRContext *context) {
1856+
result.insert<AlwaysFalseWherePattern>(context);
1857+
}

src/Dialect/ONNX/ONNXOps/Canonicalize.td

+27
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ def IsFromONNXConstantOpWithDenseElementsAttr: Constraint<
205205
CPred<" isa<DenseElementsAttr>(onnx_mlir::getONNXConstantOp($_self).getValueAttr()) ">
206206
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">;
207207

208+
def IsNegativeSplatConstant: Constraint<
209+
CPred<"onnx_mlir::isNegativeSplatConstant($_self)">,
210+
"Is a splat constant with a negative value."
211+
>;
212+
213+
def AreAllDimSizes: Constraint<
214+
CPred<"onnx_mlir::areAllDimSizes($_self)">,
215+
"All values in the input ValueRange are dimension sizes."
216+
>;
217+
208218
def AreTheSameAxesConstant: Constraint<
209219
CPred<"onnx_mlir::AreTheSameAxesConstant("
210220
"(onnx_mlir::hasShapeAndRank($0) ? $0.getType().cast<ShapedType>().getRank() : 0),"
@@ -1024,4 +1034,21 @@ def ShapeTransformComposePattern : Pat<
10241034
[]
10251035
>;
10261036

1037+
//===----------------------------------------------------------------------===//
1038+
// Canonicalization for ONNXWhere
1039+
//===----------------------------------------------------------------------===//
1040+
1041+
// In this pattern, the condition in onnx.Where is always false, so we can replace
1042+
// onnx.Where by its "false" value.
1043+
// Condition in this pattern is a comparision between dimension sizes and negative values.
1044+
// Since dimension sizes are always positive, the condition is evaluated to false.
1045+
1046+
// This pattern was found in xlm-roberta-base-language-detection model in HuggingFace.
1047+
1048+
def AlwaysFalseWherePattern : Pat<
1049+
(ONNXWhereOp (ONNXEqualOp (ONNXConcatOp $dims, $_), $negative_constant), $true_val, $false_val),
1050+
(replaceWithValue $false_val),
1051+
[(IsNegativeSplatConstant:$negative_constant), (AreAllDimSizes:$dims)]
1052+
>;
1053+
10271054
#endif // ONNX_REWRITE

test/mlir/onnx/onnx_canonicalization.mlir

+72
Original file line numberDiff line numberDiff line change
@@ -1753,3 +1753,75 @@ func.func @test_mul_in_attention(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x7
17531753
// CHECK: onnx.Return [[VAR_21_]] : tensor<?x12x?x?xf32>
17541754
// CHECK: }
17551755
}
1756+
1757+
// -----
1758+
1759+
// Canonicalize WhereOp whose condition is always false.
1760+
// This pattern was found in the model xlm-roberta-base-language-detection in HuggingFace.
1761+
func.func @test_where_with_always_false_1(%arg0: tensor<?x?xi64>) -> tensor<2xi64> {
1762+
%0 = onnx.Constant dense<-1> : tensor<2xi64>
1763+
%1 = onnx.Constant dense<1> : tensor<2xi64>
1764+
%2 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1765+
%3 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1766+
%4 = "onnx.Concat"(%2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1767+
%5 = "onnx.Equal"(%4, %0) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
1768+
%6 = "onnx.Where"(%5, %1, %4) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
1769+
onnx.Return %6 : tensor<2xi64>
1770+
1771+
// CHECK-LABEL: func.func @test_where_with_always_false_1
1772+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1773+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1774+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1775+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1776+
// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
1777+
// CHECK: }
1778+
}
1779+
1780+
// -----
1781+
1782+
// Mix of DimOp and ConstantOp.
1783+
func.func @test_where_with_always_false_2(%arg0: tensor<?x?xi64>) -> tensor<2xi64> {
1784+
%0 = onnx.Constant dense<-1> : tensor<2xi64>
1785+
%1 = onnx.Constant dense<1> : tensor<2xi64>
1786+
%2 = onnx.Constant dense<2> : tensor<1xi64>
1787+
%3 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1788+
%4 = "onnx.Concat"(%2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1789+
%5 = "onnx.Equal"(%4, %0) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
1790+
%6 = "onnx.Where"(%5, %1, %4) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
1791+
onnx.Return %6 : tensor<2xi64>
1792+
1793+
// CHECK-LABEL: func.func @test_where_with_always_false_2
1794+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1795+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64>
1796+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1797+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1798+
// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
1799+
// CHECK: }
1800+
}
1801+
1802+
// -----
1803+
1804+
// Mix of DimOp and ConstantOp but the constant is negative, so cannot guarantee the false condition in WhereOp.
1805+
// No rewrite happened.
1806+
func.func @test_where_with_always_false_3(%arg0: tensor<?x?xi64>) -> tensor<2xi64> {
1807+
%0 = onnx.Constant dense<-1> : tensor<2xi64>
1808+
%1 = onnx.Constant dense<1> : tensor<2xi64>
1809+
%2 = onnx.Constant dense<-2> : tensor<1xi64>
1810+
%3 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1811+
%4 = "onnx.Concat"(%2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1812+
%5 = "onnx.Equal"(%4, %0) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
1813+
%6 = "onnx.Where"(%5, %1, %4) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
1814+
onnx.Return %6 : tensor<2xi64>
1815+
1816+
// CHECK-LABEL: func.func @test_where_with_always_false_3
1817+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1818+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-1> : tensor<2xi64>
1819+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64>
1820+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64>
1821+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1822+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1823+
// CHECK: [[VAR_5_:%.+]] = "onnx.Equal"([[VAR_4_]], [[VAR_0_]]) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
1824+
// CHECK: [[VAR_6_:%.+]] = "onnx.Where"([[VAR_5_]], [[VAR_1_]], [[VAR_4_]]) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
1825+
// CHECK: onnx.Return [[VAR_6_]] : tensor<2xi64>
1826+
// CHECK: }
1827+
}

utils/gen_onnx_mlir.py

+1
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@
360360
"Transpose",
361361
"Unsqueeze",
362362
"UnsqueezeV11",
363+
"Where",
363364
"Xor",
364365
]
365366

0 commit comments

Comments
 (0)