Skip to content

Commit a4438fe

Browse files
srcarrolltungld
andauthored
Fix onnx.Gather conversion to stablehlo in dynamic case (#2736)
* Fix onnx.Gather conversion to stablehlo in dynamic case Signed-off-by: Sam <[email protected]> --------- Signed-off-by: Sam <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent d84e0a4 commit a4438fe

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/Conversion/ONNXToStablehlo/Tensor/Gather.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern {
4040
shapeHelper.computeShapeAndAssertOnFailure();
4141

4242
Type outputType = *op->result_type_begin();
43-
assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType");
43+
if (!isRankedShapedType(outputType))
44+
return rewriter.notifyMatchFailure(op, "Expected Ranked ShapedType");
4445

4546
// Operands and attributes.
4647
Value data = operandAdaptor.getData();
@@ -56,7 +57,7 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern {
5657
// start indices
5758
Value zero = getShapedZero(loc, rewriter, indices);
5859
Value axisDimSize;
59-
if (inputType.hasStaticShape()) {
60+
if (!inputType.isDynamicDim(axisLit)) {
6061
int64_t axisDimSizeLit = inputType.getShape()[axisLit];
6162
axisDimSize = getShapedInt(loc, rewriter, axisDimSizeLit, indices);
6263
} else {
@@ -66,6 +67,9 @@ struct ONNXGatherOpLoweringToStablehlo : public ConversionPattern {
6667
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
6768
Value axisDimSizeValue = rewriter.create<arith::IndexCastOp>(
6869
loc, indicesType.getElementType(), axisDimSizeIndexValue);
70+
axisDimSizeValue = rewriter.create<tensor::FromElementsOp>(loc,
71+
RankedTensorType::get({}, indicesType.getElementType()),
72+
axisDimSizeValue);
6973
axisDimSize =
7074
rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc, indicesType,
7175
axisDimSizeValue, indicesShape, rewriter.getI64TensorAttr({}));

test/mlir/conversion/onnx_to_stablehlo/Tensor/Gather.mlir

+27
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,33 @@ func.func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
2121

2222
// -----
2323

24+
func.func @test_gather_dynamic_axis0(%arg0 : tensor<?x?xf32>) -> tensor<2x2x?xf32> {
25+
%indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
26+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<?x?xf32>, tensor<2x2xi64>) -> tensor<2x2x?xf32>
27+
"func.return"(%0) : (tensor<2x2x?xf32>) -> ()
28+
}
29+
30+
// CHECK-LABEL: func.func @test_gather_dynamic_axis0
31+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32>) -> tensor<2x2x?xf32> {
32+
// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
33+
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<{{.}}[0, 1], [1, 2]{{.}}> : tensor<2x2xi64>
34+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64>
35+
// CHECK-DAG: [[INDICES_SHAPE_:%.+]] = shape.const_shape [2, 2] : tensor<2xindex>
36+
// CHECK-DAG: [[SHAPE_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x?xf32> -> tensor<2xindex>
37+
// CHECK-DAG: [[DIM_:%.+]] = shape.get_extent [[SHAPE_]], [[C0]] : tensor<2xindex>, index -> index
38+
// CHECK-DAG: [[DIM_CAST_:%.+]] = arith.index_cast [[DIM_]] : index to i64
39+
// CHECK-DAG: [[DIM_TENSOR_:%.+]] = tensor.from_elements [[DIM_CAST_]] : tensor<i64>
40+
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[DIM_TENSOR_]], [[INDICES_SHAPE_]], dims = [] : (tensor<i64>, tensor<2xindex>) -> tensor<2x2xi64>
41+
// CHECK-NOT: separator of consecutive DAGs
42+
// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1>
43+
// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64>
44+
// CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64>
45+
// CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) {batch_dims = 0 : i64, dim = 0 : i64} : (tensor<?x?xf32>, tensor<2x2xi64>) -> tensor<2x2x?xf32>
46+
// CHECK: return [[VAR_6_]] : tensor<2x2x?xf32>
47+
// CHECK: }
48+
49+
// -----
50+
2451
func.func @test_gather_axis0neg(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
2552
%indices = "onnx.Constant"() {value = dense<[[0, -1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64>
2653
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32>

0 commit comments

Comments
 (0)