Skip to content

Commit e5b73a7

Browse files
Signature for selected ops (#2884)
* Trace signature for ops that fit a given explicit pattern Signed-off-by: Alexandre Eichenberger <[email protected]> --------- Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 58572e0 commit e5b73a7

File tree

8 files changed

+61
-37
lines changed

8 files changed

+61
-37
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
218218
else if (optStr == "-O3")
219219
optLevel = OptLevel::O3;
220220
// Lower ONNX to Krnl, ZHigh to ZLow.
221-
addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true,
222-
instrumentONNXSignature, ONNXOpStats);
221+
addONNXToKrnlPasses(
222+
pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats);
223223

224224
if (nnpaEmissionTarget >= EmitZLowIR)
225225
emissionTarget = EmitMLIR;

src/Compiler/CompilerOptions.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ std::string mllvm; // onnx-mlir only
6464
std::string instrumentOps; // onnx-mlir only
6565
unsigned instrumentControlBits; // onnx-mlir only
6666
std::string parallelizeOps; // onnx-mlir only
67-
bool instrumentONNXSignature; // onnx-mlir only
67+
std::string instrumentSignatures; // onnx-mlir only
6868
std::string ONNXOpStats; // onnx-mlir only
6969
int onnxOpTransformThreshold; // onnx-mlir only
7070
bool onnxOpTransformReport; // onnx-mlir only
@@ -432,10 +432,17 @@ static llvm::cl::opt<std::string, true> parallelizeOpsOpt("parallelize-ops",
432432
llvm::cl::location(parallelizeOps), llvm::cl::init(""),
433433
llvm::cl::cat(OnnxMlirOptions));
434434

435-
static llvm::cl::opt<bool, true> instrumentONNXSignatureOpt(
436-
"instrument-onnx-signature",
437-
llvm::cl::desc("Instrument ONNX ops to print the type of their inputs"),
438-
llvm::cl::location(instrumentONNXSignature), llvm::cl::init(false),
435+
static llvm::cl::opt<std::string, true> instrumentSignatureOpt(
436+
"instrument-signature",
437+
llvm::cl::desc("Specify which high-level operations should print their"
438+
" input type(s) and shape(s)\n"
439+
"\"ALL\" or \"\" for all available operations,\n"
440+
"\"NONE\" for no instrument (default),\n"
441+
"\"ops1,ops2, ...\" for the multiple ops.\n"
442+
"e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n"
443+
"Asterisk is also available.\n"
444+
"e.g. \"onnx.*\" for all onnx operations.\n"),
445+
llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"),
439446
llvm::cl::cat(OnnxMlirOptions));
440447

441448
static llvm::cl::opt<std::string, true> ONNXOpStatsOpt("onnx-op-stats",

src/Compiler/CompilerOptions.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ extern std::string mllvm; // onnx-mlir only
108108
extern std::string instrumentOps; // onnx-mlir only
109109
extern unsigned instrumentControlBits; // onnx-mlir only
110110
extern std::string parallelizeOps; // onnx-mlir only
111-
extern bool instrumentONNXSignature; // onnx-mlir only
111+
extern std::string instrumentSignatures; // onnx-mlir only
112112
extern std::string ONNXOpStats; // onnx-mlir only
113113
extern int onnxOpTransformThreshold; // onnx-mlir only
114114
extern bool onnxOpTransformReport; // onnx-mlir only

src/Compiler/CompilerPasses.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
157157
}
158158

159159
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
160-
bool enableInstrumentONNXSignature, std::string ONNXOpsStatFormat) {
160+
std::string instrumentSignatureString, std::string ONNXOpsStatFormat) {
161161
if (enableCSE)
162162
// Eliminate common sub-expressions before lowering to Krnl.
163163
// TODO: enable this by default when we make sure it works flawlessly.
@@ -182,10 +182,11 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
182182
}
183183

184184
// Print Signatures of each op at runtime if enabled. Should not run
185-
// signature and instrument passes at the same time.
186-
if (enableInstrumentONNXSignature)
187-
pm.addNestedPass<func::FuncOp>(
188-
onnx_mlir::createInstrumentONNXSignaturePass());
185+
// signature and instrument passes at the same time as time may include printf
186+
// overheads.
187+
if (instrumentSignatureString != "NONE")
188+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
189+
instrumentSignatureString));
189190
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
190191
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel,
191192
/*opsToCall*/ opsForCall));
@@ -304,7 +305,7 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,
304305
if (emissionTarget >= EmitMLIR) {
305306
if (inputIRLevel <= ONNXLevel)
306307
addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true,
307-
instrumentONNXSignature, ONNXOpStats);
308+
instrumentSignatures, ONNXOpStats);
308309
if (inputIRLevel <= MLIRLevel)
309310
addKrnlToAffinePasses(pm);
310311
}

src/Compiler/CompilerPasses.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void configurePasses();
2121

