Skip to content

Commit b06a7b9

Browse files
committed
Merge branch 'main' into mem_reduction_stickified
2 parents 691ec33 + 2c8e8e5 commit b06a7b9

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

CMakeLists.txt

-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ project(onnx-mlir)
88
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
99
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
1010
option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
11-
option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON)
1211
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
1312
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
1413
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
@@ -208,10 +207,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
208207
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
209208
endif()
210209

211-
if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
212-
add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
213-
endif()
214-
215210
add_subdirectory(utils)
216211
add_subdirectory(include)
217212
add_subdirectory(src)

src/Dialect/ONNX/Transforms/Decompose.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,10 @@ bool hasStaticSpatialDims(Value v) {
332332
}
333333

334334
bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
335-
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
336335
ONNXConvTransposeOp op =
337336
mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
338337
return hasShapeAndRank(convTransposeResult) &&
339338
hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW());
340-
#else
341-
// Disable the ONNXConvTransposeOp decomposition patterns.
342-
return false;
343-
#endif
344339
}
345340

346341
// Split on the specified axis. The length of each output is one.
@@ -1128,7 +1123,6 @@ void DecomposeONNXToONNXPass::runOnOperation() {
11281123
op, alpha, rankA, rankB);
11291124
});
11301125

1131-
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
11321126
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
11331127
// ONNXtoStablehlo pass has own rewriting for ConvTranspose Op using
11341128
// stablehlo ops. To avoid conflict with it, decomposing for ConvTranspose
@@ -1141,7 +1135,6 @@ void DecomposeONNXToONNXPass::runOnOperation() {
11411135
});
11421136
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
11431137
}
1144-
#endif
11451138
#endif
11461139

11471140
RewritePatternSet patterns(context);

test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir

+28
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,34 @@ func.func @test_onnx_to_zhigh_gru0_bidir_dyn(%X: tensor<?x?x?xf32>, %W: tensor<2
247247

248248
// -----
249249

250+
func.func @gru_with_len(%arg0: tensor<2x2x1xf32>, %arg1: tensor<1x3x1xf32>, %arg2 : tensor<1x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
251+
%lens = onnx.Constant dense<[2, 1]> : tensor<2xi32>
252+
%cst = "onnx.NoValue"() {value} : () -> none
253+
%res:2 = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %lens, %cst) {layout = 0 : si64, linear_before_reset = 1 : si64}
254+
: ( tensor<2x2x1xf32>, tensor<1x3x1xf32>, tensor<1x3x1xf32>, none, tensor<2xi32>, none) -> (tensor<*xf32>, tensor<*xf32>)
255+
onnx.Return %res#0, %res#1 : tensor<*xf32>, tensor<*xf32>
256+
257+
// CHECK-LABEL: func.func @gru_with_len
258+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x2x1xf32>, [[PARAM_1_:%.+]]: tensor<1x3x1xf32>, [[PARAM_2_:%.+]]: tensor<1x3x1xf32>) -> (tensor<2x1x2x1xf32>, tensor<1x2x1xf32>) {
259+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 1]> : tensor<2xi32>
260+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
261+
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<2x2x1xf32>) -> tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>
262+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
263+
// CHECK: [[VAR_4_:%.+]]:3 = "onnx.SplitV11"([[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
264+
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.StickForGRU"([[VAR_4_]]#0, [[VAR_4_]]#1, [[VAR_4_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
265+
// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_2_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
266+
// CHECK: [[VAR_7_:%.+]]:3 = "onnx.SplitV11"([[VAR_6_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
267+
// CHECK: [[VAR_8_:%.+]] = "zhigh.StickForGRU"([[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
268+
// CHECK: [[VAR_9_:%.+]] = "zhigh.GRU"([[VAR_2_]], [[VAR_1_]], [[VAR_5_]], [[VAR_1_]], [[VAR_8_]], [[VAR_1_]]) {direction = "forward", hidden_size = 1 : si64, return_all_steps = -1 : si64} : (tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none, tensor<*xf16>, none, tensor<*xf16>, none) -> tensor<*xf16>
269+
// CHECK: [[VAR_10_:%.+]] = "zhigh.Unstick"([[VAR_9_]]) : (tensor<*xf16>) -> tensor<2x1x2x1xf32>
270+
// CHECK-DAG: [[VAR_11_:%.+]] = "zhigh.FixGRUY"([[VAR_10_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>, none) -> tensor<2x1x2x1xf32>
271+
// CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.FixGRUYh"([[VAR_10_]], [[VAR_0_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>) -> tensor<1x2x1xf32>
272+
// CHECK: onnx.Return [[VAR_11_]], [[VAR_12_]] : tensor<2x1x2x1xf32>, tensor<1x2x1xf32>
273+
// CHECK: }
274+
}
275+
276+
// -----
277+
250278
// COM : Maximum hidden_size in GRU is 10880. Not lowered when using 10881.
251279

252280
func.func @test_onnx_to_zhigh_gru_exceed_num_hidden(%X: tensor<7x2000x204xf32>, %W: tensor<1x16384x204xf32>, %R: tensor<1x16384x10881xf32>, %B: tensor<1x16386xf32>) -> (tensor<7x1x2000x10881xf32>, tensor<1x2000x10881xf32>) {

0 commit comments

Comments
 (0)