24
24
#include " llvm/ADT/SmallPtrSet.h"
25
25
#include " llvm/Support/raw_ostream.h"
26
26
27
+ #include " src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
27
28
#include " src/Dialect/ONNX/ONNXOps.hpp"
28
29
#include " src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
29
30
#include " src/Interface/ShapeInferenceOpInterface.hpp"
@@ -48,9 +49,16 @@ class InstrumentONNXSignaturePass
48
49
InstrumentONNXSignaturePass () = default ;
49
50
InstrumentONNXSignaturePass (const InstrumentONNXSignaturePass &pass)
50
51
: mlir::PassWrapper<InstrumentONNXSignaturePass,
51
- OperationPass<func::FuncOp>>() {}
52
+ OperationPass<func::FuncOp>>() {
53
+ signaturePattern = pass.signaturePattern ;
54
+ }
55
+ InstrumentONNXSignaturePass (const std::string pattern) {
56
+ signaturePattern = pattern;
57
+ }
52
58
53
59
private:
60
+ std::string signaturePattern;
61
+
54
62
public:
55
63
StringRef getArgument () const override {
56
64
return " instrument-onnx-runtime-signature" ;
@@ -62,25 +70,31 @@ class InstrumentONNXSignaturePass
62
70
}
63
71
64
72
void runOnOperation () override {
73
+ onnx_mlir::EnableByRegexOption traceSpecificOpPattern (
74
+ /* emptyIsNone*/ false );
75
+ traceSpecificOpPattern.setRegexString (signaturePattern);
65
76
// Iterate on the operations nested in this function.
66
77
getOperation ().walk ([&](mlir::Operation *op) {
67
- if (isa<ONNXDialect>(op->getDialect ())) {
68
- if (!isa<ONNXPrintSignatureOp>(op)) {
69
- Location loc = op->getLoc ();
70
- OpBuilder builder (op);
71
- std::string opName = op->getName ().getStringRef ().str ();
72
- std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt (op);
73
- std::string fullName = opName + " , " + nodeName;
74
- StringAttr fullNameAttr = builder.getStringAttr (fullName);
75
- // Enqueue all input operands, and then the results.
76
- llvm::SmallVector<Value, 6 > operAndRes (op->getOperands ());
77
- for (Value res : op->getResults ())
78
- operAndRes.emplace_back (res);
79
- // Since we may use the result of an operation, we must insert the
80
- // print operation after the operation.
81
- builder.setInsertionPointAfter (op);
82
- builder.create <ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
83
- }
78
+ std::string opName = op->getName ().getStringRef ().str ();
79
+ auto dialect = op->getDialect ();
80
+ if (isa<func::FuncDialect>(dialect) || isa<ONNXPrintSignatureOp>(op)) {
81
+ // Always skip function dialects (such as function call/return), as well
82
+ // as ONNX print signature ops.
83
+ } else if (traceSpecificOpPattern.isEnabled (opName)) {
84
+ // Add signature printing op.
85
+ Location loc = op->getLoc ();
86
+ OpBuilder builder (op);
87
+ std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt (op);
88
+ std::string fullName = opName + " , " + nodeName;
89
+ StringAttr fullNameAttr = builder.getStringAttr (fullName);
90
+ // Enqueue all input operands, and then the results.
91
+ llvm::SmallVector<Value, 6 > operAndRes (op->getOperands ());
92
+ for (Value res : op->getResults ())
93
+ operAndRes.emplace_back (res);
94
+ // Since we may use the result of an operation, we must insert the
95
+ // print operation after the operation.
96
+ builder.setInsertionPointAfter (op);
97
+ builder.create <ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
84
98
}
85
99
});
86
100
}
@@ -90,6 +104,7 @@ class InstrumentONNXSignaturePass
90
104
/* !
91
105
* Create an instrumentation pass.
92
106
*/
93
- std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass () {
94
- return std::make_unique<InstrumentONNXSignaturePass>();
107
+ std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass (
108
+ const std::string pattern) {
109
+ return std::make_unique<InstrumentONNXSignaturePass>(pattern);
95
110
}
0 commit comments