Skip to content

Commit a9d7957

Browse files
committed
Adding shape inference support for Mish op
Signed-off-by: Jagadeesh V <[email protected]>
1 parent 0205281 commit a9d7957

File tree

5 files changed

+16
-7
lines changed

5 files changed

+16
-7
lines changed

src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,15 @@ LogicalResult ONNXMeanVarianceNormalizationOp::inferShapes(
395395
return inferShapeForUnaryOps(this->getOperation());
396396
}
397397

398+
//===----------------------------------------------------------------------===//
399+
// MishOp
400+
//===----------------------------------------------------------------------===//
401+
402+
LogicalResult ONNXMishOp::inferShapes(
403+
std::function<void(Region &)> doShapeInference) {
404+
return inferShapeForUnaryOps(this->getOperation());
405+
}
406+
398407
//===----------------------------------------------------------------------===//
399408
// NegOp
400409
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ using ONNXLogOpShapeHelper = ONNXUnaryOpShapeHelper;
382382
using ONNXLogSoftmaxOpShapeHelper = ONNXUnaryOpShapeHelper;
383383
using ONNXLpNormalizationOpShapeHelper = ONNXUnaryOpShapeHelper;
384384
using ONNXMeanVarianceNormalizationOpShapeHelper = ONNXUnaryOpShapeHelper;
385+
using ONNXMishOpShapeHelper = ONNXUnaryOpShapeHelper;
385386
using ONNXNegOpShapeHelper = ONNXUnaryOpShapeHelper;
386387
using ONNXNotOpShapeHelper = ONNXUnaryOpShapeHelper;
387388
using ONNXRandomNormalLikeOpShapeHelper = ONNXUnaryOpShapeHelper;

src/Dialect/ONNX/ONNXUnsupportedOps.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ UNSUPPORTED_OPS(ONNXLpPoolOp)
4747
UNSUPPORTED_OPS(ONNXMaxPoolOp)
4848
UNSUPPORTED_OPS(ONNXMaxUnpoolOp)
4949
UNSUPPORTED_OPS(ONNXMelWeightMatrixOp)
50-
UNSUPPORTED_OPS(ONNXMishOp)
5150
UNSUPPORTED_OPS(ONNXMomentumOp)
5251
UNSUPPORTED_OPS(ONNXMultinomialOp)
5352
UNSUPPORTED_OPS(ONNXNegativeLogLikelihoodLossOp)

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,9 +1662,9 @@ func.func private @test_shrink(%arg0 : tensor<512xf32>) -> tensor<*xf32> {
16621662

16631663
// -----
16641664

1665-
func.func @test_mish(%arg0 : tensor<64x128xf32>) -> tensor<64x128xf32> {
1666-
%0 = "onnx.Mish"(%arg0) : (tensor<64x128xf32>) -> tensor<64x128xf32>
1667-
return %0 : tensor<64x128xf32>
1665+
func.func @test_mish(%arg0 : tensor<64x128xf32>) -> tensor<*xf32> {
1666+
%0 = "onnx.Mish"(%arg0) : (tensor<64x128xf32>) -> tensor<*xf32>
1667+
"func.return"(%0) : (tensor<*xf32>) -> ()
16681668

16691669
// CHECK-LABEL: func.func @test_mish
16701670
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<64x128xf32>) -> memref<64x128xf32> {

test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,9 +2578,9 @@ func.func private @test_shrink(%arg0 : tensor<512xf32>) -> tensor<*xf32> {
25782578

25792579
// -----
25802580

2581-
func.func @test_mish(%arg0 : tensor<64x128xf32>) -> tensor<64x128xf32> {
2582-
%0 = "onnx.Mish"(%arg0) : (tensor<64x128xf32>) -> tensor<64x128xf32>
2583-
return %0 : tensor<64x128xf32>
2581+
func.func @test_mish(%arg0 : tensor<64x128xf32>) -> tensor<*xf32> {
2582+
%0 = "onnx.Mish"(%arg0) : (tensor<64x128xf32>) -> tensor<*xf32>
2583+
"func.return"(%0) : (tensor<*xf32>) -> ()
25842584

25852585
// CHECK-LABEL: func.func @test_mish
25862586
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<64x128xf32>) -> memref<64x128xf32> {

0 commit comments

Comments
 (0)