Skip to content

Commit 076492a

Browse files
authored
Enhance shape inference for ONNX Reshape (#3122)
* Add a special case in shape inference for reshape Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent 55e335e commit 076492a

File tree

3 files changed

+177
-5
lines changed

3 files changed

+177
-5
lines changed

src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,83 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() {
4848
// - -1: the output dim is calculated from the other output dims. No more
4949
// than one dim in the output has value -1.
5050

51+
// Shape inference can be simplified if there is a bijection between a set of
52+
// unknown dimensions in data and unknown dimensions in shape. In such a case,
53+
// there is no need to include these unknown dimensions in computing the
54+
// dimension at position of -1, which increases the chance that the dim value
55+
// at position of -1 can be a static value.
56+
//
57+
// For example,
58+
// - data is tensor<1x?x2048xf32>,
59+
// - shape is tensor<4xi64> of [1, dim_1_of_data, -1, 64]
60+
// In this case, the 2nd dimension of data is unknown but it is similar to the
61+
// 2nd value in shape. So to compute the output dim at position of -1, we just
62+
// do 2048/64, that is 32. Without this simplification, the output dim at
63+
// position of -1 would be unknown at compile time.
64+
std::set<int64_t> dataIgnoredDims, outputIgnoredDims;
65+
SmallVector<Value> shapeDimVals;
66+
if (areDimsFromConcat(shape)) {
67+
getDims(shape, shapeDimVals);
68+
Value refData = data;
69+
70+
// Get the input A of MatMul that is the producer of "data" if applicable.
71+
// Special case to handle a pattern in the IBM granite-3.1-2b-instruct
72+
// model. This pattern is found in the IBM granite-3.1-2b-instruct model.
73+
// clang-format off
74+
// %0 = onnx.Constant dense<1.000000e+00> : tensor<2048x2048xf32>
75+
// %1 = onnx.Constant dense<64> : tensor<1xi64>
76+
// %2 = onnx.Constant dense<-1> : tensor<1xi64>
77+
// %3 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
78+
// %4 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
79+
// %5 = "onnx.MatMul"(%arg0, %0) : (tensor<?x?x2048xf32>, tensor<2048x2048xf32>) -> tensor<?x?x2048xf32>
80+
// %6 = "onnx.Concat"(%3, %4, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
81+
// %7 = "onnx.Reshape"(%5, %6) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
82+
// clang-format on
83+
// This is a special handling which is not encouraged to be used widely.
84+
// Since there is no good mechanism to handle this situation in a systematic
85+
// way (e.g. using dynamic dimension analysis), so we handle it here.
86+
ONNXMatMulOp mmOp = data.getDefiningOp<ONNXMatMulOp>();
87+
bool fromMatMul = false;
88+
if (mmOp && isRankedShapedType(mmOp.getB().getType()) &&
89+
getRank(mmOp.getB().getType()) == 2) {
90+
refData = mmOp.getA();
91+
fromMatMul = true;
92+
}
93+
94+
// Find the bijective mapping.
95+
// We do not compute the actual mapping, just storing the source and target
96+
// sets is enough if the map exists.
97+
bool isBijective = true;
98+
for (int64_t i = 0; i < outputRank; ++i) {
99+
Value dim = shapeDimVals[i];
100+
if (auto dimOp = dim.getDefiningOp<ONNXDimOp>()) {
101+
if (dimOp.getData() != refData)
102+
continue;
103+
int64_t axis = dimOp.getAxis();
104+
if (auto search = dataIgnoredDims.find(axis);
105+
search != dataIgnoredDims.end())
106+
isBijective = false;
107+
if (fromMatMul && axis == getRank(refData.getType()) - 1)
108+
isBijective = false;
109+
outputIgnoredDims.insert(i);
110+
dataIgnoredDims.insert(axis);
111+
}
112+
}
113+
if (!isBijective) {
114+
outputIgnoredDims.clear();
115+
dataIgnoredDims.clear();
116+
}
117+
}
118+
51119
// Compute the total number of elements using the input data operand.
52120
// dataRank will be 0 if Data is unranked tensor.
53121
// The number of element will not be computed
54122
IndexExpr numOfElements = LitIE(1);
55-
for (unsigned i = 0; i < dataRank; ++i)
123+
for (unsigned i = 0; i < dataRank; ++i) {
124+
if (auto search = dataIgnoredDims.find(i); search != dataIgnoredDims.end())
125+
continue;
56126
numOfElements = numOfElements * createIE->getShapeAsDim(data, i);
127+
}
57128

58129
// Compute the total number of elements from the shape values.
59130
IndexExpr numOfElementsFromShape = LitIE(1);
@@ -74,6 +145,9 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() {
74145

75146
// dimShape == -1: use 1 to compute the number of elements to avoid
76147
// negative value.
148+
if (auto search = outputIgnoredDims.find(i);
149+
search != outputIgnoredDims.end())
150+
continue;
77151
dim = dim.selectOrSelf(dim == -1, LitIE(1));
78152
numOfElementsFromShape = numOfElementsFromShape * dim;
79153
}

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ func.func @test_reshape_fusion3(%arg0: tensor<?x4x2x2xf32>) -> tensor<?x2x?xf32>
406406
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<2> : tensor<1xi64>
407407
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x4x2x2xf32>) -> tensor<1xi64>
408408
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
409-
// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<?x4x2x2xf32>, tensor<3xi64>) -> tensor<?x2x?xf32>
410-
// CHECK: onnx.Return [[VAR_4_]] : tensor<?x2x?xf32>
409+
// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<?x4x2x2xf32>, tensor<3xi64>) -> tensor<?x2x8xf32>
410+
// CHECK: onnx.Return [[VAR_4_]] : tensor<?x2x8xf32>
411411
// CHECK: }
412412
}
413413

