Skip to content

Commit f11a21c

Browse files
ttjostmgehre-amd
andauthored
Constprop: More unary/binary ops. Also don't crash in divide by Zero. (#2862)
* feat: add constant propagation for different unary and binary ops. Signed-off-by: Tiago Trevisan Jost <[email protected]> --------- Signed-off-by: Tiago Trevisan Jost <[email protected]> Co-authored-by: Matthias Gehre <[email protected]>
1 parent 225c8c4 commit f11a21c

File tree

3 files changed

+689
-1
lines changed

3 files changed

+689
-1
lines changed

src/Dialect/ONNX/Transforms/ConstProp.cpp

+132-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Pass/Pass.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "llvm/ADT/APFloat.h"
2223
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/StringSet.h"
2425
#include "llvm/Support/Debug.h"
@@ -122,6 +123,16 @@ Value createReplacingConstantOp(
122123
template <typename T>
123124
using EnableNotBool = std::enable_if_t<!std::is_same_v<T, bool>>;
124125

126+
template <typename T>
127+
using EnableBool = std::enable_if_t<std::is_same_v<T, bool>>;
128+
129+
template <typename T>
130+
using EnableInteger =
131+
std::enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, bool>>;
132+
133+
template <typename T>
134+
using EnableFloatingPoint = std::enable_if_t<std::is_floating_point_v<T>>;
135+
125136
/// Checks whether a variadic value is produced by dense ONNXConstantOps.
126137
bool isVariadicOperandFromDenseONNXConstantOp(ValueRange operands) {
127138
return llvm::all_of(operands, [](Value v) { return isDenseONNXConstant(v); });
@@ -134,6 +145,47 @@ Value ConstZeroTensor(
134145
type, rewriter.getZeroAttr(type.getElementType())));
135146
}
136147