2222
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU);
2323
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
24-
bool enableInstrumentONNXSignature, std::string ONNXOpsStatFilename);
24+
std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
2525
void addKrnlToAffinePasses(mlir::PassManager &pm);
2626
void addKrnlToLLVMPasses(
2727
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE);

src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp

+35-20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/ADT/SmallPtrSet.h"
2525
#include "llvm/Support/raw_ostream.h"
2626

27+
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
2728
#include "src/Dialect/ONNX/ONNXOps.hpp"
2829
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
2930
#include "src/Interface/ShapeInferenceOpInterface.hpp"
@@ -48,9 +49,16 @@ class InstrumentONNXSignaturePass
4849
InstrumentONNXSignaturePass() = default;
4950
InstrumentONNXSignaturePass(const InstrumentONNXSignaturePass &pass)
5051
: mlir::PassWrapper<InstrumentONNXSignaturePass,
51-
OperationPass<func::FuncOp>>() {}
52+
OperationPass<func::FuncOp>>() {
53+
signaturePattern = pass.signaturePattern;
54+
}
55+
InstrumentONNXSignaturePass(const std::string pattern) {
56+
signaturePattern = pattern;
57+
}
5258

5359
private:
60+
std::string signaturePattern;
61+
5462
public:
5563
StringRef getArgument() const override {
5664
return "instrument-onnx-runtime-signature";
@@ -62,25 +70,31 @@ class InstrumentONNXSignaturePass
6270
}
6371

6472
void runOnOperation() override {
73+
onnx_mlir::EnableByRegexOption traceSpecificOpPattern(
74+
/*emptyIsNone*/ false);
75+
traceSpecificOpPattern.setRegexString(signaturePattern);
6576
// Iterate on the operations nested in this function.
6677
getOperation().walk([&](mlir::Operation *op) {
67-
if (isa<ONNXDialect>(op->getDialect())) {
68-
if (!isa<ONNXPrintSignatureOp>(op)) {
69-
Location loc = op->getLoc();
70-
OpBuilder builder(op);
71-
std::string opName = op->getName().getStringRef().str();
72-
std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
73-
std::string fullName = opName + ", " + nodeName;
74-
StringAttr fullNameAttr = builder.getStringAttr(fullName);
75-
// Enqueue all input operands, and then the results.
76-
llvm::SmallVector<Value, 6> operAndRes(op->getOperands());
77-
for (Value res : op->getResults())
78-
operAndRes.emplace_back(res);
79-
// Since we may use the result of an operation, we must insert the
80-
// print operation after the operation.
81-
builder.setInsertionPointAfter(op);
82-
builder.create<ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
83-
}
78+
std::string opName = op->getName().getStringRef().str();
79+
auto dialect = op->getDialect();
80+
if (isa<func::FuncDialect>(dialect) || isa<ONNXPrintSignatureOp>(op)) {
81+
// Always skip function dialects (such as function call/return), as well
82+
// as ONNX print signature ops.
83+
} else if (traceSpecificOpPattern.isEnabled(opName)) {
84+
// Add signature printing op.
85+
Location loc = op->getLoc();
86+
OpBuilder builder(op);
87+
std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
88+
std::string fullName = opName + ", " + nodeName;
89+
StringAttr fullNameAttr = builder.getStringAttr(fullName);
90+
// Enqueue all input operands, and then the results.
91+
llvm::SmallVector<Value, 6> operAndRes(op->getOperands());
92+
for (Value res : op->getResults())
93+
operAndRes.emplace_back(res);
94+
// Since we may use the result of an operation, we must insert the
95+
// print operation after the operation.
96+
builder.setInsertionPointAfter(op);
97+
builder.create<ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
8498
}
8599
});
86100
}
@@ -90,6 +104,7 @@ class InstrumentONNXSignaturePass
90104
/*!
91105
* Create an instrumentation pass.
92106
*/
93-
std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass() {
94-
return std::make_unique<InstrumentONNXSignaturePass>();
107+
std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass(
108+
const std::string pattern) {
109+
return std::make_unique<InstrumentONNXSignaturePass>(pattern);
95110
}

src/Pass/Passes.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ std::unique_ptr<mlir::Pass> createInstrumentPass(
5959

6060
/// Passes for instrumenting the ONNX ops to print their operand type
6161
/// signatures at runtime.
62-
std::unique_ptr<mlir::Pass> createInstrumentONNXSignaturePass();
62+
std::unique_ptr<mlir::Pass> createInstrumentONNXSignaturePass(
63+
const std::string pattern);
6364

6465
/// Pass for simplifying shape-related ONNX operations.
6566
std::unique_ptr<mlir::Pass> createSimplifyShapeRelatedOpsPass();

src/Tools/onnx-mlir-opt/RegisterPasses.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void registerOMPasses(int optLevel) {
7171
[]() -> std::unique_ptr<mlir::Pass> { return createInstrumentPass(); });
7272

7373
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
74-
return createInstrumentONNXSignaturePass();
74+
return createInstrumentONNXSignaturePass("NONE");
7575
});
7676

7777
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {

0 commit comments

Comments
 (0)