@@ -1753,3 +1753,75 @@ func.func @test_mul_in_attention(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x7
1753
1753
// CHECK: onnx.Return [[VAR_21_]] : tensor<?x12x?x?xf32>
1754
1754
// CHECK: }
1755
1755
}
1756
+
1757
+ // -----
1758
+
1759
+ // Canonicalize WhereOp whose condition is always false.
1760
+ // This pattern was found in the model xlm-roberta-base-language-detection in HuggingFace.
1761
+ func.func @test_where_with_always_false_1 (%arg0: tensor <?x?xi64 >) -> tensor <2 xi64 > {
1762
+ %0 = onnx.Constant dense <-1 > : tensor <2 xi64 >
1763
+ %1 = onnx.Constant dense <1 > : tensor <2 xi64 >
1764
+ %2 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?xi64 >) -> tensor <1 xi64 >
1765
+ %3 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?xi64 >) -> tensor <1 xi64 >
1766
+ %4 = " onnx.Concat" (%2 , %3 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <2 xi64 >
1767
+ %5 = " onnx.Equal" (%4 , %0 ) : (tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi1 >
1768
+ %6 = " onnx.Where" (%5 , %1 , %4 ) : (tensor <2 xi1 >, tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi64 >
1769
+ onnx.Return %6 : tensor <2 xi64 >
1770
+
1771
+ // CHECK-LABEL: func.func @test_where_with_always_false_1
1772
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1773
+ // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1774
+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1775
+ // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1776
+ // CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
1777
+ // CHECK: }
1778
+ }
1779
+
1780
+ // -----
1781
+
1782
+ // Mix of DimOp and ConstantOp.
1783
+ func.func @test_where_with_always_false_2 (%arg0: tensor <?x?xi64 >) -> tensor <2 xi64 > {
1784
+ %0 = onnx.Constant dense <-1 > : tensor <2 xi64 >
1785
+ %1 = onnx.Constant dense <1 > : tensor <2 xi64 >
1786
+ %2 = onnx.Constant dense <2 > : tensor <1 xi64 >
1787
+ %3 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?xi64 >) -> tensor <1 xi64 >
1788
+ %4 = " onnx.Concat" (%2 , %3 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <2 xi64 >
1789
+ %5 = " onnx.Equal" (%4 , %0 ) : (tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi1 >
1790
+ %6 = " onnx.Where" (%5 , %1 , %4 ) : (tensor <2 xi1 >, tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi64 >
1791
+ onnx.Return %6 : tensor <2 xi64 >
1792
+
1793
+ // CHECK-LABEL: func.func @test_where_with_always_false_2
1794
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1795
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64>
1796
+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1797
+ // CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1798
+ // CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
1799
+ // CHECK: }
1800
+ }
1801
+
1802
+ // -----
1803
+
1804
+ // Mix of DimOp and ConstantOp but the constant is negative, so cannot guarantee the false condition in WhereOp.
1805
+ // No rewrite happened.
1806
+ func.func @test_where_with_always_false_3 (%arg0: tensor <?x?xi64 >) -> tensor <2 xi64 > {
1807
+ %0 = onnx.Constant dense <-1 > : tensor <2 xi64 >
1808
+ %1 = onnx.Constant dense <1 > : tensor <2 xi64 >
1809
+ %2 = onnx.Constant dense <-2 > : tensor <1 xi64 >
1810
+ %3 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?xi64 >) -> tensor <1 xi64 >
1811
+ %4 = " onnx.Concat" (%2 , %3 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <2 xi64 >
1812
+ %5 = " onnx.Equal" (%4 , %0 ) : (tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi1 >
1813
+ %6 = " onnx.Where" (%5 , %1 , %4 ) : (tensor <2 xi1 >, tensor <2 xi64 >, tensor <2 xi64 >) -> tensor <2 xi64 >
1814
+ onnx.Return %6 : tensor <2 xi64 >
1815
+
1816
+ // CHECK-LABEL: func.func @test_where_with_always_false_3
1817
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xi64>) -> tensor<2xi64> {
1818
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-1> : tensor<2xi64>
1819
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<2xi64>
1820
+ // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64>
1821
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?xi64>) -> tensor<1xi64>
1822
+ // CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
1823
+ // CHECK: [[VAR_5_:%.+]] = "onnx.Equal"([[VAR_4_]], [[VAR_0_]]) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1>
1824
+ // CHECK: [[VAR_6_:%.+]] = "onnx.Where"([[VAR_5_]], [[VAR_1_]], [[VAR_4_]]) : (tensor<2xi1>, tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64>
1825
+ // CHECK: onnx.Return [[VAR_6_]] : tensor<2xi64>
1826
+ // CHECK: }
1827
+ }
0 commit comments