148+
template <typename GetFPConstFunc =
149+
std::function<APFloat(const llvm::fltSemantics &, bool)>,
150+
typename GetIntConstFunc = std::function<APInt(unsigned)>>
151+
Value getClipConstantOfType(PatternRewriter &rewriter, ShapedType type,
152+
Location loc, GetFPConstFunc fpConstantFunc, bool isNegative,
153+
GetIntConstFunc intConstantFunc) {
154+
OnnxBuilder create(rewriter, loc);
155+
auto elemType = type.getElementType();
156+
if (auto floatType = dyn_cast<FloatType>(elemType)) {
157+
auto fpValue =
158+
fpConstantFunc(floatType.getFloatSemantics(), /*Negative=*/isNegative);
159+
return create.constant(DenseElementsAttr::get(
160+
RankedTensorType::get({}, elemType), llvm::ArrayRef(fpValue)));
161+
}
162+
auto intValue = intConstantFunc(elemType.getIntOrFloatBitWidth());
163+
return create.constant(DenseElementsAttr::get(
164+
RankedTensorType::get({}, elemType), llvm::ArrayRef(intValue)));
165+
}
166+
167+
Value createMaximumValueForClip(
168+
PatternRewriter &rewriter, ShapedType type, Value value) {
169+
170+
// Return 'value' if exists, as there is no need to clip to largest.
171+
if (!isNoneValue(value))
172+
return value;
173+
174+
return getClipConstantOfType(rewriter, type, value.getLoc(),
175+
llvm::APFloat::getLargest, false, llvm::APInt::getMaxValue);
176+
}
177+
178+
Value createMinimumValueForClip(
179+
PatternRewriter &rewriter, ShapedType type, Value value) {
180+
181+
// Return 'value' if exists, as there is no need to clip to lowest.
182+
if (!isNoneValue(value))
183+
return value;
184+
185+
return getClipConstantOfType(rewriter, type, value.getLoc(),
186+
llvm::APFloat::getLargest, true, llvm::APInt::getMinValue);
187+
}
188+
137189
WideNum asWideNum(double n, Type elemType) {
138190
return wideZeroDispatch(elemType, [n](auto wideZero) {
139191
using cpptype = decltype(wideZero);
@@ -203,7 +255,41 @@ struct ElementWiseBinaryOpImpl<ONNXMulOp, T, EnableNotBool<T>> {
203255

204256
template <typename T>
205257
struct ElementWiseBinaryOpImpl<ONNXDivOp, T, EnableNotBool<T>> {
206-
static T eval(T lhs, T rhs) { return lhs / rhs; }
258+
static T eval(T lhs, T rhs) {
259+
if constexpr (std::is_integral_v<T>) {
260+
if (rhs == 0) {
261+
// Undefined behavior. We can return any value.
262+
// Performing the divison would crash.
263+
return lhs;
264+
}
265+
}
266+
return lhs / rhs;
267+
}
268+
};
269+
270+
template <typename T>
271+
struct ElementWiseBinaryOpImpl<ONNXBitwiseAndOp, T, EnableInteger<T>> {
272+
static T eval(T lhs, T rhs) { return lhs & rhs; }
273+
};
274+
275+
template <typename T>
276+
struct ElementWiseBinaryOpImpl<ONNXBitwiseOrOp, T, EnableInteger<T>> {
277+
static T eval(T lhs, T rhs) { return lhs | rhs; }
278+
};
279+
280+
template <typename T>
281+
struct ElementWiseBinaryOpImpl<ONNXAndOp, T, EnableBool<T>> {
282+
static T eval(T lhs, T rhs) { return lhs && rhs; }
283+
};
284+
285+
template <typename T>
286+
struct ElementWiseBinaryOpImpl<ONNXOrOp, T, EnableBool<T>> {
287+
static T eval(T lhs, T rhs) { return lhs || rhs; }
288+
};
289+
290+
template <typename T>
291+
struct ElementWiseBinaryOpImpl<ONNXXorOp, T, EnableBool<T>> {
292+
static T eval(T lhs, T rhs) { return lhs != rhs; }
207293
};
208294

209295
template <typename T>
@@ -340,11 +426,56 @@ struct ElementWiseUnaryOpImpl {
340426
static T eval(T val) { llvm_unreachable("unsupported op or type"); }
341427
};
342428

429+
template <typename T>
430+
struct ElementWiseUnaryOpImpl<ONNXBitwiseNotOp, T, EnableInteger<T>> {
431+
static T eval(T val) { return ~val; }
432+
};
433+
434+
template <typename T>
435+
struct ElementWiseUnaryOpImpl<ONNXCeilOp, T, EnableNotBool<T>> {
436+
static T eval(T val) { return ceil(val); }
437+
};
438+
439+
template <typename T>
440+
struct ElementWiseUnaryOpImpl<ONNXCosOp, T, EnableFloatingPoint<T>> {
441+
static T eval(T val) { return cos(val); }
442+
};
443+
444+
template <typename T>
445+
struct ElementWiseUnaryOpImpl<ONNXErfOp, T, EnableNotBool<T>> {
446+
static T eval(T val) { return std::erf(val); }
447+
};
448+
449+
template <typename T>
450+
struct ElementWiseUnaryOpImpl<ONNXExpOp, T, EnableFloatingPoint<T>> {
451+
static T eval(T val) { return std::exp(val); }
452+
};
453+
454+
template <typename T>
455+
struct ElementWiseUnaryOpImpl<ONNXFloorOp, T, EnableNotBool<T>> {
456+
static T eval(T val) { return floor(val); }
457+
};
458+
459+
template <typename T>
460+
struct ElementWiseUnaryOpImpl<ONNXLogOp, T, EnableFloatingPoint<T>> {
461+
static T eval(T val) { return std::log(val); }
462+
};
463+
343464
template <typename T>
344465
struct ElementWiseUnaryOpImpl<ONNXNegOp, T, EnableNotBool<T>> {
345466
static T eval(T val) { return -val; }
346467
};
347468

469+
template <typename T>
470+
struct ElementWiseUnaryOpImpl<ONNXNotOp, T, EnableBool<T>> {
471+
static T eval(T val) { return !val; }
472+
};
473+
474+
template <typename T>
475+
struct ElementWiseUnaryOpImpl<ONNXSinOp, T, EnableFloatingPoint<T>> {
476+
static T eval(T val) { return sin(val); }
477+
};
478+
348479
template <>
349480
struct ElementWiseUnaryOpImpl<ONNXSqrtOp, double> {
350481
static double eval(double val) { return sqrt(val); }

0 commit comments

Comments
 (0)