Skip to content

Commit 0b6b36a

Browse files
committed
Add scalar support in ORT backend
1 parent 4b88138 commit 0b6b36a

File tree

1 file changed

+111
-20
lines changed

1 file changed

+111
-20
lines changed

src/onnxruntime.cc

+111-20
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,9 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
885885
triton::common::TritonJson::Value reshape_dims(
886886
ModelConfig(), triton::common::TritonJson::ValueType::ARRAY);
887887
RETURN_IF_ERROR(reshape.Add("shape", std::move(reshape_dims)));
888-
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
888+
if (MaxBatchSize() > 0) {
889+
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
890+
}
889891
}
890892
RETURN_IF_ERROR(io.Add("dims", std::move(dims)));
891893
RETURN_IF_ERROR(ios.Append(std::move(io)));
@@ -998,6 +1000,15 @@ class ModelInstanceState : public BackendModelInstance {
9981000
// map of output name -> tensor info
9991001
OnnxTensorInfoMap output_tensor_infos_;
10001002

1003+
// map of input name -> tensor info
1004+
OnnxTensorInfoMap input_tensor_infos_;
1005+
1006+
// A map from scalar input tensors to the dimension specified in model config
1007+
std::unordered_map<std::string, std::vector<int64_t>> scalar_inputs_;
1008+
1009+
// A map from scalar output tensors to the dimension specified in model config
1010+
std::unordered_map<std::string, std::vector<int64_t>> scalar_outputs_;
1011+
10011012
// Onnx Runtime variables that will be reset and used for every run
10021013
// on this instance.
10031014
std::vector<OrtValue*> input_tensors_;
@@ -1313,9 +1324,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13131324
{
13141325
std::set<std::string> input_tensor_names;
13151326
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
1316-
1317-
OnnxTensorInfoMap input_tensor_infos;
1318-
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
1327+
RETURN_IF_ERROR(
1328+
InputInfos(session_, default_allocator_, input_tensor_infos_));
13191329

13201330
std::set<std::string> overridable_initializer_tensor_names;
13211331
RETURN_IF_ERROR(OverridableInitializerNames(
@@ -1325,12 +1335,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13251335
RETURN_IF_ERROR(OverridableInitializerInfos(
13261336
session_, default_allocator_, overridable_initializer_tensor_infos));
13271337

1328-
if (input_tensor_infos.size() != expected_input_cnt) {
1338+
if (input_tensor_infos_.size() != expected_input_cnt) {
13291339
return TRITONSERVER_ErrorNew(
13301340
TRITONSERVER_ERROR_INVALID_ARG,
13311341
(std::string("unable to load model '") + model_state_->Name() +
13321342
"', configuration expects " + std::to_string(expected_input_cnt) +
1333-
" inputs, model provides " + std::to_string(input_tensor_infos.size()))
1343+
" inputs, model provides " +
1344+
std::to_string(input_tensor_infos_.size()))
13341345
.c_str());
13351346
}
13361347

@@ -1357,8 +1368,9 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13571368

13581369
const auto& tensor_names =
13591370
io_optional ? overridable_initializer_tensor_names : input_tensor_names;
1360-
const auto& tensor_infos =
1361-
io_optional ? overridable_initializer_tensor_infos : input_tensor_infos;
1371+
const auto& tensor_infos = io_optional
1372+
? overridable_initializer_tensor_infos
1373+
: input_tensor_infos_;
13621374
auto iit = tensor_infos.find(io_name);
13631375
if (iit == tensor_infos.end()) {
13641376
RETURN_IF_ERROR(CheckAllowedModelInput(io, tensor_names));
@@ -1419,9 +1431,30 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
14191431
.c_str());
14201432
}
14211433
} else {
1422-
RETURN_IF_ERROR(CompareDimsSupported(
1423-
model_state_->Name(), io_name, iit->second.dims_, dims,
1424-
model_state_->MaxBatchSize(), false /* compare_exact */));
1434+
if (model_state_->MaxBatchSize() != 0 || iit->second.dims_.size() > 0) {
1435+
RETURN_IF_ERROR(CompareDimsSupported(
1436+
model_state_->Name(), io_name, iit->second.dims_, dims,
1437+
model_state_->MaxBatchSize(), false /* compare_exact */));
1438+
} else {
1439+
// if max_batch_size == 0 and is a scalar tensor all the
1440+
// dimensions specified must be equal to 1
1441+
for (auto& dim : dims) {
1442+
if (dim != 1) {
1443+
return TRITONSERVER_ErrorNew(
1444+
TRITONSERVER_ERROR_INVALID_ARG,
1445+
(std::string("unable to load model '") + model_state_->Name() +
1446+
"', scalar tensor '" + io_name +
1447+
"', should only provide 1 in the model configuration when the "
1448+
"model doesn't support batching. Model configuration "
1449+
"provided: " +
1450+
ShapeToString(dims) + ".")
1451+
.c_str());
1452+
}
1453+
}
1454+
1455+
// store the dimension for reference.
1456+
scalar_inputs_[io_name] = dims;
1457+
}
14251458
}
14261459
}
14271460

@@ -1482,9 +1515,30 @@ ModelInstanceState::ValidateOutputs()
14821515

14831516
// The batch output shape doesn't necessarily match the model
14841517
if (model_state_->FindBatchOutput(io_name) == nullptr) {
1485-
RETURN_IF_ERROR(CompareDimsSupported(
1486-
model_state_->Name(), io_name, iit->second.dims_, dims,
1487-
model_state_->MaxBatchSize(), true /* compare_exact */));
1518+
// if max_batch_size == 0 and is a scalar tensor all the
1519+
// dimensions specified must be equal to 1
1520+
if (model_state_->MaxBatchSize() > 0 || iit->second.dims_.size() > 0) {
1521+
RETURN_IF_ERROR(CompareDimsSupported(
1522+
model_state_->Name(), io_name, iit->second.dims_, dims,
1523+
model_state_->MaxBatchSize(), true /* compare_exact */));
1524+
} else {
1525+
for (auto& dim : dims) {
1526+
if (dim != 1) {
1527+
return TRITONSERVER_ErrorNew(
1528+
TRITONSERVER_ERROR_INVALID_ARG,
1529+
(std::string("unable to load model '") + model_state_->Name() +
1530+
"', scalar tensor '" + io_name +
1531+
"', should only provide 1 in the model configuration when the "
1532+
"model doesn't support batching. Model configuration "
1533+
"provided: " +
1534+
ShapeToString(dims) + ".")
1535+
.c_str());
1536+
}
1537+
}
1538+
1539+
// store the dimension for reference.
1540+
scalar_outputs_[io_name] = dims;
1541+
}
14881542
}
14891543
}
14901544

