Skip to content

Commit 9a4e95e

Browse files
formatted for lint
1 parent f4365cd commit 9a4e95e

File tree

2 files changed

+67
-77
lines changed

2 files changed

+67
-77
lines changed

src/Conversion/ONNXToStablehlo/Math/Softmax.cpp

+27-37
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,13 @@ Value getReductionShapeValue(Location loc, PatternRewriter &rewriter,
4646

4747
// Calutes Broadcast dimensions
4848
SmallVector<int64_t> getBroadcastDims(
49-
Value operand, llvm::SmallVector<int64_t, 4> axes, bool keepDims) {
49+
Value operand, llvm::SmallVector<int64_t, 4> axes) {
5050
int64_t rank = mlir::cast<RankedTensorType>(operand.getType()).getRank();
5151
SmallVector<int64_t> dims;
5252
for (int64_t i = 0; i < rank; i++) {
5353
if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) {
5454
dims.push_back(i);
5555
}
56-
else if(keepDims)
57-
dims.push_back(1);
5856
}
5957

6058
return dims;
@@ -123,8 +121,7 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
123121
ConversionPatternRewriter &rewriter) const final {
124122

125123
Value operand = operands[0];
126-
// assert(
127-
// hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
124+
bool isStaticShape = hasStaticShape(operand.getType());
128125

129126
Location loc = op->getLoc();
130127
Type outputType = *op->result_type_begin();
@@ -154,57 +151,50 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
154151
SmallVector<int64_t> reducedShape =
155152
getReductionShape(ExpOutputType, axes, false);
156153
ShapedType ReducedShapeType;
157-
if(hasStaticShape(operand.getType()))
158-
{
159-
ReducedShapeType = mlir::cast<ShapedType>(
160-
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
161-
}
162-
else
163-
{
164-
SmallVector<int64_t> reducedShape_with_dims = getReductionShape(ExpOutputType, axes, true);
154+
if (isStaticShape) {
165155
ReducedShapeType = mlir::cast<ShapedType>(
166-
RankedTensorType::get(reducedShape_with_dims, ExpOutputType.getElementType()));
156+
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
157+
} else {
158+
SmallVector<int64_t> ReducedShapeVector =
159+
getReductionShape(ExpOutputType, axes, true);
160+
ReducedShapeType = mlir::cast<ShapedType>(RankedTensorType::get(
161+
ReducedShapeVector, ExpOutputType.getElementType()));
167162
}
168163
Value identity = rewriter.create<stablehlo::ConstantOp>(
169164
loc, rewriter.getZeroAttr(ExpOutputType.getElementType()));
170165
Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity,
171-
reducedShape, axes, rewriter, !(hasStaticShape(operand.getType())), ReducedShapeType);
166+
reducedShape, axes, rewriter, !isStaticShape, ReducedShapeType);
167+
172168
if (ReduceSum == nullptr)
173169
return failure();
174170

175-
SmallVector<int64_t> broadcast_dims =
176-
getBroadcastDims(ElementwiseExpStableHLO, axes, !(hasStaticShape(operand.getType())));
177-
178171
Value BroadCastOp;
179-
if(hasStaticShape(operand.getType()))
172+
if (isStaticShape) {
173+
SmallVector<int64_t> broadcast_dims =
174+
getBroadcastDims(ElementwiseExpStableHLO, axes);
180175
BroadCastOp =
181176
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
182177
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
183-
else{
184-
//mlir::Value ReshapeOp = rewriter.create<stablehlo::DynamicReshapeOp>(loc, mlir::cast<RankedTensorType>(operand.getType()).getElementType(), ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
185-
// llvm::ArrayRef<int64_t> output_dimensions = mlir::cast<mlir::RankedTensorType>(op->getResultTypes()[0]).getShape();
186-
// mlir::Type i64_type = rewriter.getIntegerType(64);
187-
// mlir::RankedTensorType output_rank = mlir::RankedTensorType::get({ExpOutputType.getRank()}, i64_type);
188-
// mlir::DenseElementsAttr DenseOutputDimensions = mlir::DenseElementsAttr::get(output_rank, output_dimensions);
189-
mlir::Value OutputDimensions = rewriter.create<shape::ShapeOfOp>(loc, operand);
190-
llvm::outs() << OutputDimensions << "\n";
191-
llvm::outs() << ReduceSum << "\n";
192-
SmallVector<int64_t> dims;
193-
for(int64_t i = 0; i < ExpOutputType.getRank(); i++)
194-
dims.push_back(i);
195-
196-
BroadCastOp = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc, ExpOutputType, ReduceSum, OutputDimensions, rewriter.getDenseI64ArrayAttr(dims));//, rewriter.getDenseI64ArrayAttr(known_expanding_dims), rewriter.getDenseI64ArrayAttr(broadcast_dims));
197-
llvm::outs() << BroadCastOp << "\n";
178+
} else {
179+
mlir::Value OutputDimensions =
180+
rewriter.create<shape::ShapeOfOp>(loc, operand);
181+
SmallVector<int64_t> DimIndex;
182+
for (int64_t i = 0; i < ExpOutputType.getRank(); i++)
183+
DimIndex.push_back(i);
184+
BroadCastOp = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc,
185+
ExpOutputType, ReduceSum, OutputDimensions,
186+
rewriter.getDenseI64ArrayAttr(DimIndex));
198187
}
199188
if (BroadCastOp == nullptr)
200189
return failure();
201190

