Skip to content

Commit 2f8924a

Browse files
committed
Merge branch 'main' into pr_extend_dynamic_backend_test
2 parents 2fe6881 + bc893dd commit 2f8924a

File tree

7 files changed

+59
-9
lines changed

7 files changed

+59
-9
lines changed

src/Compiler/CompilerOptions.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ std::vector<std::string> onnxConstPropDisablePatterns; // common for both
4141
bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
44+
bool disableKrnlOpFusion; // common for both
4445
EmissionTargetType emissionTarget; // onnx-mlir only
4546
bool invokeOnnxVersionConverter; // onnx-mlir only
4647
bool preserveLocations; // onnx-mlir only
@@ -83,6 +84,7 @@ std::vector<std::string> extraLibs; // onnx-mlir only
8384
ProfileIRs profileIR; // onnx-mlir only
8485
OptReport optReport; // onnx-mlir only
8586
bool useOldBufferization; // onnx-mlir only
87+
bool enableTiming; // onnx-mlir only
8688
bool split_input_file; // onnx-mlir-opt only
8789
bool verify_diagnostics; // onnx-mlir-opt only
8890
bool verify_passes; // onnx-mlir-opt only
@@ -201,6 +203,13 @@ static llvm::cl::list<std::string, std::vector<std::string>>
201203
llvm::cl::location(functionsToDecompose),
202204
llvm::cl::cat(OnnxMlirCommonOptions));
203205

206+
static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
207+
"disable-krnl-op-fusion",
208+
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
209+
"Set to 'true' if you want to disable fusion."),
210+
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
211+
llvm::cl::cat(OnnxMlirCommonOptions));
212+
204213
static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
205214
llvm::cl::desc("Disable recomposition of ONNX operations."),
206215
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
@@ -564,6 +573,12 @@ static llvm::cl::opt<OptReport, true> optReportOpt("opt-report",
564573
clEnumVal(Simd, "Provide report on how SIMD is applied to ONNX ops.")),
565574
llvm::cl::init(OptReport::NoReport), llvm::cl::cat(OnnxMlirOptions));
566575

