diff --git a/src/Conversion/ONNXToStablehlo/CMakeLists.txt b/src/Conversion/ONNXToStablehlo/CMakeLists.txt index 690c58ef44..4b1b0ec002 100644 --- a/src/Conversion/ONNXToStablehlo/CMakeLists.txt +++ b/src/Conversion/ONNXToStablehlo/CMakeLists.txt @@ -56,6 +56,7 @@ add_onnx_mlir_library(OMONNXToStablehlo Tensor/Concat.cpp Tensor/Constant.cpp Tensor/DepthToSpace.cpp + Tensor/Dim.cpp Tensor/Expand.cpp Tensor/Flatten.cpp Tensor/Gather.cpp diff --git a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp index 74ea09a3dc..1550214d60 100644 --- a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp +++ b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp @@ -41,6 +41,7 @@ void populateONNXToStablehloConversionPattern( populateLoweringONNXConcatOpToStablehloPattern(patterns, ctx); populateLoweringONNXConstantOpToStablehloPattern(patterns, ctx); populateLoweringONNXDepthToSpaceOpToStablehloPattern(patterns, ctx); + populateLoweringONNXDimOpToStablehloPattern(patterns, ctx); populateLoweringONNXExpandOpToStablehloPattern(patterns, ctx); populateLoweringONNXFlattenOpToStablehloPattern(patterns, ctx); populateLoweringONNXGatherOpToStablehloPattern(patterns, ctx); @@ -87,6 +88,7 @@ struct FrontendToStablehloLoweringPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); } diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp index 832e72b973..5618f9962c 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp @@ -179,6 +179,8 @@ void populateLoweringONNXConcatOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXConstantOpToStablehloPattern( RewritePatternSet &, MLIRContext *); +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &, MLIRContext *); void populateLoweringONNXDepthToSpaceOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXExpandOpToStablehloPattern( diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp new file mode 100644 index 0000000000..2e40c2ade6 --- /dev/null +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- Dim.cpp - Lowering Dim Op ----------------===// +// +// Copyright 2022-2024 +// +// ============================================================================= +// +// This file lowers the ONNXDim operator to the Tensor dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { + ONNXDimOpLoweringToStablehlo(MLIRContext *ctx) + : ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + ONNXDimOp dimOp = cast(op); + int64_t axisLit = dimOp.getAxis(); + + // Check that axisLit is a valid dimension index + Value tensorArg = operands[0]; + assert(tensorArg.getType().isa() && + "Expected ranked tensor type"); + + int64_t rank = tensorArg.getType().cast().getRank(); + + assert((axisLit >= 0 && axisLit < rank) && + "Axis must be in the range [0, input tensor rank - 1]"); + + Value inputShape = rewriter.create(loc, tensorArg); + Value dimValue = + rewriter.create(loc, inputShape, axisLit); + Type dimType = dimOp.getDim().getType(); + Type indexValueType = dimType.cast().getElementType(); + Value castedIndex = + rewriter.create(loc, indexValueType, dimValue); + Value indexTensor = rewriter.create( + loc, dimType, ArrayRef{castedIndex}); + rewriter.replaceOp(op, indexTensor); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp index 2f0a513b35..7afdc6bbf4 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp @@ -37,12 +37,16 @@ LogicalResult ONNXDimOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXDimOp::verify() { - // Input data must be ranked. if (!hasShapeAndRank(this->getData())) - return failure(); - // Axis must be in [0, rank -1]. + return emitOpError("input must have shape and rank."); + int64_t axis = this->getAxis(); - return failure((axis < 0) || (axis >= getRank(this->getData().getType()))); + if ((axis < 0) || (axis >= getRank(this->getData().getType()))) + return emitOpError("attribute ") + << ONNXDimOp::getAxisAttrName() << " value is " << axis + << ", accepted range is [0, " + << getRank(this->getData().getType()) - 1 << "]."; + return success(); } //===----------------------------------------------------------------------===// diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir new file mode 100644 index 0000000000..e3013a57e5 --- /dev/null +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir @@ -0,0 +1,46 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo --canonicalize %s -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 1 : si64} : (tensor<5x?x1x32xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} +// CHECK-LABEL: func.func @test_dim_1 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index +// CHECK-NEXT: [[SHAPE:%.+]] = shape.shape_of [[PARAM]] : tensor<5x?x1x32xf32> -> tensor<4xindex> +// CHECK-NEXT: [[DIM:%.+]] = shape.get_extent [[SHAPE]], [[CONST_1]] : tensor<4xindex>, index -> index +// CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64 +// CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> +// CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64> +// CHECK: } + +// ----- + +func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// CHECK-LABEL: func.func @test_dim_2 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x7xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64> +// CHECK-NEXT: return [[CONST]] : tensor<1xi64> +// CHECK: } + +// ----- + +func.func @test_dim_invalid_1(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + // expected-error @+1 {{attribute "axis" value is 3, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) { axis = 3 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// ----- + +func.func @test_dim_invalid_2(%arg0 : tensor<*xf32>) -> tensor<1xi64> { + // expected-error @+1 {{input must have shape and rank.}} + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<*xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index edb2eee264..f91d261eaa 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -83,6 +83,22 @@ func.func @test_concat_from_sequence_verifier_2(%arg0 : !onnx.Seq) -> tensor { + // expected-error @+1 {{input must have shape and rank}} + %1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<*xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + +func.func @test_dim_verifier_2(%arg0 : tensor<5x5xf32>) -> tensor { + // expected-error @+1 {{'onnx.Dim' op attribute "axis" value is -1, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) {axis = -1 : si64} : (tensor<5x5xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + func.func @test_dequantize_linear_verifier_1(%arg0 : tensor<5x5x1xi32>, %arg1 : tensor<3xf32>, %arg2 : tensor<3xi32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.DequantizeLinear: 'axis' value is 3, accepted range is [-3, 2]}} %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 3 : si64} : (tensor<5x5x1xi32>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32>