@@ -46,13 +46,15 @@ Value getReductionShapeValue(Location loc, PatternRewriter &rewriter,
46
46
47
47
// Calutes Broadcast dimensions
48
48
SmallVector<int64_t > getBroadcastDims (
49
- Value operand, llvm::SmallVector<int64_t , 4 > axes) {
49
+ Value operand, llvm::SmallVector<int64_t , 4 > axes, bool keepDims ) {
50
50
int64_t rank = mlir::cast<RankedTensorType>(operand.getType ()).getRank ();
51
51
SmallVector<int64_t > dims;
52
52
for (int64_t i = 0 ; i < rank; i++) {
53
53
if (!(std::find (axes.begin (), axes.end (), i) != axes.end ())) {
54
54
dims.push_back (i);
55
55
}
56
+ else if (keepDims)
57
+ dims.push_back (1 );
56
58
}
57
59
58
60
return dims;
@@ -121,8 +123,8 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
121
123
ConversionPatternRewriter &rewriter) const final {
122
124
123
125
Value operand = operands[0 ];
124
- assert (
125
- hasStaticShape (operand.getType ()) && " Only Static shapes are accepted" );
126
+ // assert(
127
+ // hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
126
128
127
129
Location loc = op->getLoc ();
128
130
Type outputType = *op->result_type_begin ();
@@ -151,20 +153,49 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
151
153
// Sum of the all the exponents for the denominator
152
154
SmallVector<int64_t > reducedShape =
153
155
getReductionShape (ExpOutputType, axes, false );
154
- ShapedType ReducedShapeType = mlir::cast<ShapedType>(
156
+ ShapedType ReducedShapeType;
157
+ if (hasStaticShape (operand.getType ()))
158
+ {
159
+ ReducedShapeType = mlir::cast<ShapedType>(
155
160
RankedTensorType::get (reducedShape, ExpOutputType.getElementType ()));
161
+ }
162
+ else
163
+ {
164
+ SmallVector<int64_t > reducedShape_with_dims = getReductionShape (ExpOutputType, axes, true );
165
+ ReducedShapeType = mlir::cast<ShapedType>(
166
+ RankedTensorType::get (reducedShape_with_dims, ExpOutputType.getElementType ()));
167
+ }
156
168
Value identity = rewriter.create <stablehlo::ConstantOp>(
157
169
loc, rewriter.getZeroAttr (ExpOutputType.getElementType ()));
158
170
Value ReduceSum = computeReduceSum (loc, ElementwiseExpStableHLO, identity,
159
- reducedShape, axes, rewriter, false , ReducedShapeType);
171
+ reducedShape, axes, rewriter, !( hasStaticShape (operand. getType ())) , ReducedShapeType);
160
172
if (ReduceSum == nullptr )
161
173
return failure ();
162
174
163
175
SmallVector<int64_t > broadcast_dims =
164
- getBroadcastDims (ElementwiseExpStableHLO, axes);
165
- Value BroadCastOp =
166
- rewriter.create <stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
167
- ReduceSum, rewriter.getDenseI64ArrayAttr (broadcast_dims));
176
+ getBroadcastDims (ElementwiseExpStableHLO, axes, !(hasStaticShape (operand.getType ())));
177
+
178
+ Value BroadCastOp;
179
+ if (hasStaticShape (operand.getType ()))
180
+ BroadCastOp =
181
+ rewriter.create <stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
182
+ ReduceSum, rewriter.getDenseI64ArrayAttr (broadcast_dims));
183
+ else {
184
+ // mlir::Value ReshapeOp = rewriter.create<stablehlo::DynamicReshapeOp>(loc, mlir::cast<RankedTensorType>(operand.getType()).getElementType(), ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
185
+ // llvm::ArrayRef<int64_t> output_dimensions = mlir::cast<mlir::RankedTensorType>(op->getResultTypes()[0]).getShape();
186
+ // mlir::Type i64_type = rewriter.getIntegerType(64);
187
+ // mlir::RankedTensorType output_rank = mlir::RankedTensorType::get({ExpOutputType.getRank()}, i64_type);
188
+ // mlir::DenseElementsAttr DenseOutputDimensions = mlir::DenseElementsAttr::get(output_rank, output_dimensions);
189
+ mlir::Value OutputDimensions = rewriter.create <shape::ShapeOfOp>(loc, operand);
190
+ llvm::outs () << OutputDimensions << " \n " ;
191
+ llvm::outs () << ReduceSum << " \n " ;
192
+ SmallVector<int64_t > dims;
193
+ for (int64_t i = 0 ; i < ExpOutputType.getRank (); i++)
194
+ dims.push_back (i);
195
+
196
+ BroadCastOp = rewriter.create <stablehlo::DynamicBroadcastInDimOp>(loc, ExpOutputType, ReduceSum, OutputDimensions, rewriter.getDenseI64ArrayAttr (dims));// , rewriter.getDenseI64ArrayAttr(known_expanding_dims), rewriter.getDenseI64ArrayAttr(broadcast_dims));
197
+ llvm::outs () << BroadCastOp << " \n " ;
198
+ }
168
199
if (BroadCastOp == nullptr )
169
200
return failure ();
170
201
0 commit comments