-
Notifications
You must be signed in to change notification settings - Fork 348
[StableHLO] Add onnx.Dim
lowering to StableHLO
#2738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
ab445d6
ONNX to StableHLO, add lowering of Dim op
brnorris03 4fc3fe6
emit errors in verifier instead of failing silently
brnorris03 159fd93
add invalid mlir lit tests
brnorris03 af40408
Merge branch 'main' into feature/dim-to-stablehlo
brnorris03 5f4f851
use shape.get_extent instead of tensor.dim (since
brnorris03 e90d37c
Merge branch 'main' into feature/dim-to-stablehlo
brnorris03 31a9616
Merge branch 'main' into feature/dim-to-stablehlo
tungld File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===----------------- Dim.cpp - Lowering Dim Op ----------------===// | ||
// | ||
// Copyright 2022-2024 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNXDim operator to the Tensor dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXDimOpLoweringToStablehlo : public ConversionPattern { | ||
ONNXDimOpLoweringToStablehlo(MLIRContext *ctx) | ||
: ConversionPattern(ONNXDimOp::getOperationName(), 1, ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
Location loc = op->getLoc(); | ||
ONNXDimOp dimOp = cast<ONNXDimOp>(op); | ||
int64_t axis = dimOp.getAxis(); | ||
|
||
// Check that axis is a valid dimension index | ||
Value tensorArg = operands[0]; | ||
assert(tensorArg.getType().isa<RankedTensorType>() && | ||
"Expected ranked tensor type"); | ||
|
||
RankedTensorType tensorType = tensorArg.getType().cast<RankedTensorType>(); | ||
int64_t rank = tensorType.getRank(); | ||
|
||
assert((axis >= 0 && axis < rank) && | ||
"Axis must be in the range [0, input tensor rank - 1]"); | ||
|
||
Value dimValue = rewriter.create<tensor::DimOp>(loc, tensorArg, axis); | ||
|
||
Type dimType = dimOp.getDim().getType(); | ||
Type indexValueType = dimType.cast<ShapedType>().getElementType(); | ||
Value castedIndex = | ||
rewriter.create<arith::IndexCastOp>(loc, indexValueType, dimValue); | ||
Value indexTensor = rewriter.create<tensor::FromElementsOp>( | ||
loc, dimType, ArrayRef<Value>{castedIndex}); | ||
rewriter.replaceOp(op, indexTensor); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXDimOpToStablehloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXDimOpLoweringToStablehlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,12 +37,16 @@ LogicalResult ONNXDimOpShapeHelper::computeShape() { | |
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult ONNXDimOp::verify() { | ||
// Input data must be ranked. | ||
if (!hasShapeAndRank(this->getData())) | ||
return failure(); | ||
// Axis must be in [0, rank -1]. | ||
return emitOpError("input must have shape and rank."); | ||
|
||
int64_t axis = this->getAxis(); | ||
return failure((axis < 0) || (axis >= getRank(this->getData().getType()))); | ||
if ((axis < 0) || (axis >= getRank(this->getData().getType()))) | ||
return emitOpError("attribute ") | ||
<< ONNXDimOp::getAxisAttrName() << " value is " << axis | ||
<< ", accepted range is [0, " | ||
<< getRank(this->getData().getType()) - 1 << "]."; | ||
return success(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks so much for the explicit error messages! Really appreciate it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :-) You are welcome! |
||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo --canonicalize %s -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// ----- | ||
|
||
func.func @test_dim_1(%arg0 : tensor<5x?x1x32xf32>) -> tensor<1xi64> { | ||
%1 = "onnx.Dim"(%arg0) { axis = 1 : si64} : (tensor<5x?x1x32xf32>) -> tensor<1xi64> | ||
return %1 : tensor<1xi64> | ||
} | ||
// CHECK-LABEL: func.func @test_dim_1 | ||
// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x?x1x32xf32>) -> tensor<1xi64> { | ||
// CHECK-NEXT: [[CONST_1:%.+]] = arith.constant 1 : index | ||
// CHECK-NEXT: [[DIM:%.+]] = tensor.dim [[PARAM]], [[CONST_1]] : tensor<5x?x1x32xf32> | ||
// CHECK-NEXT: [[INDEX_CAST:%.+]] = arith.index_cast [[DIM]] : index to i64 | ||
// CHECK-NEXT: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> | ||
// CHECK-NEXT: return [[FROM_ELEMENTS]] : tensor<1xi64> | ||
// CHECK: } | ||
|
||
// ----- | ||
|
||
func.func @test_dim_2(%arg0 : tensor<5x7xf32>) -> tensor<1xi64> { | ||
%1 = "onnx.Dim"(%arg0) { axis = 0 : si64} : (tensor<5x7xf32>) -> tensor<1xi64> | ||
return %1 : tensor<1xi64> | ||
} | ||
|
||
// CHECK-LABEL: func.func @test_dim_2 | ||
// CHECK-SAME: ([[PARAM:%.+]]: tensor<5x7xf32>) -> tensor<1xi64> { | ||
// CHECK-NEXT: [[CONST:%.+]] = arith.constant dense<5> : tensor<1xi64> | ||
// CHECK-NEXT: return [[CONST]] : tensor<1xi64> | ||
// CHECK: } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use stablehlo.get_dimension_size so we don't need to use tensor dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly, but it does make this way more complex than the tensor option, primarily because
stablehlo.get_dimension_size
returnstensor<i32>
which must then be converted to thetensor<1x64>
expected result type (so both the shape and size of the element type change). And of course there is a chance for overflow if the results exceeds i32 (which wouldn't happen with thetensor
conversion). I think that in general it would be best to avoidi64
->i32
->i64
conversions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a particular reason why the
tensor
dialect should be avoided? I am also able to use theshape
dialect (a bit more verbosely) for this, but that gets lowered to thetensor.dim
op anyway. I also don't know how to avoid usingtensor.from_elements
(which is used in the lowering of a couple of other ops).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, @Connor-XY would like to use stablehlo ops at this level as many as possible. Not sure why
stablehlo.get_dimension_size
returnstensor<i32>
while its input axis is small but i64 :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's understandable, but I don't actually see how I can avoid using
tensor
dialect ops entirely here. And also don't quite understand why I should avoid the ones (e.g.,tensor.from_elements
) that are already used in merged conversions. Also if usingstablehlo.get_dimension_size
, I am not able to successfully do the different-size index type conversions (i32
->i64
, going throughindex
doesn't work). Any suggestions?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would like to convert it to stablehlo ops as much as we can. It is fine to convert it to
shape
ortensor
dialect if it is hard to do so with pure stablehlo ops. It is possible to usestablehlo.get_dimension_size
to gettensor<i32>
, thenstablehlo.reshape
to gettensor<1xi32>
, and thenstablehlo.convert
to convert it fromtensor<1xi32>
totensor<1xi64>
.