Skip to content

Introduce --dimParams option to set onnx.dim_params attrs of dynamic model inputs, and extend dynamic backend tests on NNPA with this option to utilze NNPA #2781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 30, 2024

Conversation

negiyas
Copy link
Collaborator

@negiyas negiyas commented Apr 2, 2024

===== 2024/04/19 updated according to local discussions =====
This PR support the following two features.

(1) This PR introduces --dimParams option for onnx-mlir to set onnx.dim_params attrs of dynamic model inputs, which specifies relationships among dimensions of model inputs. For example, onnx-mlir cannot utilize NNPA for the following dynamic model, because broadcasting of %arg0 or %arg1 may be required and broadcasting is not supported by NNPA.

module {
func.func @main_graph(%arg0: tensor<?x?x?xf32> {onnx.name = "x"}, %arg1: tensor<?x?x?xf32> {onnx.name = "y"})
-> (tensor<3x4x5xf32> {onnx.name = "sum"}) {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
onnx.Return %0 : tensor<?x?x?xf32>
}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

We can use the --dimParams option to specify that all corresponding dimensions between two dynamic inputs %arg0 and %arg1 are the same by --dimParams="0:0=a,1=b,2=c|1:0=a,1=b,2=c". Accordingly compiler can know no broadcasting is required and can utilize NNPA in this case.

module {
func.func @main_graph(%arg0: tensor<?x?x?xf32> {onnx.dim_params = "0:a,1:b,2:c", onnx.name = "x"}, %arg1: tensor<?x?x?xf32> {onnx.dim_params = "0:a,1:b,2:c", onnx.name = "y"}) -> (tensor<?x?x?xf32> {onnx.name = "sum"}) {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
onnx.Return %0 : tensor<?x?x?xf32>
}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

(2) This PR also extends functions of dynamic backend tests for NNPA by using the --dimParams option
In backend tests on NNPA, support new notations for setting arguments of the option in the NNPA_TEST_LIST variable in the test/accelerators/NNPA/backend/CMakeLists.txt. The following examples mean as follows.

    test_add_cpu,zdnn_add,"0:0=a,1=b,2=c|1:0=a,1=b,2=c"
    test_averagepool_2d_default_cpu,zdnn_avgpool2d,NO_DYNAMIC_SHAPE_TEST
  • In the "test_add_cpu" test, "zdnn_add" should be called and the first, second and third dimensions of the first and second inputs are the same.
  • In the "test_averagepool_2d_default_cpu" test, dynamic shape test on NNPA should not be executed.

@negiyas negiyas marked this pull request as draft April 2, 2024 09:55
Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some first comments.

docs/Testing.md Outdated
```
IMPORTER_FORCE_DYNAMIC='-1:-1' all dimensions of all the inputs will be changed
IMPORTER_FORCE_DYNAMIC='0:-1' all dimensions of the first input will be changed
IMPORTER_FORCE_DYNAMIC='0:-1|1:0,1' all dimensions of the first input and the 1st and 2nd dimensions of the second input will be changed
IMPORTER_FORCE_DYNAMIC='0:0:a,1:b,2:c|1:0:a,1:b,2:c' the first three dimensions of the first of and the second inputs are changed. And assume that the first dimensions of the first and second arguments are the same, and same for the the second and third dimensions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know by using : we can pass dimIndex directly to an argument attribute, but I recommend to use a different symbol rather than : for dimIndex, because : has already been used to separate inputIndex and dimString. It confuses the parsers. How about - or =?

This example is quite trivial. Can we shuffle the symbols a bit, e.g. 0:0=a,1=b,2=c|1:0,1=c,2=b, and we can have a mix of <index> and <index> '=' <symbol>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tungld I changed the notation according to this comment. Thanks for the suggestions!

docs/Testing.md Outdated
<index> ::= -1 | <number>
<number> ::= <digit> | <digit><number>
<digit> ::= 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
<symbol> ::= 'a', 'b', 'c', ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add 'z' at the end to finish the list though it is still no perfect, e.g.
<symbol> ::= 'a', 'b', 'c', ..., 'z'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

docs/Testing.md Outdated
@@ -76,15 +80,18 @@ func @main_graph(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x4x5xf32>) -> tensor<3
```
with `IMPORTER_FORCE_DYNAMIC='-1:-1'`, the result is:
```
func @main_graph(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
func @main_graph(%arg0: tensor<?x?x?xf32>{onnx.name = "x“, onnx.dim_params = "0:a,1:b,2:c"}, %arg1: tensor<?x?x?xf32>{onnx.name = "y“, onnx.dim_params = "0:a,1:b,2:c"}) -> tensor<?x?x?xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 for dimIndex to say all arguments have the same dynamic dims does not make sense. It is too restricted.
Let's consider the dynamic tests for CPU where it supports broadcasting, e.g add(tensor<3x5xf32>, tensor<3x1xf32>, using -1:-1 will force all dynamic dimensions being the same and the lowering will generate different code, not taking care of broadcasting.

In summary, I recommend to not support this, and let users specify the same dynamic dimensions manually for the cases they need.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the assumption about default relationships among input data ranks, and specified the relationship manually in the test/accelerators/NNPA/backend/CMakeLists.txt file for each dynamic backend tests on NNPA.
Thanks for the suggestions.

@negiyas negiyas marked this pull request as ready for review April 15, 2024 09:18
@negiyas negiyas changed the title [WIP] Extend models for dynamic backend test with "onnx.dim_params" attributes in input arguments. Extend models for dynamic backend test with "onnx.dim_params" attributes in input arguments. Apr 15, 2024
Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments besides the main thing about introducing --dimParams as we discussed locally.

}
} else {
getParamStr(envStr, paramStrVec[idx]);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we can simplify this to avoid calling getParamStr, e.g.

envStr = envStr.substr(pos + 1);
std::replace(envStr.begin(), envStr.end(), '=', ':');

if (idx >= 0) {
  paramStrVec[idx] = envStr;
  return;
}

// Idx < 0 means setting all arguments with the same value.
for (size_t i = 0; i < numOfArgs; i++)
  paramStrVec[i]  = envStr;
return;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!


// Get named attributesfrom IMPORTER_FORCE_DYNAMIC environment.
SmallVector<NamedAttribute> getNamedAttrFromImporterForceDynamic(
func::FuncOp funcOp, size_t numOfArgs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid copying the result, it's better to pass argNamedAttrs as an argument, e.g.SmallVectorImp<NamedAttribute> &argNamedAttrs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was cleaned up, Thanks.

getInputDimParamStrVec(
envInputString, funcOp, numOfArgs, inputDimParamStrVec);
for (size_t i = 0; i < numOfArgs; ++i) {
if (inputDimParamStrVec[i].length() != 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really happened? e.g. input dim param is an empty string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was cleaned up. Thanks.

@negiyas negiyas changed the title Extend models for dynamic backend test with "onnx.dim_params" attributes in input arguments. Introduce --dimParams option to set onnx.dim_params attrs of dynamic model inputs, and extend dynamic backend tests on NNPA with this option to utilze NNPA Apr 19, 2024
Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update! The PR is in a good shape. I just put some comments to make the explanation easier to follow.

void getInputDimParamsVecFromOption(std::string optionStr, size_t numOfArgs,
SmallVector<std::string> &paramStrVec) {
std::stringstream paramStrStream;
paramStrStream << optionStr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For brevity: std::stringstream paramStrStream(optionStr);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

@@ -490,6 +513,10 @@ class FrontendGenImpl {
// See https://github.com/onnx/onnx/blob/main/docs/IR.md for more
// information about dim_param.
llvm::SmallVector<std::string, 4> inputDimParams, outputDimParams;
size_t numOfArgs = graph.input().size();
llvm::SmallVector<std::string> inputDimParamsFromOption(numOfArgs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numOfArgs here may contain initializers also, so it does not count exactly the number of actual arguments to the main function. Please see inputIndex below that correctly represents the number of actual arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the code not to use numOfArgs in inputDimParamsFromOption . Thanks!

@@ -55,6 +55,16 @@ struct ImportOptions {
// - (arg0: tensor<3x4x5xf32>, arg1: tensor<10x5xf32>)
//
std::string shapeInformation = "";
// Custom onnx.dim_params attributes for the graph inputs for specifying
// relationship among their ranks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

their ranks looks confusing. To be precise, it would be their dynamic dimensions, given that dim in dim_params means dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

// Its format is 'input_id:dim=sym,dim=sym,dim=sym|input_id:dim=sym,dim=sym..'
// E.g. An ONNX model has two dynamic inputs
// - (arg0: tensor<?x5xf32>, arg1: tensor<?x5xf32>)
// If we want to specify that the first ranks of the first and second
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to specify that the first ranks of the first and second arguments are the same => If we want to specify that the first unknown dimension of arg0 and the first unknown dimension of arg1 are the same, we can assign the two dimensions to the same symbol "batch" as follows.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

static llvm::cl::opt<std::string, true> dimParamsOpt("dimParams",
llvm::cl::desc(
"Custom onnx.dim_params attributes for the inputs of the ONNX model for"
"specifying relationship among ranks of the inputs.\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

among ranks of => among dynamic dimensions of

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

# "zdnn_add" to check the function is called, (3) when dimParams is
# "NO_DYNAMIC_SHAPE_TEST", backend test is skipped, otherwise the string is
# passed as --dimParams option. "0:0=a,1=b,2=c|1:0=a,1=b,2=c" means that the
# first ranks of the first, second and third ranks of the first and second
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this explanation easier to follow? Don't use rank but dimension. Ranks in MLIR means the number of dimensions (in this case rank = 3).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. I fixed. I confirmed that all "rank" were removed modifications by this PR.

@@ -132,6 +139,7 @@ def compile_model(model, emit):
"--constants-to-file-total-threshold="
+ str(args.constants_to_file_total_threshold)
)
command_list += get_compile_option(name + "_cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have a if directly here, it's easy to follow and faster:

if args.dynamic and test_name in variables.test_to_enable_dimparams_dict:
        command_list.append("--dimParams=" + variables.test_to_enable_dimparams_dict[test_name])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

@@ -0,0 +1,23 @@
// RUN: onnx-mlir --EmitONNXBasic --dimParams="0:0=a,1=b,2=c|1:0=a,1=b,2=c" --printIR %s | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this test! Really appreciate it.

@@ -55,6 +55,16 @@ struct ImportOptions {
// - (arg0: tensor<3x4x5xf32>, arg1: tensor<10x5xf32>)
//
std::string shapeInformation = "";
// Custom onnx.dim_params attributes for the graph inputs for specifying
// relationship among their ranks.
// Its format is 'input_id:dim=sym,dim=sym,dim=sym|input_id:dim=sym,dim=sym..'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim_id would be better than dim

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. at the end would be three dots ...? Also that ... is for | meaning multiple input_id:dim=sym,dim=sym,dim=sym, or for , meaning multiple dim=sym,dim=sym?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed.

"Custom onnx.dim_params attributes for the inputs of the ONNX model for"
"specifying relationship among ranks of the inputs.\n"
"\"value\" is in the format of "
"\"INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn|INPUT_ID2:D1=T1,D2=T2,...Dn=Tn\""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a , before that last Dn.

It's better to add | ... at the end to denote multiple INPUT_ID, e.g.

INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn | INPUT_ID2:D1=T1,D2=T2,...,Dn=Tn | ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

"All dimensions of onnx.dim_params for a specified input index in "
"the original onnx model are cleared and repalced by this option. "
"onnx.dim_params for other input indices in the original onnx model "
"are not cleared"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the explanation!

paramStrForAllArgs = dimParamStr;
else {
while ((int)paramStrVec.size() <= idx) // Expand paramStrVec
paramStrVec.emplace_back("");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the option definition you wrote onnx.dim_params for other input indices in the original onnx model are not cleared, but It looks to me only inputs whose index is larger than the max idx here are not cleared. And, the inputs whose index <= the max idx will be cleared.

For example, we have 4 input tensors

arg0: tensor<?x?xf32> {onnx.dim_param=0:x},
arg1: tensor<?x?xf32> {onnx.dim_param=0:y},
arg2: tensor<?x?xf32>  {onnx.dim_param=0:z}
arg3: tensor<?x?xf32>  {onnx.dim_param=0:c,1:d}

with --dimParams="0:0=b,1=a|2:0=a,1=b", it becomes

arg0: tensor<?x?xf32> {onnx.dim_param=0:b,1:a},
arg1: tensor<?x?xf32>,
arg2: tensor<?x?xf32>  {onnx.dim_param=0:a,1:b}
arg3: tensor<?x?xf32>  {onnx.dim_param=0:c,1:d}

Is it correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tungld Thanks for the comment.
It is not correct. onnx.dim_param of arg1 is not cleared with --dimParams="0:0=b,1=a|2:0=a,1=b".

       // For each input tensor, use either all dimensions by the compiler
       // option OR all dimensions in the original onnx model. Dimensions
       // from the option and the model in a single input tensor are not
       // merged.
       if (inputIndex < (int)inputDimParamsFromOption.size() &&
           !inputDimParamsFromOption[inputIndex].empty())
         inputDimParams.emplace_back(inputDimParamsFromOption[inputIndex]);
       else if (!inputDimParamsFromOptionForAllArgs.empty())
         inputDimParams.emplace_back(inputDimParamsFromOptionForAllArgs);
       else if (!dimParams.empty())
           inputDimParams.emplace_back(dimParams);

The code checks if inputDimParamsFromOption[inputIndex] is empty string, and if it's empty, it uses dimParams (=onnx.dim_param in the original model), so that onnx.dim_param attr of arg1 in the original model is not cleared.

I added test/mlir/onnx/parse/dim_param_and_option.json to verify this situation.
In the test, With --dimParams="1:1=b", onnx.dim_params="0:A,1:B" of the first input argument in the original model, is not cleared with the option.

Original model

module {
  func.func @main_graph(%arg0: tensor<?x?xf32> {onnx.dim_params = "0:A,1:B", onnx.name = "X"}, %arg1: tensor<?x?xf32> {onnx.dim_params = "0:A,1:B", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:A,1:B", onnx.name = "Z"}) {
    %0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
    onnx.Return %0 : tensor<?x?xf32>
  }
  "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

Imported model

module {
  func.func @main_graph(%arg0: tensor<?x?xf32> {onnx.dim_params = "0:A,1:B", onnx.name = "X"}, %arg1: tensor<?x?xf32> {onnx.dim_params = "1:b", onnx.name = "Y"}) -> (tensor<?x?xf32> {onnx.dim_params = "0:A,1:B", onnx.name = "Z"}) {
    %0 = "onnx.Add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
    onnx.Return %0 : tensor<?x?xf32>
  }
  "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

Copy link
Collaborator

@tungld tungld Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thank you for the confirmation! Just a suggestion: if you use a map instead of vector for inputDimParamsFromOption, the code is easier to read and you don't need to expand the vector and check if (inputIndex < (int)inputDimParamsFromOption.size() && !inputDimParamsFromOption[inputIndex].empty()). Vector data structure looks not ideal here. It's a bit tricky to use an empty string.

Copy link
Collaborator Author

@negiyas negiyas Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the code to use std::map<int, std::string> instead of SmallVector<std::string>
Thanks for the suggestion. It makes the code simpler.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

negiyas added 2 commits April 26, 2024 08:47
…rc/Builder/FrontendDialectTransformer.cpp

Signed-off-by: Yasushi Negishi <[email protected]>
Signed-off-by: Yasushi Negishi <[email protected]>
@negiyas negiyas merged commit 733dfac into onnx:main Apr 30, 2024
6 of 7 checks passed
@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #13764 [push] Introduce --dimParams op... started at 22:48

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #14739 [push] Introduce --dimParams op... started at 21:36

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #14769 [push] Introduce --dimParams op... started at 22:36

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #14739 [push] Introduce --dimParams op... passed after 1 hr 5 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #14769 [push] Introduce --dimParams op... passed after 1 hr 29 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #13764 [push] Introduce --dimParams op... passed after 2 hr 3 min

@AlexandreEichenberger
Copy link
Collaborator

@hamptonm1 @Sunny-Anand Suggestion is to create issues so that the root cause of the model features will be evaluated with the goal of increasing our share of model zoo models successfully running with / without dynamic bounds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants