Skip to content

Commit c38e5a8

Browse files
authored
Merge branch 'main' into ferdinand.update_llvm_april_2024
2 parents a9708e1 + 733dfac commit c38e5a8

14 files changed

+281
-84
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ SUPPRESS_WARNINGS_POP
5151
#include <array>
5252
#include <fstream>
5353
#include <map>
54+
#include <sstream>
55+
#include <string>
5456
#include <unordered_map>
5557
#include <vector>
5658

@@ -450,6 +452,27 @@ class FrontendGenImpl {
450452
return attributes;
451453
}
452454

455+
// Generate a string vector from the dimParams option string
456+
void getInputDimParamsMapFromOption(std::string optionStr,
457+
std::map<int, std::string> &paramStrMap,
458+
std::string &paramStrForAllArgs) {
459+
std::stringstream paramStrStream(optionStr);
460+
std::string dimParamStr;
461+
while (std::getline(paramStrStream, dimParamStr, '|')) {
462+
size_t pos = dimParamStr.find(':');
463+
assert((pos > 0) && "invalid dimParams option string");
464+
int idx = stoi(dimParamStr.substr(0, pos));
465+
dimParamStr = dimParamStr.substr(pos + 1);
466+
std::replace(dimParamStr.begin(), dimParamStr.end(), '=', ':');
467+
if (idx < 0) // set all arguments
468+
paramStrForAllArgs = dimParamStr;
469+
else {
470+
paramStrMap[idx] = dimParamStr;
471+
}
472+
}
473+
return;
474+
}
475+
453476
/*!
454477
* An alternative graph importing procedure for importing ONNX subgraphs.
455478
* ONNX subgraphs, unlike the main computation graph, are imported as regions
@@ -490,6 +513,10 @@ class FrontendGenImpl {
490513
// See https://github.com/onnx/onnx/blob/main/docs/IR.md for more
491514
// information about dim_param.
492515
llvm::SmallVector<std::string, 4> inputDimParams, outputDimParams;
516+
std::map<int, std::string> inputDimParamsFromOption;
517+
std::string inputDimParamsFromOptionForAllArgs;
518+
getInputDimParamsMapFromOption(options_.dimParams, inputDimParamsFromOption,
519+
inputDimParamsFromOptionForAllArgs);
493520

494521
// Import the input tensor types that are not constant and not initialized.
495522
int inputIndex = 0;
@@ -500,7 +527,16 @@ class FrontendGenImpl {
500527
std::string dimParams = "";
501528
Type argTy = ImportType(input.type(), &dimParams);
502529
argTy = modelInputShaper_.reshape(inputIndex, argTy);
503-
if (!dimParams.empty())
530+
// For each input tensor, use either all dimensions by the compiler
531+
// option OR all dimensions in the original onnx model. Dimensions
532+
// from the option and the model in a single input tensor are not
533+
// merged.
534+
if (inputDimParamsFromOption.find(inputIndex) !=
535+
inputDimParamsFromOption.end())
536+
inputDimParams.emplace_back(inputDimParamsFromOption[inputIndex]);
537+
else if (!inputDimParamsFromOptionForAllArgs.empty())
538+
inputDimParams.emplace_back(inputDimParamsFromOptionForAllArgs);
539+
else if (!dimParams.empty())
504540
inputDimParams.emplace_back(dimParams);
505541

506542
argTypes.emplace_back(argTy);

src/Builder/FrontendDialectTransformer.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ struct ImportOptions {
5555
// - (arg0: tensor<3x4x5xf32>, arg1: tensor<10x5xf32>)
5656
//
5757
std::string shapeInformation = "";
58+
// Custom onnx.dim_params attributes for the graph inputs for specifying
59+
// relationship among their dynamic dimensions.
60+
// Its format is 'input_id:dim_id=sym,dim_id=sym,...|input_id:
61+
// dim_id=sym,dim_id=sym,...|input_id...'
62+
// E.g. An ONNX model has two dynamic inputs
63+
// - (arg0: tensor<?x5xf32>, arg1: tensor<?x5xf32>)
64+
// If we want to specify that the first unknown dimension of arg0 and the
65+
// first unknown dimension of arg1 are the same, we can assign the two
66+
// dimensions to the same symbol "batch" as follows.
67+
// - dimParams = '0:0=batch|1:0=batch'
68+
//
69+
std::string dimParams = "";
5870
// Directory to look for external data if any tensor has external
5971
// data location. If empty then external data is disabled.
6072
std::string externalDataDir = "";

src/Compiler/CompilerOptions.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ std::vector<std::string> onnxConstPropDisablePatterns; // common for both
4141
bool enableONNXHybridPass; // common for both
4242
std::vector<std::string> functionsToDecompose; // common for both
4343
std::string opsForCall; // common for both
44+
bool disableKrnlOpFusion; // common for both
4445
EmissionTargetType emissionTarget; // onnx-mlir only
4546
bool invokeOnnxVersionConverter; // onnx-mlir only
4647
bool preserveLocations; // onnx-mlir only
@@ -51,6 +52,7 @@ bool preserveMLIR; // onnx-mlir only
5152
bool useOnnxModelTypes; // onnx-mlir only
5253
int repeatOnnxTransform; // onnx-mlir only
5354
std::string shapeInformation; // onnx-mlir only
55+
std::string dimParams; // onnx-mlir only
5456
ModelSize modelSize; // onnx-mlir only
5557
bool storeConstantsToFile; // onnx-mlir only
5658
float constantsToFileTotalThreshold; // onnx-mlir only
@@ -201,6 +203,13 @@ static llvm::cl::list<std::string, std::vector<std::string>>
201203
llvm::cl::location(functionsToDecompose),
202204
llvm::cl::cat(OnnxMlirCommonOptions));
203205

206+
static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
207+
"disable-krnl-op-fusion",
208+
llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n"
209+
"Set to 'true' if you want to disable fusion."),
210+
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
211+
llvm::cl::cat(OnnxMlirCommonOptions));
212+
204213
static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
205214
llvm::cl::desc("Disable recomposition of ONNX operations."),
206215
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
@@ -282,6 +291,23 @@ static llvm::cl::opt<std::string, true> shapeInformationOpt("shapeInformation",
282291
llvm::cl::value_desc("value"), llvm::cl::location(shapeInformation),
283292
llvm::cl::cat(OnnxMlirOptions));
284293

294+
static llvm::cl::opt<std::string, true> dimParamsOpt("dimParams",
295+
llvm::cl::desc(
296+
"Custom onnx.dim_params attributes for the inputs of the ONNX model for"
297+
"specifying relationship among dynamic dimensions of the inputs.\n"
298+
"\"value\" is in the format of "
299+
"\"INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn|INPUT_ID2:D1=T1,D2=T2,...Dn=Tn|"
300+
"...\" where \"INPUT_ID1, INPUT_ID2, ...\" are input indices "
301+
"(starting from 0 or being -1 for all input indices), and\n"
302+
"\"S1, S2, ...\" and \"T2, T2, ...\" are symbols to specify that same "
303+
"symbols have the same value. "
304+
"All dimensions of onnx.dim_params for a specified input index in "
305+
"the original onnx model are cleared and repalced by this option. "
306+
"onnx.dim_params for other input indices in the original onnx model "
307+
"are not cleared"),
308+
llvm::cl::value_desc("value"), llvm::cl::location(dimParams),
309+
llvm::cl::cat(OnnxMlirOptions));
310+
285311
// Default value is defined by the OnnxMlirEnvOptionName constant string
286312
// variable, but the default setting mechanism here cannot be used here as we
287313
// need to evaluate this value prior to the compiler options being set. Proper

src/Compiler/CompilerOptions.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ extern std::vector<std::string> onnxConstPropDisablePatterns; // common for both
8484
extern bool enableONNXHybridPass; // common for both
8585
extern std::vector<std::string> functionsToDecompose; // common for both
8686
extern std::string opsForCall; // common for both
87+
extern bool disableKrnlOpFusion; // common for both
8788
extern EmissionTargetType emissionTarget; // onnx-mlir only
8889
extern bool invokeOnnxVersionConverter; // onnx-mlir only
8990
extern bool preserveLocations; // onnx-mlir only
@@ -94,6 +95,7 @@ extern bool preserveMLIR; // onnx-mlir only
9495
extern bool useOnnxModelTypes; // onnx-mlir only
9596
extern int repeatOnnxTransform; // onnx-mlir only
9697
extern std::string shapeInformation; // onnx-mlir only
98+
extern std::string dimParams; // onnx-mlir only
9799
extern ModelSize modelSize; // onnx-mlir only
98100
extern bool storeConstantsToFile; // onnx-mlir only
99101
extern float constantsToFileTotalThreshold; // onnx-mlir only

src/Compiler/CompilerUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context,
636636
options.useOnnxModelTypes = useOnnxModelTypes;
637637
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
638638
options.shapeInformation = shapeInformation;
639+
options.dimParams = dimParams;
639640
options.allowSorting = allowSorting;
640641
options.externalDataDir = dirName(inputFilename);
641642
options.functionsToDecompose.insert(options.functionsToDecompose.end(),

src/Conversion/ONNXToKrnl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXToKrnl
8080

8181
LINK_LIBS PUBLIC
8282
OMAccelerator
83+
OMCompilerOptions
8384
OMONNXConversionCommon
8485
OMONNXOps
8586
OMSupport

src/Conversion/ONNXToKrnl/Math/Elementwise.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "llvm/Support/Debug.h"
1919

20+
#include "src/Compiler/CompilerOptions.hpp"
2021
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
2122
#include "src/Dialect/Krnl/DialectBuilder.hpp"
2223
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
@@ -1884,6 +1885,9 @@ bool OpFusionHelper::areInputsValidForFusion(
18841885
// A successor op (user) is fusible if it is the only user, it is in the
18851886
// fusible elementwise op list, and its inputs are valid for fusion.
18861887
void OpFusionHelper::findFusibleOps() {
1888+
// Direct return if fusion is disabled
1889+
if (disableKrnlOpFusion)
1890+
return;
18871891
Operation *defOp = rootOp;
18881892
while (defOp->hasOneUse()) {
18891893
// the possible ONNX Ops.

0 commit comments

Comments
 (0)