Skip to content

Commit 18f4e07

Browse files
brnorris03tungld
andauthored
[StableHLO] Add onnx.Dim lowering to StableHLO (#2738)
* ONNX to StableHLO, add lowering of Dim op Signed-off-by: Boyana Norris <[email protected]> --------- Signed-off-by: Boyana Norris <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent a4438fe commit 18f4e07

File tree

7 files changed

+140
-4
lines changed

7 files changed

+140
-4
lines changed

src/Conversion/ONNXToStablehlo/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ add_onnx_mlir_library(OMONNXToStablehlo
5656
Tensor/Concat.cpp
5757
Tensor/Constant.cpp
5858
Tensor/DepthToSpace.cpp
59+
Tensor/Dim.cpp
5960
Tensor/Expand.cpp
6061
Tensor/Flatten.cpp
6162
Tensor/Gather.cpp

src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void populateONNXToStablehloConversionPattern(
4141
populateLoweringONNXConcatOpToStablehloPattern(patterns, ctx);
4242
populateLoweringONNXConstantOpToStablehloPattern(patterns, ctx);
4343
populateLoweringONNXDepthToSpaceOpToStablehloPattern(patterns, ctx);
44+
populateLoweringONNXDimOpToStablehloPattern(patterns, ctx);
4445
populateLoweringONNXExpandOpToStablehloPattern(patterns, ctx);
4546
populateLoweringONNXFlattenOpToStablehloPattern(patterns, ctx);
4647
populateLoweringONNXGatherOpToStablehloPattern(patterns, ctx);
@@ -87,6 +88,7 @@ struct FrontendToStablehloLoweringPass
8788

8889
void getDependentDialects(DialectRegistry &registry) const override {
8990
registry.insert<mlir::stablehlo::StablehloDialect>();
91+
registry.insert<mlir::arith::ArithDialect>();
9092
registry.insert<shape::ShapeDialect>();
9193
}
9294

src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ void populateLoweringONNXConcatOpToStablehloPattern(
179179
RewritePatternSet &, MLIRContext *);
180180
void populateLoweringONNXConstantOpToStablehloPattern(
181181
RewritePatternSet &, MLIRContext *);
182+
void populateLoweringONNXDimOpToStablehloPattern(
183+
RewritePatternSet &, MLIRContext *);
182184
void populateLoweringONNXDepthToSpaceOpToStablehloPattern(
183185
RewritePatternSet &, MLIRContext *);
184186
void populateLoweringONNXExpandOpToStablehloPattern(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===----------------- Dim.cpp - Lowering Dim Op ----------------===//
6+
//
7+
// Copyright 2022-2024
8+
//
9+
// =============================================================================
10+
//
11+
// This file lowers the ONNXDim operator to the Tensor dialect.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Shape/IR/Shape.h"
16+
#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp"
17+
18+
using namespace mlir;
19+
20+
namespace onnx_mlir {
21+
22+
namespace {
23+
24+
struct ONNXDimOpLoweringToStablehlo : public ConversionPattern {
25+
ONNXDimOpLoweringToStablehlo(MLIRContext *ctx)
26+
: ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {}
27+
28+
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
29+
ConversionPatternRewriter &rewriter) const final {
30+
Location loc = op->getLoc();
31+
ONNXDimOp dimOp = cast<ONNXDimOp>(op);
32+
int64_t axisLit = dimOp.getAxis();
33+
34+
// Check that axisLit is a valid dimension index
35+
Value tensorArg = operands[0];
36+
assert(tensorArg.getType().isa<RankedTensorType>() &&
37+
"Expected ranked tensor type");
38+
39+
int64_t rank = tensorArg.getType().cast<RankedTensorType>().getRank();
40+
41+
assert((axisLit >= 0 && axisLit < rank) &&
42+
"Axis must be in the range [0, input tensor rank - 1]");
43+
44+
Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensorArg);
45+
Value dimValue =
46+
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
47+
Type dimType = dimOp.getDim().getType();
48+
Type indexValueType = dimType.cast<ShapedType>().getElementType();
49+
Value castedIndex =
50+
rewriter.create<arith::IndexCastOp>(loc, indexValueType, dimValue);
51+
Value indexTensor = rewriter.create<tensor::FromElementsOp>(
52+
loc, dimType, ArrayRef<Value>{castedIndex});
53+
rewriter.replaceOp(op, indexTensor);
54+
return success();
55+
}
56+
};
57+
58+
} // namespace
59+
60+
void populateLoweringONNXDimOpToStablehloPattern(
61+
RewritePatternSet &patterns, MLIRContext *ctx) {
62+
patterns.insert<ONNXDimOpLoweringToStablehlo>(ctx);
63+
}
64+
65+
} // namespace onnx_mlir

src/Dialect/ONNX/ONNXOps/Additional/Dim.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ LogicalResult ONNXDimOpShapeHelper::computeShape() {
3737
//===----------------------------------------------------------------------===//
3838

3939
LogicalResult ONNXDimOp::verify() {
40-
// Input data must be ranked.
4140
if (!hasShapeAndRank(this->getData()))
42-
return failure();
43-
// Axis must be in [0, rank -1].
41+
return emitOpError("input must have shape and rank.");
42+
4443
int64_t axis = this->getAxis();
45-
return failure((axis < 0) || (axis >= getRank(this->getData().getType())));
44+
if ((axis < 0) || (axis >= getRank(this->getData().getType())))
45+
return emitOpError("attribute ")
46+
<< ONNXDimOp::getAxisAttrName() << " value is " << axis
47+
<< ", accepted range is [0, "
48+
<< getRank(this->getData().getType()) - 1 << "].";
49+
return success();
4650
}
4751

4852
//===----------------------------------------------------------------------===//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo --canonicalize %s -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// -----
4+
5+
func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> {
6+
%1 = "onnx.Dim"(%arg0) { axis = 1 : si64} : (tensor<5x?x1x32xf32>) -> tensor<1xi64>
7+
return %1 : tensor<1xi64>
8+
}
9+
// CHECK-LABEL: func.func @test_dim_1
10+
// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> {
11+
// CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index
12+
// CHECK-NEXT: [[SHAPE:%.+]] = shape.shape_of [[PARAM]] : tensor<5x?x1x32xf32> -> tensor<4xindex>
13+
// CHECK-NEXT: [[DIM:%.+]] = shape.get_extent [[SHAPE]], [[CONST_1]] : tensor<4xindex>, index -> index
14+
// CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64
15+
// CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64>
16+
// CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64>
17+
// CHECK: }
18+
19+
// -----
20+
21+
func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> {
22+
%1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<5x7xf32>) -> tensor<1xi64>
23+
return %1 : tensor<1xi64>
24+
}
25+
26+
// CHECK-LABEL: func.func @test_dim_2
27+
// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x7xf32>) -> tensor<1xi64> {
28+
// CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64>
29+
// CHECK-NEXT: return [[CONST]] : tensor<1xi64>
30+
// CHECK: }
31+
32+
// -----
33+
34+
func.func @test_dim_invalid_1(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> {
35+
// expected-error @+1 {{attribute "axis" value is 3, accepted range is [0, 1].}}
36+
%1 = "onnx.Dim"(%arg0) { axis = 3 : si64} : (tensor<5x7xf32>) -> tensor<1xi64>
37+
return %1 : tensor<1xi64>
38+
}
39+
40+
// -----
41+
42+
func.func @test_dim_invalid_2(%arg0 : tensor<*xf32>) -> tensor<1xi64> {
43+
// expected-error @+1 {{input must have shape and rank.}}
44+
%1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<*xf32>) -> tensor<1xi64>
45+
return %1 : tensor<1xi64>
46+
}

test/mlir/onnx/invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ func.func @test_concat_from_sequence_verifier_2(%arg0 : !onnx.Seq<tensor<5x5x1x3
8383

8484
// -----
8585

86+
func.func @test_dim_verifier_1(%arg0 : tensor<*xf32>) -> tensor<i64> {
87+
// expected-error @+1 {{input must have shape and rank}}
88+
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<*xf32>) -> tensor<i64>
89+
"onnx.Return"(%1) : (tensor<i64>) -> ()
90+
}
91+
92+
// -----
93+
94+
func.func @test_dim_verifier_2(%arg0 : tensor<5x5xf32>) -> tensor<i64> {
95+
// expected-error @+1 {{'onnx.Dim' op attribute "axis" value is -1, accepted range is [0, 1].}}
96+
%1 = "onnx.Dim"(%arg0) {axis = -1 : si64} : (tensor<5x5xf32>) -> tensor<i64>
97+
"onnx.Return"(%1) : (tensor<i64>) -> ()
98+
}
99+
100+
// -----
101+
86102
func.func @test_dequantize_linear_verifier_1(%arg0 : tensor<5x5x1xi32>, %arg1 : tensor<3xf32>, %arg2 : tensor<3xi32>) -> tensor<*xf32> {
87103
// expected-error @+1 {{onnx.DequantizeLinear: 'axis' value is 3, accepted range is [-3, 2]}}
88104
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 3 : si64} : (tensor<5x5x1xi32>, tensor<3xf32>, tensor<3xi32>) -> tensor<*xf32>

0 commit comments

Comments
 (0)