@@ -1952,4 +1952,4 @@ func.func @test_reorder_relu_maxpool(%arg0: tensor<1x64x32x32xf32>) -> tensor<1x
19521952
// CHECK: [[VAR_1_:%.+]] = "onnx.Relu"([[VAR_0_]]) : (tensor<*xf32>) -> tensor<1x64x16x16xf32>
19531953
// CHECK-NEXT: return [[VAR_1_]] : tensor<1x64x16x16xf32>
19541954
// CHECK: }
1955-
}
1955+
}

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,104 @@ onnx.Return %0 : tensor<*xf16>
900900
// CHECK: onnx.Return [[VAR_1_]] : tensor<4x?x3xf16>
901901
// CHECK: }
902902

903+
// -----
904+
905+
func.func @test_reshape_dim(%arg0: tensor<?x?x2048xf32>) -> tensor<?x?x?x64xf32> {
906+
%1 = onnx.Constant dense<64> : tensor<1xi64>
907+
%2 = onnx.Constant dense<-1> : tensor<1xi64>
908+
%3 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
909+
%4 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
910+
%5 = "onnx.Concat"(%3, %4, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
911+
%6 = "onnx.Reshape"(%arg0, %5) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
912+
return %6 : tensor<?x?x?x64xf32>
913+
914+
// CHECK-LABEL: func.func @test_reshape_dim
915+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x32x64xf32> {
916+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
917+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
918+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
919+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
920+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]], [[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
921+
// CHECK: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_4_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x32x64xf32>
922+
// CHECK: return [[VAR_5_]] : tensor<?x?x32x64xf32>
923+
// CHECK: }
924+
}
925+
926+
// -----
927+
928+
func.func @test_reshape_dim_bijective_at_last_dim(%arg0: tensor<?x?x2048xf32>) -> tensor<?x?x64x?xf32> {
929+
%1 = onnx.Constant dense<64> : tensor<1xi64>
930+
%2 = onnx.Constant dense<-1> : tensor<1xi64>
931+
%3 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
932+
%4 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
933+
%5 = "onnx.Concat"(%4, %2, %1, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
934+
%6 = "onnx.Reshape"(%arg0, %5) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x64x?xf32>
935+
return %6 : tensor<?x?x64x?xf32>
936+
937+
// CHECK-LABEL: func.func @test_reshape_dim_bijective_at_last_dim
938+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x32x64x?xf32> {
939+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
940+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
941+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
942+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
943+
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_1_]], [[VAR_0_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
944+
// CHECK: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_4_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x32x64x?xf32>
945+
// CHECK: return [[VAR_5_]] : tensor<?x32x64x?xf32>
946+
// CHECK: }
947+
}
948+
949+
// -----
950+
951+
// COM: This pattern is found in the IBM granite-3.1-2b-instruct model.
952+
func.func @test_reshape_matmul_dim(%arg0: tensor<?x?x2048xf32>) -> tensor<?x?x?x64xf32> {
953+
%0 = onnx.Constant dense<1.000000e+00> : tensor<2048x2048xf32>
954+
%1 = onnx.Constant dense<64> : tensor<1xi64>
955+
%2 = onnx.Constant dense<-1> : tensor<1xi64>
956+
%3 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
957+
%4 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
958+
%5 = "onnx.MatMul"(%arg0, %0) : (tensor<?x?x2048xf32>, tensor<2048x2048xf32>) -> tensor<?x?x2048xf32>
959+
%6 = "onnx.Concat"(%3, %4, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
960+
%7 = "onnx.Reshape"(%5, %6) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
961+
return %7 : tensor<?x?x?x64xf32>
962+
963+
// CHECK-LABEL: func.func @test_reshape_matmul_dim
964+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x32x64xf32> {
965+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<2048x2048xf32>
966+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
967+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
968+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
969+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
970+
// CHECK-NOT: separator of consecutive DAGs
971+
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<?x?x2048xf32>, tensor<2048x2048xf32>) -> tensor<?x?x2048xf32>
972+
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_2_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
973+
// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[VAR_5_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x32x64xf32>
974+
// CHECK: return [[VAR_7_]] : tensor<?x?x32x64xf32>
975+
// CHECK: }
976+
}
977+
978+
// -----
979+
980+
func.func @test_reshape_dim_not_bijection(%arg0: tensor<?x?x2048xf32>) -> tensor<?x?x?x64xf32> {
981+
%1 = onnx.Constant dense<64> : tensor<1xi64>
982+
%2 = onnx.Constant dense<-1> : tensor<1xi64>
983+
%3 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
984+
%4 = "onnx.Concat"(%3, %3, %2, %1) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
985+
%5 = "onnx.Reshape"(%arg0, %4) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
986+
return %5 : tensor<?x?x?x64xf32>
987+
988+
// CHECK-LABEL: func.func @test_reshape_dim_not_bijection
989+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x2048xf32>) -> tensor<?x?x?x64xf32> {
990+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<64> : tensor<1xi64>
991+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
992+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x2048xf32>) -> tensor<1xi64>
993+
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_1_]], [[VAR_0_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
994+
// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<?x?x2048xf32>, tensor<4xi64>) -> tensor<?x?x?x64xf32>
995+
// CHECK: return [[VAR_4_]] : tensor<?x?x?x64xf32>
996+
// CHECK: }
997+
}
998+
999+
// -----
1000+
9031001
//===----------------------------------------------------------------------===//
9041002
/// Test the flatten op inference.
9051003
//===----------------------------------------------------------------------===//
@@ -3910,4 +4008,4 @@ func.func @test_grid_sample_dim_shape3(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor
39104008
// CHECK: return [[GRID]] : tensor<?x?x10x20xf32>
39114009
// CHECK: }
39124010
return %0 : tensor<*xf32>
3913-
}
4011+
}

0 commit comments

Comments
 (0)