From ab445d628705302eb49a5a4efda98b845bb52df3 Mon Sep 17 00:00:00 2001 From: Boyana Norris Date: Sun, 3 Mar 2024 16:49:17 -0800 Subject: [PATCH 1/4] ONNX to StableHLO, add lowering of Dim op Signed-off-by: Boyana Norris --- src/Conversion/ONNXToStablehlo/CMakeLists.txt | 1 + .../ConvertONNXToStablehlo.cpp | 2 + .../ONNXToStablehlo/ONNXToStablehloCommon.hpp | 2 + src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp | 65 +++++++++++++++++++ .../onnx_to_stablehlo/Tensor/Dim.mlir | 29 +++++++++ 5 files changed, 99 insertions(+) create mode 100644 src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp create mode 100644 test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir diff --git a/src/Conversion/ONNXToStablehlo/CMakeLists.txt b/src/Conversion/ONNXToStablehlo/CMakeLists.txt index 690c58ef44..4b1b0ec002 100644 --- a/src/Conversion/ONNXToStablehlo/CMakeLists.txt +++ b/src/Conversion/ONNXToStablehlo/CMakeLists.txt @@ -56,6 +56,7 @@ add_onnx_mlir_library(OMONNXToStablehlo Tensor/Concat.cpp Tensor/Constant.cpp Tensor/DepthToSpace.cpp + Tensor/Dim.cpp Tensor/Expand.cpp Tensor/Flatten.cpp Tensor/Gather.cpp diff --git a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp index 74ea09a3dc..1550214d60 100644 --- a/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp +++ b/src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp @@ -41,6 +41,7 @@ void populateONNXToStablehloConversionPattern( populateLoweringONNXConcatOpToStablehloPattern(patterns, ctx); populateLoweringONNXConstantOpToStablehloPattern(patterns, ctx); populateLoweringONNXDepthToSpaceOpToStablehloPattern(patterns, ctx); + populateLoweringONNXDimOpToStablehloPattern(patterns, ctx); populateLoweringONNXExpandOpToStablehloPattern(patterns, ctx); populateLoweringONNXFlattenOpToStablehloPattern(patterns, ctx); populateLoweringONNXGatherOpToStablehloPattern(patterns, ctx); @@ -87,6 +88,7 @@ struct FrontendToStablehloLoweringPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); } diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp index 832e72b973..5618f9962c 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp @@ -179,6 +179,8 @@ void populateLoweringONNXConcatOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXConstantOpToStablehloPattern( RewritePatternSet &, MLIRContext *); +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &, MLIRContext *); void populateLoweringONNXDepthToSpaceOpToStablehloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXExpandOpToStablehloPattern( diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp new file mode 100644 index 0000000000..5c1320617a --- /dev/null +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- Dim.cpp - Lowering Dim Op ----------------===// +// +// Copyright 2022-2024 +// +// ============================================================================= +// +// This file lowers the ONNXDim operator to the Tensor dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { + ONNXDimOpLoweringToStablehlo(MLIRContext *ctx) + : ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + ONNXDimOp dimOp = cast(op); + int64_t axis = dimOp.getAxis(); + + // Check that axis is a valid dimension index + Value tensorArg = operands[0]; + if (!tensorArg.getType().isa()) { + return rewriter.notifyMatchFailure(op, "Expected ranked tensor type"); + } + RankedTensorType tensorType = tensorArg.getType().cast(); + int64_t rank = tensorType.getRank(); + + if (axis < 0 || axis >= rank) { + return rewriter.notifyMatchFailure( + op, "Invalid axis, must be in range 0 to rank-1 of the input tensor"); + } + Value dimValue = rewriter.create(loc, tensorArg, axis); + + Type dimType = dimOp.getDim().getType(); + Type indexValueType = dimType.cast().getElementType(); + Value castedIndex = + rewriter.create(loc, indexValueType, dimValue); + Value indexTensor = rewriter.create( + loc, dimType, ArrayRef{castedIndex}); + rewriter.replaceOp(op, indexTensor); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXDimOpToStablehloPattern( + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir new file mode 100644 index 0000000000..86b1c5d9c8 --- /dev/null +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir @@ -0,0 +1,29 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo --canonicalize %s -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 1 : si64} : (tensor<5x?x1x32xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} +// CHECK-LABEL: func.func @test_dim_1 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index +// CHECK-NEXT: [[DIM:%.+]] = tensor.dim [[PARAM]], [[CONST_1]] : tensor<5x?x1x32xf32> +// CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64 +// CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> +// CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64> +// CHECK: } + +// ----- + +func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// CHECK-LABEL: func.func @test_dim_2 +// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x7xf32>) -> tensor<1xi64> { +// CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64> +// CHECK-NEXT: return [[CONST]] : tensor<1xi64> +// CHECK: } From 4fc3fe698be9fc21819615f7cf6c5b68e0c94867 Mon Sep 17 00:00:00 2001 From: Boyana Norris Date: Sun, 3 Mar 2024 18:41:01 -0800 Subject: [PATCH 2/4] emit errors in verifier instead of failing silently Signed-off-by: Boyana Norris --- src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp | 13 ++++++------- src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp | 11 +++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp index 5c1320617a..ea09db0724 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -32,16 +32,15 @@ struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { // Check that axis is a valid dimension index Value tensorArg = operands[0]; - if (!tensorArg.getType().isa()) { - return rewriter.notifyMatchFailure(op, "Expected ranked tensor type"); - } + assert(tensorArg.getType().isa() && + "Expected ranked tensor type"); + RankedTensorType tensorType = tensorArg.getType().cast(); int64_t rank = tensorType.getRank(); - if (axis < 0 || axis >= rank) { - return rewriter.notifyMatchFailure( - op, "Invalid axis, must be in range 0 to rank-1 of the input tensor"); - } + assert((axis >= 0 && axis < rank) && + "Invalid axis, must be in the range [0, input tensor rank)"); + Value dimValue = rewriter.create(loc, tensorArg, axis); Type dimType = dimOp.getDim().getType(); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp index 2f0a513b35..403009d73f 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp @@ -37,12 +37,15 @@ LogicalResult ONNXDimOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXDimOp::verify() { - // Input data must be ranked. if (!hasShapeAndRank(this->getData())) - return failure(); - // Axis must be in [0, rank -1]. + return emitOpError("input must have shape and rank."); + int64_t axis = this->getAxis(); - return failure((axis < 0) || (axis >= getRank(this->getData().getType()))); + if ((axis < 0) || (axis >= getRank(this->getData().getType()))) + return emitOpError("attribute ") + << ONNXDimOp::getAxisAttrName() << " must be in the range [0, " + << getRank(this->getData().getType()) << ")."; + return success(); } //===----------------------------------------------------------------------===// From 159fd939825bbb6186b7d66379579b1bf38dfd48 Mon Sep 17 00:00:00 2001 From: Boyana Norris Date: Sun, 3 Mar 2024 18:59:09 -0800 Subject: [PATCH 3/4] add invalid mlir lit tests Signed-off-by: Boyana Norris --- src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp | 2 +- src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp | 5 +++-- test/mlir/onnx/invalid.mlir | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp index ea09db0724..0dae351c07 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -39,7 +39,7 @@ struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { int64_t rank = tensorType.getRank(); assert((axis >= 0 && axis < rank) && - "Invalid axis, must be in the range [0, input tensor rank)"); + "Axis must be in the range [0, input tensor rank - 1]"); Value dimValue = rewriter.create(loc, tensorArg, axis); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp index 403009d73f..7afdc6bbf4 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp @@ -43,8 +43,9 @@ LogicalResult ONNXDimOp::verify() { int64_t axis = this->getAxis(); if ((axis < 0) || (axis >= getRank(this->getData().getType()))) return emitOpError("attribute ") - << ONNXDimOp::getAxisAttrName() << " must be in the range [0, " - << getRank(this->getData().getType()) << ")."; + << ONNXDimOp::getAxisAttrName() << " value is " << axis + << ", accepted range is [0, " + << getRank(this->getData().getType()) - 1 << "]."; return success(); } diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index edb2eee264..f91d261eaa 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -83,6 +83,22 @@ func.func @test_concat_from_sequence_verifier_2(%arg0 : !onnx.Seq) -> tensor { + // expected-error @+1 {{input must have shape and rank}} + %1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<*xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + +func.func @test_dim_verifier_2(%arg0 : tensor<5x5xf32>) -> tensor { + // expected-error @+1 {{'onnx.Dim' op attribute "axis" value is -1, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) {axis = -1 : si64} : (tensor<5x5xf32>) -> tensor + "onnx.Return"(%1) : (tensor) -> () +} + +// ----- + func.func @test_dequantize_linear_verifier_1(%arg0 : tensor<5x5x1xi32>, %arg1 : tensor<3xf32>, %arg2 : tensor<3xi32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.DequantizeLinear: 'axis' value is 3, accepted range is [-3, 2]}} %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 3 : si64} : (tensor<5x5x1xi32>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32> From 5f4f85108a47dbbaf224fc1c5f609bb0af29813c Mon Sep 17 00:00:00 2001 From: Boyana Norris Date: Tue, 5 Mar 2024 20:01:44 -0800 Subject: [PATCH 4/4] use shape.get_extent instead of tensor.dim (since the shape dialect is already used similarly in the conversion of other ops). Signed-off-by: Boyana Norris --- src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp | 15 ++++++++------- .../onnx_to_stablehlo/Tensor/Dim.mlir | 19 ++++++++++++++++++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp index 0dae351c07..2e40c2ade6 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Dim.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Shape/IR/Shape.h" #include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" using namespace mlir; @@ -28,21 +29,21 @@ struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); ONNXDimOp dimOp = cast(op); - int64_t axis = dimOp.getAxis(); + int64_t axisLit = dimOp.getAxis(); - // Check that axis is a valid dimension index + // Check that axisLit is a valid dimension index Value tensorArg = operands[0]; assert(tensorArg.getType().isa() && "Expected ranked tensor type"); - RankedTensorType tensorType = tensorArg.getType().cast(); - int64_t rank = tensorType.getRank(); + int64_t rank = tensorArg.getType().cast().getRank(); - assert((axis >= 0 && axis < rank) && + assert((axisLit >= 0 && axisLit < rank) && "Axis must be in the range [0, input tensor rank - 1]"); - Value dimValue = rewriter.create(loc, tensorArg, axis); - + Value inputShape = rewriter.create(loc, tensorArg); + Value dimValue = + rewriter.create(loc, inputShape, axisLit); Type dimType = dimOp.getDim().getType(); Type indexValueType = dimType.cast().getElementType(); Value castedIndex = diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir index 86b1c5d9c8..e3013a57e5 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/Dim.mlir @@ -9,7 +9,8 @@ func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> { // CHECK-LABEL: func.func @test_dim_1 // CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> { // CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index -// CHECK-NEXT: [[DIM:%.+]] = tensor.dim [[PARAM]], [[CONST_1]] : tensor<5x?x1x32xf32> +// CHECK-NEXT: [[SHAPE:%.+]] = shape.shape_of [[PARAM]] : tensor<5x?x1x32xf32> -> tensor<4xindex> +// CHECK-NEXT: [[DIM:%.+]] = shape.get_extent [[SHAPE]], [[CONST_1]] : tensor<4xindex>, index -> index // CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64 // CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> // CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64> @@ -27,3 +28,19 @@ func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { // CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64> // CHECK-NEXT: return [[CONST]] : tensor<1xi64> // CHECK: } + +// ----- + +func.func @test_dim_invalid_1(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { + // expected-error @+1 {{attribute "axis" value is 3, accepted range is [0, 1].}} + %1 = "onnx.Dim"(%arg0) { axis = 3 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +} + +// ----- + +func.func @test_dim_invalid_2(%arg0 : tensor<*xf32>) -> tensor<1xi64> { + // expected-error @+1 {{input must have shape and rank.}} + %1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<*xf32>) -> tensor<1xi64> + return %1 : tensor<1xi64> +}