202-
Value Softmax_output = rewriter.create<stablehlo::DivOp>(
191+
Value SoftmaxOutput = rewriter.create<stablehlo::DivOp>(
203192
loc, ElementwiseExpStableHLO, BroadCastOp);
204-
if (Softmax_output == nullptr)
193+
194+
if (SoftmaxOutput == nullptr)
205195
return failure();
206196

207-
rewriter.replaceOp(op, Softmax_output);
197+
rewriter.replaceOp(op, SoftmaxOutput);
208198
return success();
209199
}
210200
};

test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir

+40-40
Original file line numberDiff line numberDiff line change
@@ -32,46 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
3232
"func.return"(%0) : (tensor<?x20x30xf32>) -> ()
3333
}
3434

35-
//TODO: Renable dynamic shape test
36-
// func.func @test_softmax_dynamic
37-
// ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
38-
// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
39-
// [[CST_2_:%.+]] = arith.constant 2 : index
40-
// [[CST_1_:%.+]] = arith.constant 1 : index
41-
// [[CST_0_:%.+]] = arith.constant 0 : index
42-
// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
43-
// separator of consecutive DAGs
44-
// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
45-
// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
46-
// separator of consecutive DAGs
47-
// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
48-
// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
49-
// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
50-
// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
51-
// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
52-
// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
53-
// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
54-
// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
55-
// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
56-
// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
57-
// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
58-
// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
59-
// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
60-
// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
61-
// separator of consecutive DAGs
62-
// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
63-
// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
64-
// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
65-
// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
66-
// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
67-
// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
68-
// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
69-
// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
70-
// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
71-
// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
72-
// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
73-
// return [[VAR_28_]] : tensor<?x20x30xf32>
74-
// }
35+
// CHECK-LABEL: func.func @test_softmax_dynamic
36+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
37+
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
38+
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
39+
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
40+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
41+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
42+
// CHECK-NOT: separator of consecutive DAGs
43+
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
44+
// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
45+
// CHECK-NOT: separator of consecutive DAGs
46+
// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
47+
// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
48+
// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
49+
// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
50+
// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
51+
// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
52+
// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
53+
// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
54+
// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
55+
// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
56+
// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
57+
// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
58+
// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
59+
// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
60+
// CHECK-NOT: separator of consecutive DAGs
61+
// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
62+
// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
63+
// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
64+
// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
65+
// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
66+
// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
67+
// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
68+
// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
69+
// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
70+
// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
71+
// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
72+
// CHECK: return [[VAR_28_]] : tensor<?x20x30xf32>
73+
// CHECK: }
74+
7575

7676
// -----
7777

0 commit comments

Comments
 (0)