Skip to content

Commit 0813d63

Browse files
committed
Code updated according to the PR comments.
Signed-off-by: Yasushi Negishi <[email protected]>
1 parent 48a3b06 commit 0813d63

File tree

5 files changed

+33
-31
lines changed

5 files changed

+33
-31
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,9 @@ class FrontendGenImpl {
453453
}
454454

455455
// Generate a string vector from the dimParams option string
456-
void getInputDimParamsVecFromOption(std::string optionStr, size_t numOfArgs,
457-
SmallVector<std::string> &paramStrVec) {
458-
std::stringstream paramStrStream;
459-
paramStrStream << optionStr;
456+
void getInputDimParamsVecFromOption(std::string optionStr,
457+
SmallVector<std::string> &paramStrVec, std::string &paramStrForAllArgs) {
458+
std::stringstream paramStrStream(optionStr);
460459
std::string dimParamStr;
461460
while (std::getline(paramStrStream, dimParamStr, '|')) {
462461
size_t pos = dimParamStr.find(':');
@@ -465,10 +464,12 @@ class FrontendGenImpl {
465464
dimParamStr = dimParamStr.substr(pos + 1);
466465
std::replace(dimParamStr.begin(), dimParamStr.end(), '=', ':');
467466
if (idx < 0) // set all arguments
468-
for (size_t i = 0; i < numOfArgs; i++)
469-
paramStrVec[i] = dimParamStr;
470-
else
467+
paramStrForAllArgs = dimParamStr;
468+
else {
469+
while ((int)paramStrVec.size() <= idx) // Expand paramStrVec
470+
paramStrVec.emplace_back("");
471471
paramStrVec[idx] = dimParamStr;
472+
}
472473
}
473474
return;
474475
}
@@ -513,10 +514,10 @@ class FrontendGenImpl {
513514
// See https://github.com/onnx/onnx/blob/main/docs/IR.md for more
514515
// information about dim_param.
515516
llvm::SmallVector<std::string, 4> inputDimParams, outputDimParams;
516-
size_t numOfArgs = graph.input().size();
517-
llvm::SmallVector<std::string> inputDimParamsFromOption(numOfArgs);
518-
getInputDimParamsVecFromOption(
519-
options_.dimParams, numOfArgs, inputDimParamsFromOption);
517+
llvm::SmallVector<std::string> inputDimParamsFromOption;
518+
std::string inputDimParamsFromOptionForAllArgs;
519+
getInputDimParamsVecFromOption(options_.dimParams, inputDimParamsFromOption,
520+
inputDimParamsFromOptionForAllArgs);
520521

521522
// Import the input tensor types that are not constant and not initialized.
522523
int inputIndex = 0;
@@ -527,8 +528,11 @@ class FrontendGenImpl {
527528
std::string dimParams = "";
528529
Type argTy = ImportType(input.type(), &dimParams);
529530
argTy = modelInputShaper_.reshape(inputIndex, argTy);
530-
if (!inputDimParamsFromOption[inputIndex].empty())
531+
if (inputIndex < (int)inputDimParamsFromOption.size() &&
532+
!inputDimParamsFromOption[inputIndex].empty())
531533
inputDimParams.emplace_back(inputDimParamsFromOption[inputIndex]);
534+
else if (!inputDimParamsFromOptionForAllArgs.empty())
535+
inputDimParams.emplace_back(inputDimParamsFromOptionForAllArgs);
532536
else if (!dimParams.empty())
533537
inputDimParams.emplace_back(dimParams);
534538

src/Builder/FrontendDialectTransformer.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ struct ImportOptions {
5656
//
5757
std::string shapeInformation = "";
5858
// Custom onnx.dim_params attributes for the graph inputs for specifying
59-
// relationship among their ranks.
60-
// Its format is 'input_id:dim=sym,dim=sym,dim=sym|input_id:dim=sym,dim=sym..'
59+
// relationship among their dynamic dimensions.
60+
// Its format is 'input_id:dim_id=sym,dim_id=sym,...|input_id:
61+
// dim_id=sym,dim_id=sym,...|input_id...'
6162
// E.g. An ONNX model has two dynamic inputs
6263
// - (arg0: tensor<?x5xf32>, arg1: tensor<?x5xf32>)
63-
// If we want to specify that the first ranks of the first and second
64-
// arguments are the same, we can use:
64+
// If we want to specify that the first unknown dimension of arg0 and the
65+
// first unknown dimension of arg1 are the same, we can assign the two
66+
// dimensions to the same symbol "batch" as follows.
6567
// - dimParams = '0:0=batch|1:0=batch'
6668
//
6769
std::string dimParams = "";

src/Compiler/CompilerOptions.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,11 @@ static llvm::cl::opt<std::string, true> shapeInformationOpt("shapeInformation",
285285
static llvm::cl::opt<std::string, true> dimParamsOpt("dimParams",
286286
llvm::cl::desc(
287287
"Custom onnx.dim_params attributes for the inputs of the ONNX model for"
288-
"specifying relationship among ranks of the inputs.\n"
288+
"specifying relationship among dynamic dimensions of the inputs.\n"
289289
"\"value\" is in the format of "
290-
"\"INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn|INPUT_ID2:D1=T1,D2=T2,...Dn=Tn\""
291-
"where \"INPUT_ID1, INPUT_ID2, ...\" are input indices (starting from "
292-
"0 or being -1 for all input indices), and\n"
290+
"\"INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn|INPUT_ID2:D1=T1,D2=T2,...Dn=Tn|"
291+
"...\" where \"INPUT_ID1, INPUT_ID2, ...\" are input indices "
292+
"(starting from 0 or being -1 for all input indices), and\n"
293293
"\"S1, S2, ...\" and \"T2, T2, ...\" are symbols to specify that same "
294294
"symbols have the same value"),
295295
llvm::cl::value_desc("value"), llvm::cl::location(dimParams),

test/accelerators/NNPA/backend/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ endif()
100100
# "zdnn_add" to check the function is called, (3) when dimParams is
101101
# "NO_DYNAMIC_SHAPE_TEST", backend test is skipped, otherwise the string is
102102
# passed as --dimParams option. "0:0=a,1=b,2=c|1:0=a,1=b,2=c" means that the
103-
# first ranks of the first, second and third ranks of the first and second
104-
# input arguments are the same respectively.
103+
# first, second and third dimensions of the first and second input arguments # are the same respectively.
105104
set(NNPA_TEST_LIST
106105

107106
# ==ARCH== NNPA

test/backend/common.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ def execute_commands(cmds, dynamic_inputs_dims):
6767
subprocess.run(cmds, env=my_env, check=True)
6868

6969

70-
def get_compile_option(test_name):
71-
if args.dynamic and test_name in variables.test_to_enable_dimparams_dict:
72-
return ["--dimParams=" + variables.test_to_enable_dimparams_dict[test_name]]
73-
else:
74-
return []
75-
76-
7770
def check_instruction(test_name, exec_name):
7871
if args.instruction_check and test_name in variables.test_to_enable_symbol_dict:
7972
symbol_name = variables.test_to_enable_symbol_dict[test_name]
@@ -115,6 +108,7 @@ def compile_model(model, emit):
115108

116109
exec_base = os.path.join(model_dir, name)
117110
exec_name = exec_base + suffix[emit]
111+
test_name_cpu = name + "_cpu"
118112

119113
# Command
120114
command_list = [TEST_DRIVER]
@@ -139,7 +133,10 @@ def compile_model(model, emit):
139133
"--constants-to-file-total-threshold="
140134
+ str(args.constants_to_file_total_threshold)
141135
)
142-
command_list += get_compile_option(name + "_cpu")
136+
if args.dynamic and test_name_cpu in variables.test_to_enable_dimparams_dict:
137+
command_list.append(
138+
"--dimParams=" + variables.test_to_enable_dimparams_dict[test_name_cpu]
139+
)
143140

144141
command_list.append(target[emit])
145142
command_list.append(model_name)
@@ -161,6 +158,6 @@ def compile_model(model, emit):
161158
# in execute_commands when calling subprocess.run.
162159

163160
# Check if specific instruction are included in the compiled model.
164-
check_instruction(name + "_cpu", exec_name)
161+
check_instruction(test_name_cpu, exec_name)
165162

166163
return exec_name

0 commit comments

Comments
 (0)