@@ -900,6 +900,104 @@ onnx.Return %0 : tensor<*xf16>
900
900
// CHECK: onnx.Return [[VAR_1_]] : tensor<4x?x3xf16>
901
901
// CHECK: }
902
902
903
+ // -----
904
+
905
+ func.func @test_reshape_dim (%arg0: tensor <?x?x2048 xf32 >) -> tensor <?x?x?x64 xf32 > {
906
+ %1 = onnx.Constant dense <64 > : tensor <1 xi64 >
907
+ %2 = onnx.Constant dense <-1 > : tensor <1 xi64 >
908
+ %3 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
909
+ %4 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
910
+ %5 = " onnx.Concat" (%3 , %4 , %2 , %1 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <4 xi64 >
911
+ %6 = " onnx.Reshape" (%arg0 , %5 ) {allowzero = 0 : si64 } : (tensor <?x?x2048 xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x64 xf32 >
912
+ return %6 : tensor <?x?x?x64 xf32 >
913
+
914
+ // CHECK-LABEL: func.func @test_reshape_dim
915
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x32x64xf32> {
916
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
917
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
918
+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
919
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
920
+ // CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]], [[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
921
+ // CHECK: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_4_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x32x64xf32>
922
+ // CHECK: return [[VAR_5_]] : tensor<?x?x32x64xf32>
923
+ // CHECK: }
924
+ }
925
+
926
+ // -----
927
+
928
+ func.func @test_reshape_dim_bijective_at_last_dim (%arg0: tensor <?x?x2048 xf32 >) -> tensor <?x?x64 x?xf32 > {
929
+ %1 = onnx.Constant dense <64 > : tensor <1 xi64 >
930
+ %2 = onnx.Constant dense <-1 > : tensor <1 xi64 >
931
+ %3 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
932
+ %4 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
933
+ %5 = " onnx.Concat" (%4 , %2 , %1 , %3 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <4 xi64 >
934
+ %6 = " onnx.Reshape" (%arg0 , %5 ) {allowzero = 0 : si64 } : (tensor <?x?x2048 xf32 >, tensor <4 xi64 >) -> tensor <?x?x64 x?xf32 >
935
+ return %6 : tensor <?x?x64 x?xf32 >
936
+
937
+ // CHECK-LABEL: func.func @test_reshape_dim_bijective_at_last_dim
938
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x32x64x?xf32> {
939
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
940
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
941
+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
942
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
943
+ // CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_1_]], [[VAR_0_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
944
+ // CHECK: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_4_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x32x64x?xf32>
945
+ // CHECK: return [[VAR_5_]] : tensor<?x32x64x?xf32>
946
+ // CHECK: }
947
+ }
948
+
949
+ // -----
950
+
951
+ // COM: This pattern is found in the IBM granite-3.1-2b-instruct model.
952
+ func.func @test_reshape_matmul_dim (%arg0: tensor <?x?x2048 xf32 >) -> tensor <?x?x?x64 xf32 > {
953
+ %0 = onnx.Constant dense <1.000000e+00 > : tensor <2048 x2048 xf32 >
954
+ %1 = onnx.Constant dense <64 > : tensor <1 xi64 >
955
+ %2 = onnx.Constant dense <-1 > : tensor <1 xi64 >
956
+ %3 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
957
+ %4 = " onnx.Dim" (%arg0 ) {axis = 1 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
958
+ %5 = " onnx.MatMul" (%arg0 , %0 ) : (tensor <?x?x2048 xf32 >, tensor <2048 x2048 xf32 >) -> tensor <?x?x2048 xf32 >
959
+ %6 = " onnx.Concat" (%3 , %4 , %2 , %1 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <4 xi64 >
960
+ %7 = " onnx.Reshape" (%5 , %6 ) {allowzero = 0 : si64 } : (tensor <?x?x2048 xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x64 xf32 >
961
+ return %7 : tensor <?x?x?x64 xf32 >
962
+
963
+ // CHECK-LABEL: func.func @test_reshape_matmul_dim
964
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x32x64xf32> {
965
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<2048x2048xf32>
966
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
967
+ // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
968
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
969
+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
970
+ // CHECK-NOT: separator of consecutive DAGs
971
+ // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<?x?x2048xf32>, tensor<2048x2048xf32>) -> tensor<?x?x2048xf32>
972
+ // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
973
+ // CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[VAR_5_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x32x64xf32>
974
+ // CHECK: return [[VAR_7_]] : tensor<?x?x32x64xf32>
975
+ // CHECK: }
976
+ }
977
+
978
+ // -----
979
+
980
+ func.func @test_reshape_dim_not_bijection (%arg0: tensor <?x?x2048 xf32 >) -> tensor <?x?x?x64 xf32 > {
981
+ %1 = onnx.Constant dense <64 > : tensor <1 xi64 >
982
+ %2 = onnx.Constant dense <-1 > : tensor <1 xi64 >
983
+ %3 = " onnx.Dim" (%arg0 ) {axis = 0 : si64 } : (tensor <?x?x2048 xf32 >) -> tensor <1 xi64 >
984
+ %4 = " onnx.Concat" (%3 , %3 , %2 , %1 ) {axis = 0 : si64 } : (tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >, tensor <1 xi64 >) -> tensor <4 xi64 >
985
+ %5 = " onnx.Reshape" (%arg0 , %4 ) {allowzero = 0 : si64 } : (tensor <?x?x2048 xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x64 xf32 >
986
+ return %5 : tensor <?x?x?x64 xf32 >
987
+
988
+ // CHECK-LABEL: func.func @test_reshape_dim_not_bijection
989
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x?x64xf32> {
990
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
991
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
992
+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
993
+ // CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
994
+ // CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
995
+ // CHECK: return [[VAR_4_]] : tensor<?x?x?x64xf32>
996
+ // CHECK: }
997
+ }
998
+
999
+ // -----
1000
+
903
1001
//===----------------------------------------------------------------------===//
904
1002
/// Test the flatten op inference.
905
1003
//===----------------------------------------------------------------------===//
@@ -3910,4 +4008,4 @@ func.func @test_grid_sample_dim_shape3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor
3910
4008
// CHECK: return [[GRID]] : tensor<?x?x10x20xf32>
3911
4009
// CHECK: }
3912
4010
return %0 : tensor <*xf32 >
3913
- }
4011
+ }
0 commit comments