@@ -1900,13 +1954,34 @@ ModelInstanceState::SetInputTensors(
19001954
input_name, nullptr, 0, allowed_input_types, &input_buffer,
19011955
&batchn_byte_size, &memory_type, &memory_type_id));
19021956

1957+
auto iti = input_tensor_infos_.find(input_name);
1958+
if (iti == input_tensor_infos_.end()) {
1959+
return TRITONSERVER_ErrorNew(
1960+
TRITONSERVER_ERROR_INTERNAL,
1961+
std::string(
1962+
std::string(
1963+
"Failed to retrieve the ONNX input tensor info from '") +
1964+
input_name + "'.")
1965+
.c_str());
1966+
}
1967+
19031968
// Create ORT Tensor
1904-
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1905-
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1906-
: cpu_allocator_info_,
1907-
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1908-
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1909-
&input_tensors_.back()));
1969+
if (iti->second.dims_.size() == 0) {
1970+
// scalar tensor
1971+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1972+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1973+
: cpu_allocator_info_,
1974+
(void*)input_buffer, batchn_byte_size, nullptr /* scalar */,
1975+
0 /* number of dims */, ConvertToOnnxDataType(input_datatype),
1976+
&input_tensors_.back()));
1977+
} else {
1978+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1979+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1980+
: cpu_allocator_info_,
1981+
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1982+
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1983+
&input_tensors_.back()));
1984+
}
19101985
RETURN_IF_ORT_ERROR(
19111986
ort_api->BindInput(io_binding_, input_name, input_tensors_.back()));
19121987
} else {
@@ -2283,6 +2358,22 @@ ModelInstanceState::ReadOutputTensors(
22832358
batchn_shape, dtype, output_tensor, &output_buffer, string_buffers,
22842359
offsets));
22852360

2361+
// If the number of dimensions is equal to zero, it means that it is a
2362+
// scalar and it would use the dimensions specified in the mdel
2363+
// configuration.
2364+
if (batchn_shape.size() == 0) {
2365+
auto scalar_output_dims_it = scalar_outputs_.find(name);
2366+
if (scalar_output_dims_it == scalar_outputs_.end()) {
2367+
return TRITONSERVER_ErrorNew(
2368+
TRITONSERVER_ERROR_INTERNAL,
2369+
std::string(
2370+
"Failed to find the scalar output dimension for " + name +
2371+
" in the model configuration.")
2372+
.c_str());
2373+
}
2374+
batchn_shape = scalar_output_dims_it->second;
2375+
}
2376+
22862377
if (output_tensor_pair.first != -1) {
22872378
if (dtype == TRITONSERVER_TYPE_BYTES) {
22882379
auto content = string_buffers.back().data();

0 commit comments

Comments
 (0)