Skip to content

Commit 4f0a141

Browse files
Softmax op dynamic shape addition for stablehlo (#2918)
Signed-off-by: Abhishek-TyRnT <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent 4e99738 commit 4f0a141

File tree

2 files changed

+74
-53
lines changed

2 files changed

+74
-53
lines changed

src/Conversion/ONNXToStablehlo/Math/Softmax.cpp

+34-13
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
121121
ConversionPatternRewriter &rewriter) const final {
122122

123123
Value operand = operands[0];
124-
assert(
125-
hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
124+
bool isStaticShape = hasStaticShape(operand.getType());
126125

127126
Location loc = op->getLoc();
128127
Type outputType = *op->result_type_begin();
@@ -151,29 +150,51 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
151150
// Sum of the all the exponents for the denominator
152151
SmallVector<int64_t> reducedShape =
153152
getReductionShape(ExpOutputType, axes, false);
154-
ShapedType ReducedShapeType = mlir::cast<ShapedType>(
155-
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
153+
ShapedType ReducedShapeType;
154+
if (isStaticShape) {
155+
ReducedShapeType = mlir::cast<ShapedType>(
156+
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
157+
} else {
158+
SmallVector<int64_t> ReducedShapeVector =
159+
getReductionShape(ExpOutputType, axes, true);
160+
ReducedShapeType = mlir::cast<ShapedType>(RankedTensorType::get(
161+
ReducedShapeVector, ExpOutputType.getElementType()));
162+
}
156163
Value identity = rewriter.create<stablehlo::ConstantOp>(
157164
loc, rewriter.getZeroAttr(ExpOutputType.getElementType()));
158165
Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity,
159-
reducedShape, axes, rewriter, false, ReducedShapeType);
166+
reducedShape, axes, rewriter, !isStaticShape, ReducedShapeType);
167+
160168
if (ReduceSum == nullptr)
161169
return failure();
162170

163-
SmallVector<int64_t> broadcast_dims =
164-
getBroadcastDims(ElementwiseExpStableHLO, axes);
165-
Value BroadCastOp =
166-
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
167-
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
171+
Value BroadCastOp;
172+
if (isStaticShape) {
173+
SmallVector<int64_t> broadcast_dims =
174+
getBroadcastDims(ElementwiseExpStableHLO, axes);
175+
BroadCastOp =
176+
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
177+
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
178+
} else {
179+
mlir::Value OutputDimensions =
180+
rewriter.create<shape::ShapeOfOp>(loc, operand);
181+
SmallVector<int64_t> DimIndex;
182+
for (int64_t i = 0; i < ExpOutputType.getRank(); i++)
183+
DimIndex.push_back(i);
184+
BroadCastOp = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc,
185+
ExpOutputType, ReduceSum, OutputDimensions,
186+
rewriter.getDenseI64ArrayAttr(DimIndex));
187+
}
168188
if (BroadCastOp == nullptr)
169189
return failure();
170190

171-
Value Softmax_output = rewriter.create<stablehlo::DivOp>(
191+
Value SoftmaxOutput = rewriter.create<stablehlo::DivOp>(
172192
loc, ElementwiseExpStableHLO, BroadCastOp);
173-
if (Softmax_output == nullptr)
193+
194+
if (SoftmaxOutput == nullptr)
174195
return failure();
175196

176-
rewriter.replaceOp(op, Softmax_output);
197+
rewriter.replaceOp(op, SoftmaxOutput);
177198
return success();
178199
}
179200
};

test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir

+40-40
Original file line numberDiff line numberDiff line change
@@ -32,46 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
3232
"func.return"(%0) : (tensor<?x20x30xf32>) -> ()
3333
}
3434

35-
//TODO: Renable dynamic shape test
36-
// func.func @test_softmax_dynamic
37-
// ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
38-
// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
39-
// [[CST_2_:%.+]] = arith.constant 2 : index
40-
// [[CST_1_:%.+]] = arith.constant 1 : index
41-
// [[CST_0_:%.+]] = arith.constant 0 : index
42-
// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
43-
// separator of consecutive DAGs
44-
// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
45-
// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
46-
// separator of consecutive DAGs
47-
// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
48-
// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
49-
// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
50-
// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
51-
// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
52-
// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
53-
// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
54-
// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
55-
// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
56-
// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
57-
// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
58-
// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
59-
// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
60-
// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
61-
// separator of consecutive DAGs
62-
// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
63-
// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
64-
// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
65-
// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
66-
// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
67-
// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
68-
// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
69-
// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
70-
// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
71-
// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
72-
// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
73-
// return [[VAR_28_]] : tensor<?x20x30xf32>
74-
// }
35+
// CHECK-LABEL: func.func @test_softmax_dynamic
36+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
37+
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
38+
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
39+
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
40+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
41+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
42+
// CHECK-NOT: separator of consecutive DAGs
43+
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
44+
// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
45+
// CHECK-NOT: separator of consecutive DAGs
46+
// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
47+
// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
48+
// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
49+
// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
50+
// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
51+
// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
52+
// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
53+
// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
54+
// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
55+
// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
56+
// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
57+
// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
58+
// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
59+
// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
60+
// CHECK-NOT: separator of consecutive DAGs
61+
// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
62+
// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
63+
// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
64+
// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
65+
// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
66+
// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
67+
// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
68+
// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
69+
// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
70+
// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
71+
// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
72+
// CHECK: return [[VAR_28_]] : tensor<?x20x30xf32>
73+
// CHECK: }
74+
7575

7676
// -----
7777

0 commit comments

Comments
 (0)