@@ -28,41 +28,78 @@ Value getIdentityValue(
28
28
return nullptr ;
29
29
}
30
30
31
- template <>
32
- Value getIdentityValue<ONNXReduceMaxV13Op>(
31
+ Value getReduceMaxIdentityValue (
33
32
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
34
33
MathBuilder createMath (rewriter, loc);
35
34
return rewriter.create <stablehlo::ConstantOp>(
36
35
loc, createMath.negativeInfAttr (elemType));
37
36
}
38
37
39
- template <>
40
- Value getIdentityValue<ONNXReduceMinV13Op>(
38
+ Value getReduceMinIdentityValue (
41
39
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
42
40
MathBuilder createMath (rewriter, loc);
43
41
return rewriter.create <stablehlo::ConstantOp>(
44
42
loc, createMath.positiveInfAttr (elemType));
45
43
}
46
44
47
- template <>
48
- Value getIdentityValue<ONNXReduceSumOp>(
45
+ Value getReduceSumIdentityValue (
49
46
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
50
47
return rewriter.create <stablehlo::ConstantOp>(
51
48
loc, rewriter.getZeroAttr (elemType));
52
49
}
53
50
54
- template <>
55
- Value getIdentityValue<ONNXReduceSumV11Op>(
51
+ Value getReduceMeanIdentityValue (
56
52
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
57
53
return rewriter.create <stablehlo::ConstantOp>(
58
54
loc, rewriter.getZeroAttr (elemType));
59
55
}
60
56
57
+ template <>
58
+ Value getIdentityValue<ONNXReduceMaxOp>(
59
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
60
+ return getReduceMaxIdentityValue (rewriter, loc, elemType);
61
+ }
62
+
63
+ template <>
64
+ Value getIdentityValue<ONNXReduceMaxV13Op>(
65
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
66
+ return getReduceMaxIdentityValue (rewriter, loc, elemType);
67
+ }
68
+
69
+ template <>
70
+ Value getIdentityValue<ONNXReduceMinOp>(
71
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
72
+ return getReduceMinIdentityValue (rewriter, loc, elemType);
73
+ }
74
+
75
+ template <>
76
+ Value getIdentityValue<ONNXReduceMinV13Op>(
77
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
78
+ return getReduceMinIdentityValue (rewriter, loc, elemType);
79
+ }
80
+
81
+ template <>
82
+ Value getIdentityValue<ONNXReduceSumOp>(
83
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
84
+ return getReduceSumIdentityValue (rewriter, loc, elemType);
85
+ }
86
+
87
+ template <>
88
+ Value getIdentityValue<ONNXReduceSumV11Op>(
89
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
90
+ return getReduceSumIdentityValue (rewriter, loc, elemType);
91
+ }
92
+
93
+ template <>
94
+ Value getIdentityValue<ONNXReduceMeanOp>(
95
+ ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
96
+ return getReduceMeanIdentityValue (rewriter, loc, elemType);
97
+ }
98
+
61
99
template <>
62
100
Value getIdentityValue<ONNXReduceMeanV13Op>(
63
101
ConversionPatternRewriter &rewriter, Location loc, Type elemType) {
64
- return rewriter.create <stablehlo::ConstantOp>(
65
- loc, rewriter.getZeroAttr (elemType));
102
+ return getReduceMeanIdentityValue (rewriter, loc, elemType);
66
103
}
67
104
68
105
template <typename ONNXReductionOp>
@@ -78,12 +115,9 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes(Operation *op) {
78
115
return definedAxes;
79
116
}
80
117
81
- template <>
82
- llvm::SmallVector< int64_t , 4 > getDefinedAxes<ONNXReduceSumOp>( Operation *op) {
118
+ llvm::SmallVector< int64_t , 4 > getDefinedAxesFromConstAxes (
119
+ Operation *op, Value axesValue, bool keepDims ) {
83
120
llvm::SmallVector<int64_t , 4 > definedAxes;
84
- ONNXReduceSumOp reduceSumOp = cast<ONNXReduceSumOp>(op);
85
- Value axesValue = reduceSumOp.getAxes ();
86
-
87
121
// Assume it is verified that axes are known. Convert DenseElementsAttr to
88
122
// ArrayAttr.
89
123
if (!isNoneValue (axesValue) && getONNXConstantOp (axesValue)) {
@@ -104,7 +138,7 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
104
138
assert (inputType != nullptr && outputType != nullptr &&
105
139
" not implemented for dynamic axes when either input or output is not "
106
140
" ranked" );
107
- bool keepDims = reduceSumOp. getKeepdims () == 1 ;
141
+
108
142
int64_t inputRank = inputType.getRank ();
109
143
int64_t outputRank = outputType.getRank ();
110
144
llvm::ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -127,22 +161,69 @@ llvm::SmallVector<int64_t, 4> getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
127
161
return definedAxes;
128
162
}
129
163
164
+ template <>
165
+ llvm::SmallVector<int64_t , 4 > getDefinedAxes<ONNXReduceMaxOp>(Operation *op) {
166
+ ONNXReduceMaxOp reduceMaxOp = cast<ONNXReduceMaxOp>(op);
167
+ Value axesValue = reduceMaxOp.getAxes ();
168
+ bool keepDims = reduceMaxOp.getKeepdims () == 1 ;
169
+ return getDefinedAxesFromConstAxes (op, axesValue, keepDims);
170
+ }
171
+
172
+ template <>
173
+ llvm::SmallVector<int64_t , 4 > getDefinedAxes<ONNXReduceMinOp>(Operation *op) {
174
+ ONNXReduceMinOp reduceMinOp = cast<ONNXReduceMinOp>(op);
175
+ Value axesValue = reduceMinOp.getAxes ();
176
+ bool keepDims = reduceMinOp.getKeepdims () == 1 ;
177
+ return getDefinedAxesFromConstAxes (op, axesValue, keepDims);
178
+ }
179
+
180
+ template <>
181
+ llvm::SmallVector<int64_t , 4 > getDefinedAxes<ONNXReduceSumOp>(Operation *op) {
182
+ ONNXReduceSumOp reduceSumOp = cast<ONNXReduceSumOp>(op);
183
+ Value axesValue = reduceSumOp.getAxes ();
184
+ bool keepDims = reduceSumOp.getKeepdims () == 1 ;
185
+ return getDefinedAxesFromConstAxes (op, axesValue, keepDims);
186
+ }
187
+
188
+ template <>
189
+ llvm::SmallVector<int64_t , 4 > getDefinedAxes<ONNXReduceMeanOp>(Operation *op) {
190
+ ONNXReduceMeanOp reduceMeanOp = cast<ONNXReduceMeanOp>(op);
191
+ Value axesValue = reduceMeanOp.getAxes ();
192
+ bool keepDims = reduceMeanOp.getKeepdims () == 1 ;
193
+ return getDefinedAxesFromConstAxes (op, axesValue, keepDims);
194
+ }
195
+
130
196
// Block reduce ops
131
197
template <typename ReductionOp>
132
198
struct BlockReduceOp {
133
199
using Op = void ;
134
200
};
135
201
202
+ template <>
203
+ struct BlockReduceOp <ONNXReduceMaxOp> {
204
+ using Op = stablehlo::MaxOp;
205
+ };
206
+
136
207
template <>
137
208
struct BlockReduceOp <ONNXReduceMaxV13Op> {
138
209
using Op = stablehlo::MaxOp;
139
210
};
140
211
212
+ template <>
213
+ struct BlockReduceOp <ONNXReduceMinOp> {
214
+ using Op = stablehlo::MinOp;
215
+ };
216
+
141
217
template <>
142
218
struct BlockReduceOp <ONNXReduceMinV13Op> {
143
219
using Op = stablehlo::MinOp;
144
220
};
145
221
222
+ template <>
223
+ struct BlockReduceOp <ONNXReduceMeanOp> {
224
+ using Op = stablehlo::AddOp;
225
+ };
226
+
146
227
template <>
147
228
struct BlockReduceOp <ONNXReduceMeanV13Op> {
148
229
using Op = stablehlo::AddOp;
@@ -355,10 +436,14 @@ struct ONNXReductionOpLoweringToStablehlo : public ConversionPattern {
355
436
356
437
void populateLoweringONNXReductionOpToStablehloPattern (
357
438
RewritePatternSet &patterns, MLIRContext *ctx) {
358
- patterns.insert <ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxV13Op>,
439
+ patterns.insert <ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxOp>,
440
+ ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMaxV13Op>,
441
+ ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMinOp>,
359
442
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMinV13Op>,
360
443
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceSumOp>,
361
444
ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceSumV11Op>>(ctx);
445
+ patterns.insert <ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMeanOp>>(
446
+ ctx, /* computeMean=*/ true );
362
447
patterns
363
448
.insert <ONNXReductionOpLoweringToStablehlo<mlir::ONNXReduceMeanV13Op>>(
364
449
ctx, /* computeMean=*/ true );
0 commit comments