Skip to content

Commit ea8e277

Browse files
authored
[stablehlo] Reduction upgrade (#2745)
* [Stablehlo] Reduction Op Upgrade Signed-off-by: yan.xu0210 <[email protected]> * add test for reduction ops Signed-off-by: yan.xu0210 <[email protected]> --------- Signed-off-by: yan.xu0210 <[email protected]>
1 parent 18f4e07 commit ea8e277

File tree

2 files changed

+156
-17
lines changed

2 files changed

+156
-17
lines changed

src/Conversion/ONNXToStablehlo/Math/Reduction.cpp

+102-17
Original file line numberDiff line numberDiff line change
@@ -28,41 +28,78 @@ Value getIdentityValue(
2828
return nullptr;
2929
}
3030

31-
template <>
32-
Value getIdentityValue<ONNXReduceMaxV13Op>(
31+
Value getReduceMaxIdentityValue(
3332
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
3433
MathBuilder createMath(rewriter, loc);
3534
return rewriter.create<stablehlo::ConstantOp>(
3635
loc, createMath.negativeInfAttr(elemType));
3736
}
3837

39-
template <>
40-
Value getIdentityValue<ONNXReduceMinV13Op>(
38+
Value getReduceMinIdentityValue(
4139
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
4240
MathBuilder createMath(rewriter, loc);
4341
return rewriter.create<stablehlo::ConstantOp>(
4442
loc, createMath.positiveInfAttr(elemType));
4543
}
4644

47-
template <>
48-
Value getIdentityValue<ONNXReduceSumOp>(
45+
Value getReduceSumIdentityValue(
4946
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
5047
return rewriter.create<stablehlo::ConstantOp>(
5148
loc, rewriter.getZeroAttr(elemType));
5249
}
5350

54-
template <>
55-
Value getIdentityValue<ONNXReduceSumV11Op>(
51+
Value getReduceMeanIdentityValue(
5652
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
5753
return rewriter.create<stablehlo::ConstantOp>(
5854
loc, rewriter.getZeroAttr(elemType));
5955
}
6056

57+
template <>
58+
Value getIdentityValue<ONNXReduceMaxOp>(
59+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
60+
return getReduceMaxIdentityValue(rewriter, loc, elemType);
61+
}
62+
63+
template <>
64+
Value getIdentityValue<ONNXReduceMaxV13Op>(
65+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
66+
return getReduceMaxIdentityValue(rewriter, loc, elemType);
67+
}
68+
69+
template <>
70+
Value getIdentityValue<ONNXReduceMinOp>(
71+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
72+
return getReduceMinIdentityValue(rewriter, loc, elemType);
73+
}
74+
75+
template <>
76+
Value getIdentityValue<ONNXReduceMinV13Op>(
77+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
78+
return getReduceMinIdentityValue(rewriter, loc, elemType);
79+
}
80+
81+
template <>
82+
Value getIdentityValue<ONNXReduceSumOp>(
83+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
84+
return getReduceSumIdentityValue(rewriter, loc, elemType);
85+
}
86+
87+
template <>
88+
Value getIdentityValue<ONNXReduceSumV11Op>(
89+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
90+
return getReduceSumIdentityValue(rewriter, loc, elemType);
91+
}
92+
93+
template <>
94+
Value getIdentityValue<ONNXReduceMeanOp>(
95+
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
96+
return getReduceMeanIdentityValue(rewriter, loc, elemType);
97+
}
98+
6199
template <>
62100
Value getIdentityValue<ONNXReduceMeanV13Op>(
63101
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
64-
return rewriter.create<stablehlo::ConstantOp>(
65-
loc, rewriter.getZeroAttr(elemType));
102+
return getReduceMeanIdentityValue(rewriter, loc, elemType);
66103
}
67104

68105
template <typename ONNXReductionOp>
@@ -78,12 +115,9 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes(Operation *op) {
78115
return definedAxes;
79116
}
80117

81-
template <>
82-
llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
118+
llvm::SmallVector<int64_t, 4> getDefinedAxesFromConstAxes(
119+
Operation *op, Value axesValue, bool keepDims) {
83120
llvm::SmallVector<int64_t, 4> definedAxes;
84-
ONNXReduceSumOp reduceSumOp = cast<ONNXReduceSumOp>(op);
85-
Value axesValue = reduceSumOp.getAxes();
86-
87121
// Assume it is verified that axes are known. Convert DenseElementsAttr to
88122
// ArrayAttr.
89123
if (!isNoneValue(axesValue) && getONNXConstantOp(axesValue)) {
@@ -104,7 +138,7 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
104138
assert(inputType != nullptr && outputType != nullptr &&
105139
"not implemented for dynamic axes when either input or output is not "
106140
"ranked");
107-
bool keepDims = reduceSumOp.getKeepdims() == 1;
141+
108142
int64_t inputRank = inputType.getRank();
109143
int64_t outputRank = outputType.getRank();
110144
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -127,22 +161,69 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
127161
return definedAxes;
128162
}
129163

164+
template <>
165+
llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceMaxOp>(Operation *op) {
166+
ONNXReduceMaxOp reduceMaxOp = cast<ONNXReduceMaxOp>(op);
167+
Value axesValue = reduceMaxOp.getAxes();
168+
bool keepDims = reduceMaxOp.getKeepdims() == 1;
169+
return getDefinedAxesFromConstAxes(op, axesValue, keepDims);
170+
}
171+
172+
template <>
173+
llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceMinOp>(Operation *op) {
174+
ONNXReduceMinOp reduceMinOp = cast<ONNXReduceMinOp>(op);
175+
Value axesValue = reduceMinOp.getAxes();
176+
bool keepDims = reduceMinOp.getKeepdims() == 1;
177+
return getDefinedAxesFromConstAxes(op, axesValue, keepDims);
178+
}
179+
180+
template <>
181+
llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
182+
ONNXReduceSumOp reduceSumOp = cast<ONNXReduceSumOp>(op);
183+
Value axesValue = reduceSumOp.getAxes();
184+
bool keepDims = reduceSumOp.getKeepdims() == 1;
185+
return getDefinedAxesFromConstAxes(op, axesValue, keepDims);
186+
}
187+
188+
template <>
189+
llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceMeanOp>(Operation *op) {
190+
ONNXReduceMeanOp reduceMeanOp = cast<ONNXReduceMeanOp>(op);
191+
Value axesValue = reduceMeanOp.getAxes();
192+
bool keepDims = reduceMeanOp.getKeepdims() == 1;
193+
return getDefinedAxesFromConstAxes(op, axesValue, keepDims);
194+
}
195+
130196
// Block reduce ops
131197
template <typename ReductionOp>
132198
struct BlockReduceOp {
133199
using Op = void;
134200
};
135201

202+
template <>
203+
struct BlockReduceOp<ONNXReduceMaxOp> {
204+
using Op = stablehlo::MaxOp;
205+
};
206+
136207
template <>
137208
struct BlockReduceOp<ONNXReduceMaxV13Op> {
138209
using Op = stablehlo::MaxOp;
139210
};
140211

212+
template <>
213+
struct BlockReduceOp<ONNXReduceMinOp> {
214+
using Op = stablehlo::MinOp;
215+
};
216+
141217
template <>
142218
struct BlockReduceOp<ONNXReduceMinV13Op> {
143219
using Op = stablehlo::MinOp;
144220
};
145221

222+
template <>
223+
struct BlockReduceOp<ONNXReduceMeanOp> {
224+
using Op = stablehlo::AddOp;
225+
};
226+
146227
template <>
147228
struct BlockReduceOp<ONNXReduceMeanV13Op> {
148229
using Op = stablehlo::AddOp;
@@ -355,10 +436,14 @@ struct ONNXReductionOpLoweringToStablehlo : public ConversionPattern {
355436

356437
void populateLoweringONNXReductionOpToStablehloPattern(
357438
RewritePatternSet &patterns, MLIRContext *ctx) {
358-
patterns.insert<ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxV13Op>,
439+
patterns.insert<ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxOp>,
440+
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxV13Op>,
441+
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMinOp>,
359442
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMinV13Op>,
360443
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceSumOp>,
361444
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceSumV11Op>>(ctx);
445+
patterns.insert<ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMeanOp>>(
446+
ctx, /*computeMean=*/true);
362447
patterns
363448
.insert<ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMeanV13Op>>(
364449
ctx, /*computeMean=*/true);

test/mlir/conversion/onnx_to_stablehlo/Math/Reduction.mlir

+54
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,57 @@ func.func @test_reducemean_v13_2(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
175175
// CHECK: [[VAR_6_:%.+]] = stablehlo.divide [[VAR_2_]], [[VAR_5_]] : tensor<?x?xf32>
176176
// CHECK: return [[VAR_6_]] : tensor<?x?xf32>
177177
// CHECK: }
178+
179+
// -----
180+
181+
func.func @reduce_mean(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
182+
%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
183+
%1 = "onnx.ReduceMean"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32>
184+
return %1 : tensor<2x5x1x1xf32>
185+
}
186+
187+
// CHECK-LABEL: func.func @reduce_mean
188+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
189+
// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex>
190+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<9.900000e+01> : tensor<2x5x1x1xf32>
191+
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
192+
// CHECK: [[VAR_3_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor<f32>) -> tensor<2x5xf32>
193+
// CHECK: [[VAR_4_:%.+]] = stablehlo.dynamic_reshape [[VAR_3_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32>
194+
// CHECK: [[VAR_5_:%.+]] = stablehlo.divide [[VAR_4_]], [[VAR_1_]] : tensor<2x5x1x1xf32>
195+
// CHECK: return [[VAR_5_]] : tensor<2x5x1x1xf32>
196+
// CHECK: }
197+
198+
// -----
199+
200+
func.func @reduce_max(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
201+
%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
202+
%1 = "onnx.ReduceMax"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32>
203+
return %1 : tensor<2x5x1x1xf32>
204+
}
205+
206+
// CHECK-LABEL: func.func @reduce_max
207+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
208+
// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex>
209+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
210+
// CHECK: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor<f32>) -> tensor<2x5xf32>
211+
// CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32>
212+
// CHECK: return [[VAR_3_]] : tensor<2x5x1x1xf32>
213+
// CHECK: }
214+
215+
// -----
216+
217+
218+
func.func @reduce_min(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
219+
%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64>
220+
%1 = "onnx.ReduceMin"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32>
221+
return %1 : tensor<2x5x1x1xf32>
222+
}
223+
224+
// CHECK-LABEL: func.func @reduce_min
225+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> {
226+
// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [2, 5, 1, 1] : tensor<4xindex>
227+
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0x7F800000> : tensor<f32>
228+
// CHECK: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.minimum across dimensions = [2, 3] : (tensor<2x5x9x11xf32>, tensor<f32>) -> tensor<2x5xf32>
229+
// CHECK: [[VAR_3_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4xindex>) -> tensor<2x5x1x1xf32>
230+
// CHECK: return [[VAR_3_]] : tensor<2x5x1x1xf32>
231+
// CHECK: }

0 commit comments

Comments
 (0)