Skip to content

Commit 7c58751

Browse files
authored
Recompose multiple ops into a single ONNXGelu (#2965)
Recompose multiple ops into a single ONNXGelu (#2965) Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent 265ee60 commit 7c58751

File tree

8 files changed

+480
-20
lines changed

8 files changed

+480
-20
lines changed

src/Dialect/ONNX/DialectBuilder.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ Value OnnxBuilder::expand(Type outputType, Value input, Value shape) const {
150150
outputType, toTensor(input), toTensor(shape));
151151
}
152152

153+
Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const {
154+
return createOpAndInferShapes<ONNXGeluOp>(
155+
toTensor(input.getType()), input, approximateAttr);
156+
}
157+
153158
// ONNXLayerNormalizationOp, version with one output only (Y).
154159
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
155160
Value bias, int64_t axis, FloatAttr epsilon) const {

src/Dialect/ONNX/DialectBuilder.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ struct OnnxBuilder : DialectBuilder {
8787
mlir::Value expand(
8888
mlir::Type outputType, mlir::Value input, mlir::Value shape) const;
8989

90+
// ONNXGeluOp
91+
mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const;
92+
9093
// ONNXLayerNormalizationOp, version with one output only (Y).
9194
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
9295
mlir::Value scale, mlir::Value bias, int64_t axis,

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,24 @@ RESULT_TYPE getScalarValue(ONNXConstantOp constantOp) {
579579
template double getScalarValue<double>(ONNXConstantOp constantOp);
580580
template int64_t getScalarValue<int64_t>(ONNXConstantOp constantOp);
581581

582+
/// Return the wide type of a value.
583+
WideNum asWideNum(double n, Type elemType) {
584+
return wideZeroDispatch(elemType, [n](auto wideZero) {
585+
using cpptype = decltype(wideZero);
586+
constexpr BType TAG = toBType<cpptype>;
587+
return WideNum::widen<TAG>(static_cast<cpptype>(n));
588+
});
589+
}
590+
591+
/// Checks whether a constant tensor's elements are all equal to a given scalar.
592+
bool isConstOf(Value constValue, double n) {
593+
ElementsAttr constElements = getElementAttributeFromONNXValue(constValue);
594+
Type elemType = constElements.getElementType();
595+
assert(!elemType.isInteger(1) && "booleans are not supported");
596+
WideNum w = asWideNum(n, elemType);
597+
return ElementsAttrBuilder::allEqual(constElements, w);
598+
}
599+
582600
// Convert type to MLIR type.
583601
// A complete list of types can be found in:
584602
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h

src/Dialect/ONNX/ONNXOps/OpHelper.hpp

+43
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ RESULT_TYPE getScalarValue(mlir::ElementsAttr denseAttr, mlir::Type type);
244244
template <typename RESULT_TYPE>
245245
RESULT_TYPE getScalarValue(mlir::ONNXConstantOp constantOp);
246246

247+
/// Return the wide type of a value.
248+
WideNum asWideNum(double n, mlir::Type elemType);
249+
250+
/// Checks whether a constant tensor's elements are all equal to a given scalar.
251+
bool isConstOf(mlir::Value constValue, double n);
252+
247253
mlir::Type convertONNXTypeToMLIRType(
248254
mlir::Builder &builder, onnx::TensorProto_DataType onnxType);
249255

@@ -277,6 +283,43 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op,
277283
mlir::Value &matchOperand0, mlir::Value &matchOperand1,
278284
int64_t matchThisOperandIndex);
279285

286+
// This is to recognize a binary op, e.g. A*B where one of A and B is a constant
287+
// and the other one is defined by OP.
288+
// Note: this function can handle the communitive property of the binary op.
289+
//
290+
// For example, to recognize this pattern:
291+
// %x = "onnx.Tanh"()
292+
// %y = 0.5 * %x // or %x * 0.5
293+
//
294+
// we call
295+
// ```
296+
// ONNXTanhOp tanhOp;
297+
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, 0.5, tanhOp);
298+
// ```
299+
// where `A` and `B` are operands of ONNXMul that produces %y.
300+
template <typename OP>
301+
bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &op);
302+
303+
// This is to recognize a binary op, e.g. A*B where one of A and B is the given
304+
// value and the other one is defined by OP.
305+
// Note: this function can handle the communitive property of the binary op.
306+
//
307+
// For example, to recognize this pattern where %z is one of the inputs of *,
308+
// and the other input of * is defined by onnx.Tanh:
309+
// %x = "onnx.Tanh"()
310+
// %y = %z * %x // or %x * %z
311+
//
312+
// we call
313+
// ```
314+
// Value z;
315+
// ONNXTanhOp tanhOp;
316+
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, z, tanhOp);
317+
// ```
318+
// where `A` and `B` are operands of ONNXMul that produces %y.
319+
template <typename OP>
320+
bool matchValueAndOp(
321+
mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp);
322+
280323
/// Check if a value is to store dimensions, meaning it is a tensor of one
281324
/// element or concatenation of one-element tensors.
282325
bool areDims(mlir::Value val);

