19
19
#include " mlir/Pass/Pass.h"
20
20
#include " mlir/Transforms/DialectConversion.h"
21
21
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22
+ #include " llvm/ADT/APFloat.h"
22
23
#include " llvm/ADT/STLExtras.h"
23
24
#include " llvm/ADT/StringSet.h"
24
25
#include " llvm/Support/Debug.h"
@@ -122,6 +123,16 @@ Value createReplacingConstantOp(
122
123
template <typename T>
123
124
using EnableNotBool = std::enable_if_t <!std::is_same_v<T, bool >>;
124
125
126
+ template <typename T>
127
+ using EnableBool = std::enable_if_t <std::is_same_v<T, bool >>;
128
+
129
+ template <typename T>
130
+ using EnableInteger =
131
+ std::enable_if_t <std::is_integral_v<T> && !std::is_same_v<T, bool >>;
132
+
133
+ template <typename T>
134
+ using EnableFloatingPoint = std::enable_if_t <std::is_floating_point_v<T>>;
135
+
125
136
// / Checks whether a variadic value is produced by dense ONNXConstantOps.
126
137
bool isVariadicOperandFromDenseONNXConstantOp (ValueRange operands) {
127
138
return llvm::all_of (operands, [](Value v) { return isDenseONNXConstant (v); });
@@ -134,6 +145,47 @@ Value ConstZeroTensor(
134
145
type, rewriter.getZeroAttr (type.getElementType ())));
135
146
}
136
147
148
+ template <typename GetFPConstFunc =
149
+ std::function<APFloat(const llvm::fltSemantics &, bool )>,
150
+ typename GetIntConstFunc = std::function<APInt(unsigned )>>
151
+ Value getClipConstantOfType (PatternRewriter &rewriter, ShapedType type,
152
+ Location loc, GetFPConstFunc fpConstantFunc, bool isNegative,
153
+ GetIntConstFunc intConstantFunc) {
154
+ OnnxBuilder create (rewriter, loc);
155
+ auto elemType = type.getElementType ();
156
+ if (auto floatType = dyn_cast<FloatType>(elemType)) {
157
+ auto fpValue =
158
+ fpConstantFunc (floatType.getFloatSemantics (), /* Negative=*/ isNegative);
159
+ return create.constant (DenseElementsAttr::get (
160
+ RankedTensorType::get ({}, elemType), llvm::ArrayRef (fpValue)));
161
+ }
162
+ auto intValue = intConstantFunc (elemType.getIntOrFloatBitWidth ());
163
+ return create.constant (DenseElementsAttr::get (
164
+ RankedTensorType::get ({}, elemType), llvm::ArrayRef (intValue)));
165
+ }
166
+
167
+ Value createMaximumValueForClip (
168
+ PatternRewriter &rewriter, ShapedType type, Value value) {
169
+
170
+ // Return 'value' if exists, as there is no need to clip to largest.
171
+ if (!isNoneValue (value))
172
+ return value;
173
+
174
+ return getClipConstantOfType (rewriter, type, value.getLoc (),
175
+ llvm::APFloat::getLargest, false , llvm::APInt::getMaxValue);
176
+ }
177
+
178
+ Value createMinimumValueForClip (
179
+ PatternRewriter &rewriter, ShapedType type, Value value) {
180
+
181
+ // Return 'value' if exists, as there is no need to clip to lowest.
182
+ if (!isNoneValue (value))
183
+ return value;
184
+
185
+ return getClipConstantOfType (rewriter, type, value.getLoc (),
186
+ llvm::APFloat::getLargest, true , llvm::APInt::getMinValue);
187
+ }
188
+
137
189
WideNum asWideNum (double n, Type elemType) {
138
190
return wideZeroDispatch (elemType, [n](auto wideZero) {
139
191
using cpptype = decltype (wideZero);
@@ -203,7 +255,41 @@ struct ElementWiseBinaryOpImpl<ONNXMulOp, T, EnableNotBool<T>> {
203
255
204
256
template <typename T>
205
257
struct ElementWiseBinaryOpImpl <ONNXDivOp, T, EnableNotBool<T>> {
206
- static T eval (T lhs, T rhs) { return lhs / rhs; }
258
+ static T eval (T lhs, T rhs) {
259
+ if constexpr (std::is_integral_v<T>) {
260
+ if (rhs == 0 ) {
261
+ // Undefined behavior. We can return any value.
262
+ // Performing the divison would crash.
263
+ return lhs;
264
+ }
265
+ }
266
+ return lhs / rhs;
267
+ }
268
+ };
269
+
270
+ template <typename T>
271
+ struct ElementWiseBinaryOpImpl <ONNXBitwiseAndOp, T, EnableInteger<T>> {
272
+ static T eval (T lhs, T rhs) { return lhs & rhs; }
273
+ };
274
+
275
+ template <typename T>
276
+ struct ElementWiseBinaryOpImpl <ONNXBitwiseOrOp, T, EnableInteger<T>> {
277
+ static T eval (T lhs, T rhs) { return lhs | rhs; }
278
+ };
279
+
280
+ template <typename T>
281
+ struct ElementWiseBinaryOpImpl <ONNXAndOp, T, EnableBool<T>> {
282
+ static T eval (T lhs, T rhs) { return lhs && rhs; }
283
+ };
284
+
285
+ template <typename T>
286
+ struct ElementWiseBinaryOpImpl <ONNXOrOp, T, EnableBool<T>> {
287
+ static T eval (T lhs, T rhs) { return lhs || rhs; }
288
+ };
289
+
290
+ template <typename T>
291
+ struct ElementWiseBinaryOpImpl <ONNXXorOp, T, EnableBool<T>> {
292
+ static T eval (T lhs, T rhs) { return lhs != rhs; }
207
293
};
208
294
209
295
template <typename T>
@@ -340,11 +426,56 @@ struct ElementWiseUnaryOpImpl {
340
426
static T eval (T val) { llvm_unreachable (" unsupported op or type" ); }
341
427
};
342
428
429
+ template <typename T>
430
+ struct ElementWiseUnaryOpImpl <ONNXBitwiseNotOp, T, EnableInteger<T>> {
431
+ static T eval (T val) { return ~val; }
432
+ };
433
+
434
+ template <typename T>
435
+ struct ElementWiseUnaryOpImpl <ONNXCeilOp, T, EnableNotBool<T>> {
436
+ static T eval (T val) { return ceil (val); }
437
+ };
438
+
439
+ template <typename T>
440
+ struct ElementWiseUnaryOpImpl <ONNXCosOp, T, EnableFloatingPoint<T>> {
441
+ static T eval (T val) { return cos (val); }
442
+ };
443
+
444
+ template <typename T>
445
+ struct ElementWiseUnaryOpImpl <ONNXErfOp, T, EnableNotBool<T>> {
446
+ static T eval (T val) { return std::erf (val); }
447
+ };
448
+
449
+ template <typename T>
450
+ struct ElementWiseUnaryOpImpl <ONNXExpOp, T, EnableFloatingPoint<T>> {
451
+ static T eval (T val) { return std::exp (val); }
452
+ };
453
+
454
+ template <typename T>
455
+ struct ElementWiseUnaryOpImpl <ONNXFloorOp, T, EnableNotBool<T>> {
456
+ static T eval (T val) { return floor (val); }
457
+ };
458
+
459
+ template <typename T>
460
+ struct ElementWiseUnaryOpImpl <ONNXLogOp, T, EnableFloatingPoint<T>> {
461
+ static T eval (T val) { return std::log (val); }
462
+ };
463
+
343
464
template <typename T>
344
465
struct ElementWiseUnaryOpImpl <ONNXNegOp, T, EnableNotBool<T>> {
345
466
static T eval (T val) { return -val; }
346
467
};
347
468
469
+ template <typename T>
470
+ struct ElementWiseUnaryOpImpl <ONNXNotOp, T, EnableBool<T>> {
471
+ static T eval (T val) { return !val; }
472
+ };
473
+
474
+ template <typename T>
475
+ struct ElementWiseUnaryOpImpl <ONNXSinOp, T, EnableFloatingPoint<T>> {
476
+ static T eval (T val) { return sin (val); }
477
+ };
478
+
348
479
template <>
349
480
struct ElementWiseUnaryOpImpl <ONNXSqrtOp, double > {
350
481
static double eval (double val) { return sqrt (val); }
0 commit comments