Skip to content

Commit f4365cd

Browse files
Added support for softmax Dynamic shapes
1 parent b861652 commit f4365cd

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/Conversion/ONNXToStablehlo/Math/Softmax.cpp

+40-9
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ Value getReductionShapeValue(Location loc, PatternRewriter &rewriter,
4646

4747
// Calutes Broadcast dimensions
4848
SmallVector<int64_t> getBroadcastDims(
49-
Value operand, llvm::SmallVector<int64_t, 4> axes) {
49+
Value operand, llvm::SmallVector<int64_t, 4> axes, bool keepDims) {
5050
int64_t rank = mlir::cast<RankedTensorType>(operand.getType()).getRank();
5151
SmallVector<int64_t> dims;
5252
for (int64_t i = 0; i < rank; i++) {
5353
if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) {
5454
dims.push_back(i);
5555
}
56+
else if(keepDims)
57+
dims.push_back(1);
5658
}
5759

5860
return dims;
@@ -121,8 +123,8 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
121123
ConversionPatternRewriter &rewriter) const final {
122124

123125
Value operand = operands[0];
124-
assert(
125-
hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
126+
// assert(
127+
// hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
126128

127129
Location loc = op->getLoc();
128130
Type outputType = *op->result_type_begin();
@@ -151,20 +153,49 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
151153
// Sum of the all the exponents for the denominator
152154
SmallVector<int64_t> reducedShape =
153155
getReductionShape(ExpOutputType, axes, false);
154-
ShapedType ReducedShapeType = mlir::cast<ShapedType>(
156+
ShapedType ReducedShapeType;
157+
if(hasStaticShape(operand.getType()))
158+
{
159+
ReducedShapeType = mlir::cast<ShapedType>(
155160
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
161+
}
162+
else
163+
{
164+
SmallVector<int64_t> reducedShape_with_dims = getReductionShape(ExpOutputType, axes, true);
165+
ReducedShapeType = mlir::cast<ShapedType>(
166+
RankedTensorType::get(reducedShape_with_dims, ExpOutputType.getElementType()));
167+
}
156168
Value identity = rewriter.create<stablehlo::ConstantOp>(
157169
loc, rewriter.getZeroAttr(ExpOutputType.getElementType()));
158170
Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity,
159-
reducedShape, axes, rewriter, false, ReducedShapeType);
171+
reducedShape, axes, rewriter, !(hasStaticShape(operand.getType())), ReducedShapeType);
160172
if (ReduceSum == nullptr)
161173
return failure();
162174

163175
SmallVector<int64_t> broadcast_dims =
164-
getBroadcastDims(ElementwiseExpStableHLO, axes);
165-
Value BroadCastOp =
166-
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
167-
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
176+
getBroadcastDims(ElementwiseExpStableHLO, axes, !(hasStaticShape(operand.getType())));
177+
178+
Value BroadCastOp;
179+
if(hasStaticShape(operand.getType()))
180+
BroadCastOp =
181+
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
182+
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
183+
else{
184+
//mlir::Value ReshapeOp = rewriter.create<stablehlo::DynamicReshapeOp>(loc, mlir::cast<RankedTensorType>(operand.getType()).getElementType(), ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
185+
// llvm::ArrayRef<int64_t> output_dimensions = mlir::cast<mlir::RankedTensorType>(op->getResultTypes()[0]).getShape();
186+
// mlir::Type i64_type = rewriter.getIntegerType(64);
187+
// mlir::RankedTensorType output_rank = mlir::RankedTensorType::get({ExpOutputType.getRank()}, i64_type);
188+
// mlir::DenseElementsAttr DenseOutputDimensions = mlir::DenseElementsAttr::get(output_rank, output_dimensions);
189+
mlir::Value OutputDimensions = rewriter.create<shape::ShapeOfOp>(loc, operand);
190+
llvm::outs() << OutputDimensions << "\n";
191+
llvm::outs() << ReduceSum << "\n";
192+
SmallVector<int64_t> dims;
193+
for(int64_t i = 0; i < ExpOutputType.getRank(); i++)
194+
dims.push_back(i);
195+
196+
BroadCastOp = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc, ExpOutputType, ReduceSum, OutputDimensions, rewriter.getDenseI64ArrayAttr(dims));//, rewriter.getDenseI64ArrayAttr(known_expanding_dims), rewriter.getDenseI64ArrayAttr(broadcast_dims));
197+
llvm::outs() << BroadCastOp << "\n";
198+
}
168199
if (BroadCastOp == nullptr)
169200
return failure();
170201

0 commit comments

Comments
 (0)