576+
static llvm::cl::opt<bool, true> enable_timing("enable-timing",
577+
llvm::cl::desc("Enable compile timing (default is false)\n"
578+
"Set to 'true' if you want to enable compile timing."),
579+
llvm::cl::location(enableTiming), llvm::cl::init(false),
580+
llvm::cl::cat(OnnxMlirOptions));
581+
567582
// Options for onnx-mlir-opt only
568583
static llvm::cl::opt<bool, true> split_input_file_opt("split-input-file",
569584
llvm::cl::desc("Split the input file into pieces and process each "

src/Compiler/CompilerOptions.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
8484
extern bool enableONNXHybridPass; // common for both
8585
extern std::vector<std::string> functionsToDecompose; // common for both
8686
extern std::string opsForCall; // common for both
87+
extern bool disableKrnlOpFusion; // common for both
8788
extern EmissionTargetType emissionTarget; // onnx-mlir only
8889
extern bool invokeOnnxVersionConverter; // onnx-mlir only
8990
extern bool preserveLocations; // onnx-mlir only
@@ -126,6 +127,7 @@ extern std::vector<std::string> extraLibs; // onnx-mlir only
126127
extern ProfileIRs profileIR; // onnx-mlir only
127128
extern OptReport optReport; // onnx-mlir only
128129
extern bool useOldBufferization; // onnx-mlir only
130+
extern bool enableTiming; // onnx-mlir only
129131
extern bool split_input_file; // onnx-mlir-opt only
130132
extern bool verify_diagnostics; // onnx-mlir-opt only
131133
extern bool verify_passes; // onnx-mlir-opt only

src/Compiler/CompilerUtils.cpp

+23-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
#include "CompilerUtils.hpp"
1616

1717
#include <fstream>
18+
#include <memory>
1819
#include <regex>
1920

2021
#include "mlir/Dialect/Func/IR/FuncOps.h"
2122
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2223
#include "mlir/Parser/Parser.h"
24+
#include "mlir/Pass/Pass.h"
2325
#include "mlir/Pass/PassManager.h"
2426
#include "mlir/Support/FileUtilities.h"
27+
#include "mlir/Support/Timing.h"
2528
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
2629
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2730
#include "mlir/Target/LLVMIR/Export.h"
@@ -35,6 +38,7 @@
3538
#include "llvm/Support/SourceMgr.h"
3639
#include "llvm/Support/TargetSelect.h"
3740
#include "llvm/Support/ToolOutputFile.h"
41+
#include "llvm/Support/raw_ostream.h"
3842
#include "llvm/Target/TargetMachine.h"
3943

4044
#include "src/Accelerators/Accelerator.hpp"
@@ -49,6 +53,8 @@
4953
using namespace mlir;
5054
using namespace onnx_mlir;
5155

56+
mlir::DefaultTimingManager timingManager;
57+
mlir::TimingScope rootTimingScope;
5258
namespace onnx_mlir {
5359

5460
// Make a function that forces preserving all files using the runtime arguments
@@ -327,6 +333,8 @@ std::string getTargetFilename(
327333
// Returns 0 on success, error code on failure.
328334
static int genLLVMBitcode(const mlir::OwningOpRef<ModuleOp> &module,
329335
std::string outputNameNoExt, std::string optimizedBitcodeNameWithExt) {
336+
auto llvmTiming = rootTimingScope.nest(
337+
"[onnx-mlir] Compiling MLIR module to LLVM Optimized Bitcode");
330338
std::error_code error;
331339

332340
// Write bitcode to a file.
@@ -397,7 +405,8 @@ static int genLLVMBitcode(const mlir::OwningOpRef<ModuleOp> &module,
397405
// Return 0 on success, error code on failure.
398406
static int genModelObject(
399407
std::string bitcodeNameWithExt, std::string &modelObjNameWithExt) {
400-
408+
auto objectTiming =
409+
rootTimingScope.nest("[onnx-mlir] Compiling LLVM Bitcode to Object File");
401410
std::string llcPath = getToolPath("llc");
402411
Command llvmToObj(/*exePath=*/llcPath);
403412
setXllcOption({"--code-model", modelSizeStr[modelSize]});
@@ -418,6 +427,8 @@ static int genModelObject(
418427
// Return 0 on success, error code on failure.
419428
static int genJniObject(const mlir::OwningOpRef<ModuleOp> &module,
420429
std::string jniSharedLibPath, std::string jniObjPath) {
430+
auto jniTiming =
431+
rootTimingScope.nest("[onnx-mlir] Compiling JNI Object File");
421432
Command ar(/*exePath=*/getToolPath("ar", true));
422433
int rc = ar.appendStr("x")
423434
// old version of ar does not support --output so comment out
@@ -436,7 +447,8 @@ static int genJniObject(const mlir::OwningOpRef<ModuleOp> &module,
436447
static int genSharedLib(std::string sharedLibNameWithExt,
437448
std::vector<std::string> opts, std::vector<std::string> objs,
438449
std::vector<std::string> libs, std::vector<std::string> libDirs) {
439-
450+
auto sharedLibTiming =
451+
rootTimingScope.nest("[onnx-mlir] Linking Shared Library");
440452
#ifdef _WIN32
441453
std::vector<std::string> outputOpt = {"/Fe:" + sharedLibNameWithExt};
442454
// link has to be before libpath since they need to be passed through to the
@@ -486,6 +498,7 @@ static int genSharedLib(std::string sharedLibNameWithExt,
486498
// Return 0 on success, error code on failure.
487499
static int genJniJar(const mlir::OwningOpRef<ModuleOp> &module,
488500
std::string modelSharedLibPath, std::string modelJniJarPath) {
501+
auto jniJarTiming = rootTimingScope.nest("[onnx-mlir] Creating JNI Jar");
489502
llvm::SmallString<8> libraryPath(getLibraryPath());
490503
llvm::sys::path::append(libraryPath, "javaruntime.jar");
491504
std::string javaRuntimeJarPath = llvm::StringRef(libraryPath).str();
@@ -880,6 +893,9 @@ static int emitOutput(mlir::OwningOpRef<ModuleOp> &module,
880893
int compileModule(mlir::OwningOpRef<ModuleOp> &module,
881894
mlir::MLIRContext &context, std::string outputNameNoExt,
882895
EmissionTargetType emissionTarget) {
896+
auto compileModuleTiming =
897+
rootTimingScope.nest("[onnx-mlir] Compiling Module using MLIR");
898+
883899
int rc = setupModule(module, context, outputNameNoExt);
884900
if (rc != CompilerSuccess)
885901
return rc;
@@ -905,10 +921,14 @@ int compileModule(mlir::OwningOpRef<ModuleOp> &module,
905921
heapLogFileame, reportHeapBefore, reportHeapAfter));
906922
}
907923
(void)mlir::applyPassManagerCLOptions(pm);
908-
mlir::applyDefaultTimingPassManagerCLOptions(pm);
924+
925+
if (enableTiming) {
926+
pm.enableTiming(compileModuleTiming);
927+
}
909928

910929
if (mlir::failed(pm.run(*module)))
911930
return CompilerFailure;
931+
compileModuleTiming.stop();
912932
return emitOutput(module, context, outputNameNoExt, pm, emissionTarget);
913933
}
914934

src/Compiler/CompilerUtils.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818

1919
#include "mlir/IR/BuiltinOps.h"
2020
#include "mlir/IR/OwningOpRef.h"
21+
#include "mlir/Support/Timing.h"
2122
#include "llvm/ADT/StringRef.h"
2223
#include "llvm/Support/Path.h"
2324

2425
#include <optional>
2526
#include <string>
2627
#include <vector>
2728

29+
extern mlir::DefaultTimingManager timingManager;
30+
extern mlir::TimingScope rootTimingScope;
31+
2832
namespace onnx_mlir {
2933

3034
struct Command {

src/Conversion/ONNXToKrnl/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXToKrnl
8080

8181
LINK_LIBS PUBLIC
8282
OMAccelerator
83+
OMCompilerOptions
8384
OMONNXConversionCommon
8485
OMONNXOps
8586
OMSupport

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "llvm/Support/Debug.h"
1919

20+
#include "src/Compiler/CompilerOptions.hpp"
2021
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
2122
#include "src/Dialect/Krnl/DialectBuilder.hpp"
2223
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
@@ -1884,6 +1885,9 @@ bool OpFusionHelper::areInputsValidForFusion(
18841885
// A successor op (user) is fusible if it is the only user, it is in the
18851886
// fusible elementwise op list, and its inputs are valid for fusion.
18861887
void OpFusionHelper::findFusibleOps() {
1888+
// Direct return if fusion is disabled
1889+
if (disableKrnlOpFusion)
1890+
return;
18871891
Operation *defOp = rootOp;
18881892
while (defOp->hasOneUse()) {
18891893
// the possible ONNX Ops.

src/onnx-mlir.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
// Implements main for onnx-mlir driver.
1212
//===----------------------------------------------------------------------===//
1313

14-
#include <iostream>
1514
#include <regex>
1615

16+
#include "mlir/Support/Timing.h"
1717
#include "src/Compiler/CompilerOptions.hpp"
1818
#include "src/Compiler/CompilerUtils.hpp"
1919
#include "src/Version/Version.hpp"
@@ -24,12 +24,10 @@
2424
using namespace onnx_mlir;
2525

2626
int main(int argc, char *argv[]) {
27-
2827
// Register MLIR command line options.
2928
mlir::registerAsmPrinterCLOptions();
3029
mlir::registerMLIRContextCLOptions();
3130
mlir::registerPassManagerCLOptions();
32-
mlir::registerDefaultTimingManagerCLOptions();
3331
mlir::registerAsmPrinterCLOptions();
3432

3533
llvm::cl::SetVersionPrinter(getVersionPrinter);
@@ -44,9 +42,13 @@ int main(int argc, char *argv[]) {
4442
llvm::errs() << "Failed to parse options\n";
4543
return 1;
4644
}
47-
4845
initCompilerConfig();
4946

47+
// Timing manager reporting enabled via "--enable-timing" compiler flag
48+
timingManager.setEnabled(enableTiming);
49+
rootTimingScope = timingManager.getRootScope();
50+
auto setupTiming = rootTimingScope.nest("[onnx-mlir] Loading Dialects");
51+
5052
// Special handling of outputBaseName to derive output filename.
5153
// outputBaseName must specify a file, so ignore invalid values
5254
// such as ".", "..", "./", "/.", etc.
@@ -71,7 +73,9 @@ int main(int argc, char *argv[]) {
7173
LLVM_DEBUG(llvm::dbgs() << "multithreading is disabled\n");
7274
}
7375
loadDialects(context);
74-
76+
setupTiming.stop();
77+
auto inputFileTiming =
78+
rootTimingScope.nest("[onnx-mlir] Importing Input Model to MLIR");
7579
mlir::OwningOpRef<mlir::ModuleOp> module;
7680
std::string errorMessage;
7781
int rc = processInputFile(inputFilename, context, module, &errorMessage);
@@ -80,6 +84,6 @@ int main(int argc, char *argv[]) {
8084
llvm::errs() << errorMessage << "\n";
8185
return 1;
8286
}
83-
87+
inputFileTiming.stop();
8488
return compileModule(module, context, outputBaseName, emissionTarget);
8589
}

0 commit comments

Comments
 (0)