@@ -244,6 +244,12 @@ RESULT_TYPE getScalarValue(mlir::ElementsAttr denseAttr, mlir::Type type);
244
244
template <typename RESULT_TYPE>
245
245
RESULT_TYPE getScalarValue (mlir::ONNXConstantOp constantOp);
246
246
247
+ // / Return the wide type of a value.
248
+ WideNum asWideNum (double n, mlir::Type elemType);
249
+
250
+ // / Checks whether a constant tensor's elements are all equal to a given scalar.
251
+ bool isConstOf (mlir::Value constValue, double n);
252
+
247
253
mlir::Type convertONNXTypeToMLIRType (
248
254
mlir::Builder &builder, onnx::TensorProto_DataType onnxType);
249
255
@@ -277,6 +283,43 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op,
277
283
mlir::Value &matchOperand0, mlir::Value &matchOperand1,
278
284
int64_t matchThisOperandIndex);
279
285
286
+ // This is to recognize a binary op, e.g. A*B where one of A and B is a constant
287
+ // and the other one is defined by OP.
288
+ // Note: this function can handle the communitive property of the binary op.
289
+ //
290
+ // For example, to recognize this pattern:
291
+ // %x = "onnx.Tanh"()
292
+ // %y = 0.5 * %x // or %x * 0.5
293
+ //
294
+ // we call
295
+ // ```
296
+ // ONNXTanhOp tanhOp;
297
+ // bool found = matchConstAndOp<ONNXTanhOp>(A, B, 0.5, tanhOp);
298
+ // ```
299
+ // where `A` and `B` are operands of ONNXMul that produces %y.
300
+ template <typename OP>
301
+ bool matchConstAndOp (mlir::Value A, mlir::Value B, double cst, OP &op);
302
+
303
+ // This is to recognize a binary op, e.g. A*B where one of A and B is the given
304
+ // value and the other one is defined by OP.
305
+ // Note: this function can handle the communitive property of the binary op.
306
+ //
307
+ // For example, to recognize this pattern where %z is one of the inputs of *,
308
+ // and the other input of * is defined by onnx.Tanh:
309
+ // %x = "onnx.Tanh"()
310
+ // %y = %z * %x // or %x * %z
311
+ //
312
+ // we call
313
+ // ```
314
+ // Value z;
315
+ // ONNXTanhOp tanhOp;
316
+ // bool found = matchConstAndOp<ONNXTanhOp>(A, B, z, tanhOp);
317
+ // ```
318
+ // where `A` and `B` are operands of ONNXMul that produces %y.
319
+ template <typename OP>
320
+ bool matchValueAndOp (
321
+ mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp);
322
+
280
323
// / Check if a value is to store dimensions, meaning it is a tensor of one
281
324
// / element or concatenation of one-element tensors.
282
325
bool areDims (mlir::Value val);
0 commit comments