Skip to content

Commit f563c41

Browse files
srcarrolltungld
andauthored
Relax dynamic restriction on ScatterND stablehlo conversion (#2772)
Signed-off-by: Sam <[email protected]> Co-authored-by: Tung D. Le <[email protected]>
1 parent 08d4fed commit f563c41

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/Conversion/ONNXToStablehlo/Tensor/ScatterND.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ struct ONNXScatterNDOpLoweringToStablehlo
4040
auto indicesType = indices.getType().cast<ShapedType>();
4141
int64_t dataRank = dataType.getRank();
4242
int64_t indicesRank = indicesType.getRank();
43-
assert(indicesType.hasStaticShape() &&
44-
"only support indices with static shape");
43+
if (indicesType.isDynamicDim(indicesRank - 1))
44+
return rewriter.notifyMatchFailure(
45+
op, "only support indices with static last dim");
4546
int64_t partialIdxDim = indicesType.getDimSize(indicesRank - 1);
4647

4748
assert(dataRank >= 1 && "The rank of 'data' must be >= 1");

test/mlir/conversion/onnx_to_stablehlo/Tensor/ScatterND.mlir

+15
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,18 @@ func.func @test_scatternd_2(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<2x1xi64>,
2727
// CHECK: return [[VAR_0_]] : tensor<4x4x4xi32>
2828
// CHECK: }
2929
}
30+
31+
// -----
32+
33+
func.func @test_scatternd_dynamic(%arg0 : tensor<1x?x32x128xf32>, %arg1 : tensor<?x?x32x64x4xi64>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<1x?x32x128xf32> {
34+
%0 = "onnx.ScatterND"(%arg0, %arg1, %arg2) : (tensor<1x?x32x128xf32>, tensor<?x?x32x64x4xi64>, tensor<?x?x?x?xf32>) -> tensor<1x?x32x128xf32>
35+
return %0 : tensor<1x?x32x128xf32>
36+
// CHECK-LABEL: func.func @test_scatternd_dynamic
37+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x?x32x128xf32>, [[PARAM_1_:%.+]]: tensor<?x?x32x64x4xi64>, [[PARAM_2_:%.+]]: tensor<?x?x?x?xf32>) -> tensor<1x?x32x128xf32> {
38+
// CHECK: [[VAR_0_:%.+]] = "stablehlo.scatter"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) ({
39+
// CHECK: ^bb0([[arg3_:%.+]]: tensor<f32>, [[arg4_:%.+]]: tensor<f32>):
40+
// CHECK: stablehlo.return [[arg4_]] : tensor<f32>
41+
// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 4>, unique_indices = false} : (tensor<1x?x32x128xf32>, tensor<?x?x32x64x4xi64>, tensor<?x?x?x?xf32>) -> tensor<1x?x32x128xf32>
42+
// CHECK: return [[VAR_0_]] : tensor<1x?x32x128xf32>
43+
// CHECK: }
44+
}

0 commit comments

Comments
 (0)