@@ -885,7 +885,9 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
885
885
triton::common::TritonJson::Value reshape_dims (
886
886
ModelConfig (), triton::common::TritonJson::ValueType::ARRAY);
887
887
RETURN_IF_ERROR (reshape.Add (" shape" , std::move (reshape_dims)));
888
- RETURN_IF_ERROR (io.Add (" reshape" , std::move (reshape)));
888
+ if (MaxBatchSize () > 0 ) {
889
+ RETURN_IF_ERROR (io.Add (" reshape" , std::move (reshape)));
890
+ }
889
891
}
890
892
RETURN_IF_ERROR (io.Add (" dims" , std::move (dims)));
891
893
RETURN_IF_ERROR (ios.Append (std::move (io)));
@@ -998,6 +1000,15 @@ class ModelInstanceState : public BackendModelInstance {
998
1000
// map of output name -> tensor info
999
1001
OnnxTensorInfoMap output_tensor_infos_;
1000
1002
1003
+ // map of input name -> tensor info
1004
+ OnnxTensorInfoMap input_tensor_infos_;
1005
+
1006
+ // A map from scalar input tensors to the dimension specified in model config
1007
+ std::unordered_map<std::string, std::vector<int64_t >> scalar_inputs_;
1008
+
1009
+ // A map from scalar output tensors to the dimension specified in model config
1010
+ std::unordered_map<std::string, std::vector<int64_t >> scalar_outputs_;
1011
+
1001
1012
// Onnx Runtime variables that will be reset and used for every run
1002
1013
// on this instance.
1003
1014
std::vector<OrtValue*> input_tensors_;
@@ -1313,9 +1324,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
1313
1324
{
1314
1325
std::set<std::string> input_tensor_names;
1315
1326
RETURN_IF_ERROR (InputNames (session_, input_tensor_names));
1316
-
1317
- OnnxTensorInfoMap input_tensor_infos;
1318
- RETURN_IF_ERROR (InputInfos (session_, default_allocator_, input_tensor_infos));
1327
+ RETURN_IF_ERROR (
1328
+ InputInfos (session_, default_allocator_, input_tensor_infos_));
1319
1329
1320
1330
std::set<std::string> overridable_initializer_tensor_names;
1321
1331
RETURN_IF_ERROR (OverridableInitializerNames (
@@ -1325,12 +1335,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
1325
1335
RETURN_IF_ERROR (OverridableInitializerInfos (
1326
1336
session_, default_allocator_, overridable_initializer_tensor_infos));
1327
1337
1328
- if (input_tensor_infos .size () != expected_input_cnt) {
1338
+ if (input_tensor_infos_ .size () != expected_input_cnt) {
1329
1339
return TRITONSERVER_ErrorNew (
1330
1340
TRITONSERVER_ERROR_INVALID_ARG,
1331
1341
(std::string (" unable to load model '" ) + model_state_->Name () +
1332
1342
" ', configuration expects " + std::to_string (expected_input_cnt) +
1333
- " inputs, model provides " + std::to_string (input_tensor_infos.size ()))
1343
+ " inputs, model provides " +
1344
+ std::to_string (input_tensor_infos_.size ()))
1334
1345
.c_str ());
1335
1346
}
1336
1347
@@ -1357,8 +1368,9 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
1357
1368
1358
1369
const auto & tensor_names =
1359
1370
io_optional ? overridable_initializer_tensor_names : input_tensor_names;
1360
- const auto & tensor_infos =
1361
- io_optional ? overridable_initializer_tensor_infos : input_tensor_infos;
1371
+ const auto & tensor_infos = io_optional
1372
+ ? overridable_initializer_tensor_infos
1373
+ : input_tensor_infos_;
1362
1374
auto iit = tensor_infos.find (io_name);
1363
1375
if (iit == tensor_infos.end ()) {
1364
1376
RETURN_IF_ERROR (CheckAllowedModelInput (io, tensor_names));
@@ -1419,9 +1431,30 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
1419
1431
.c_str ());
1420
1432
}
1421
1433
} else {
1422
- RETURN_IF_ERROR (CompareDimsSupported (
1423
- model_state_->Name (), io_name, iit->second .dims_ , dims,
1424
- model_state_->MaxBatchSize (), false /* compare_exact */ ));
1434
+ if (model_state_->MaxBatchSize () != 0 || iit->second .dims_ .size () > 0 ) {
1435
+ RETURN_IF_ERROR (CompareDimsSupported (
1436
+ model_state_->Name (), io_name, iit->second .dims_ , dims,
1437
+ model_state_->MaxBatchSize (), false /* compare_exact */ ));
1438
+ } else {
1439
+ // if max_batch_size == 0 and is a scalar tensor all the
1440
+ // dimensions specified must be equal to 1
1441
+ for (auto & dim : dims) {
1442
+ if (dim != 1 ) {
1443
+ return TRITONSERVER_ErrorNew (
1444
+ TRITONSERVER_ERROR_INVALID_ARG,
1445
+ (std::string (" unable to load model '" ) + model_state_->Name () +
1446
+ " ', scalar tensor '" + io_name +
1447
+ " ', should only provide 1 in the model configuration when the "
1448
+ " model doesn't support batching. Model configuration "
1449
+ " provided: " +
1450
+ ShapeToString (dims) + " ." )
1451
+ .c_str ());
1452
+ }
1453
+ }
1454
+
1455
+ // store the dimension for reference.
1456
+ scalar_inputs_[io_name] = dims;
1457
+ }
1425
1458
}
1426
1459
}
1427
1460
@@ -1482,9 +1515,30 @@ ModelInstanceState::ValidateOutputs()
1482
1515
1483
1516
// The batch output shape doesn't necessarily match the model
1484
1517
if (model_state_->FindBatchOutput (io_name) == nullptr ) {
1485
- RETURN_IF_ERROR (CompareDimsSupported (
1486
- model_state_->Name (), io_name, iit->second .dims_ , dims,
1487
- model_state_->MaxBatchSize (), true /* compare_exact */ ));
1518
+ // if max_batch_size == 0 and is a scalar tensor all the
1519
+ // dimensions specified must be equal to 1
1520
+ if (model_state_->MaxBatchSize () > 0 || iit->second .dims_ .size () > 0 ) {
1521
+ RETURN_IF_ERROR (CompareDimsSupported (
1522
+ model_state_->Name (), io_name, iit->second .dims_ , dims,
1523
+ model_state_->MaxBatchSize (), true /* compare_exact */ ));
1524
+ } else {
1525
+ for (auto & dim : dims) {
1526
+ if (dim != 1 ) {
1527
+ return TRITONSERVER_ErrorNew (
1528
+ TRITONSERVER_ERROR_INVALID_ARG,
1529
+ (std::string (" unable to load model '" ) + model_state_->Name () +
1530
+ " ', scalar tensor '" + io_name +
1531
+ " ', should only provide 1 in the model configuration when the "
1532
+ " model doesn't support batching. Model configuration "
1533
+ " provided: " +
1534
+ ShapeToString (dims) + " ." )
1535
+ .c_str ());
1536
+ }
1537
+ }
1538
+
1539
+ // store the dimension for reference.
1540
+ scalar_outputs_[io_name] = dims;
1541
+ }
1488
1542
}
1489
1543
}
1490
1544
@@ -1900,13 +1954,34 @@ ModelInstanceState::SetInputTensors(
1900
1954
input_name, nullptr , 0 , allowed_input_types, &input_buffer,
1901
1955
&batchn_byte_size, &memory_type, &memory_type_id));
1902
1956
1957
+ auto iti = input_tensor_infos_.find (input_name);
1958
+ if (iti == input_tensor_infos_.end ()) {
1959
+ return TRITONSERVER_ErrorNew (
1960
+ TRITONSERVER_ERROR_INTERNAL,
1961
+ std::string (
1962
+ std::string (
1963
+ " Failed to retrieve the ONNX input tensor info from '" ) +
1964
+ input_name + " '." )
1965
+ .c_str ());
1966
+ }
1967
+
1903
1968
// Create ORT Tensor
1904
- RETURN_IF_ORT_ERROR (ort_api->CreateTensorWithDataAsOrtValue (
1905
- memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1906
- : cpu_allocator_info_,
1907
- (void *)input_buffer, batchn_byte_size, batchn_shape.data (),
1908
- batchn_shape.size (), ConvertToOnnxDataType (input_datatype),
1909
- &input_tensors_.back ()));
1969
+ if (iti->second .dims_ .size () == 0 ) {
1970
+ // scalar tensor
1971
+ RETURN_IF_ORT_ERROR (ort_api->CreateTensorWithDataAsOrtValue (
1972
+ memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1973
+ : cpu_allocator_info_,
1974
+ (void *)input_buffer, batchn_byte_size, nullptr /* scalar */ ,
1975
+ 0 /* number of dims */ , ConvertToOnnxDataType (input_datatype),
1976
+ &input_tensors_.back ()));
1977
+ } else {
1978
+ RETURN_IF_ORT_ERROR (ort_api->CreateTensorWithDataAsOrtValue (
1979
+ memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1980
+ : cpu_allocator_info_,
1981
+ (void *)input_buffer, batchn_byte_size, batchn_shape.data (),
1982
+ batchn_shape.size (), ConvertToOnnxDataType (input_datatype),
1983
+ &input_tensors_.back ()));
1984
+ }
1910
1985
RETURN_IF_ORT_ERROR (
1911
1986
ort_api->BindInput (io_binding_, input_name, input_tensors_.back ()));
1912
1987
} else {
@@ -2283,6 +2358,22 @@ ModelInstanceState::ReadOutputTensors(
2283
2358
batchn_shape, dtype, output_tensor, &output_buffer, string_buffers,
2284
2359
offsets));
2285
2360
2361
+ // If the number of dimensions is equal to zero, it means that it is a
2362
+ // scalar and it would use the dimensions specified in the mdel
2363
+ // configuration.
2364
+ if (batchn_shape.size () == 0 ) {
2365
+ auto scalar_output_dims_it = scalar_outputs_.find (name);
2366
+ if (scalar_output_dims_it == scalar_outputs_.end ()) {
2367
+ return TRITONSERVER_ErrorNew (
2368
+ TRITONSERVER_ERROR_INTERNAL,
2369
+ std::string (
2370
+ " Failed to find the scalar output dimension for " + name +
2371
+ " in the model configuration." )
2372
+ .c_str ());
2373
+ }
2374
+ batchn_shape = scalar_output_dims_it->second ;
2375
+ }
2376
+
2286
2377
if (output_tensor_pair.first != -1 ) {
2287
2378
if (dtype == TRITONSERVER_TYPE_BYTES) {
2288
2379
auto content = string_buffers.back ().data ();
0 commit comments