src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc

+62
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,65 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op,
8383
}
8484
return false;
8585
}
86+
87+
// This is to recognize a binary op, e.g. A*B where one of A and B is a constant
88+
// and the other one is defined by OP.
89+
// Note: this function can handle the communitive property of the binary op.
90+
//
91+
// For example, to recognize this pattern:
92+
// %x = "onnx.Tanh"()
93+
// %y = 0.5 * %x // or %x * 0.5
94+
//
95+
// we call
96+
// ```
97+
// ONNXTanhOp tanhOp;
98+
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, 0.5, tanhOp);
99+
// ```
100+
// where `A` and `B` are operands of ONNXMul that produces %y.
101+
template <typename OP>
102+
bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &matchOp) {
103+
auto opA = A.getDefiningOp<OP>();
104+
auto opB = B.getDefiningOp<OP>();
105+
if (onnx_mlir::isDenseONNXConstant(A) && onnx_mlir::isConstOf(A, cst) && opB)
106+
{
107+
matchOp = opB;
108+
return true;
109+
}
110+
if (opA && onnx_mlir::isDenseONNXConstant(B) && onnx_mlir::isConstOf(B, cst))
111+
{
112+
matchOp = opA;
113+
return true;
114+
}
115+
return false;
116+
}
117+
118+
// This is to recognize a binary op, e.g. A*B where one of A and B is the given
119+
// value and the other one is defined by OP.
120+
// Note: this function can handle the communitive property of the binary op.
121+
//
122+
// For example, to recognize this pattern where %z is one of the inputs of *,
123+
// and the other input of * is defined by onnx.Tanh:
124+
// %x = "onnx.Tanh"()
125+
// %y = %z * %x // or %x * %z
126+
//
127+
// we call
128+
// ```
129+
// Value z;
130+
// ONNXTanhOp tanhOp;
131+
// bool found = matchConstAndOp<ONNXTanhOp>(A, B, z, tanhOp);
132+
// ```
133+
// where `A` and `B` are operands of ONNXMul that produces %y.
134+
template <typename OP>
135+
bool matchValueAndOp(mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp) {
136+
auto opA = A.getDefiningOp<OP>();
137+
auto opB = B.getDefiningOp<OP>();
138+
if ((A == matchValue) && opB) {
139+
matchOp = opB;
140+
return true;
141+
}
142+
if (opA && (B == matchValue)) {
143+
matchOp = opA;
144+
return true;
145+
}
146+
return false;
147+
}

src/Dialect/ONNX/Transforms/ConstProp.cpp

-17
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,6 @@ Value createMinimumValueForClip(
186186
llvm::APFloat::getLargest, true, llvm::APInt::getMinValue);
187187
}
188188

189-
WideNum asWideNum(double n, Type elemType) {
190-
return wideZeroDispatch(elemType, [n](auto wideZero) {
191-
using cpptype = decltype(wideZero);
192-
constexpr BType TAG = toBType<cpptype>;
193-
return WideNum::widen<TAG>(static_cast<cpptype>(n));
194-
});
195-
}
196-
197-
/// Checks whether a constant tensor's elements are all equal to a given scalar.
198-
bool isConstOf(Value constValue, double n) {
199-
ElementsAttr constElements = getConstValueElements(constValue);
200-
Type elemType = constElements.getElementType();
201-
assert(!elemType.isInteger(1) && "booleans are not supported");
202-
WideNum w = asWideNum(n, elemType);
203-
return ElementsAttrBuilder::allEqual(constElements, w);
204-
}
205-
206189
// Extracts number from a scalar constant value.
207190
WideNum getScalarNum(Value constValue) {
208191
ElementsAttr elements = getConstValueElements(constValue);

0 commit comments

Comments
 (0)