Skip to content

Commit a855fa5

Browse files
committed
Add scalar support in ORT backend
1 parent 4b88138 commit a855fa5

File tree

1 file changed

+108
-20
lines changed

1 file changed

+108
-20
lines changed

src/onnxruntime.cc

+108-20
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,9 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
885885
triton::common::TritonJson::Value reshape_dims(
886886
ModelConfig(), triton::common::TritonJson::ValueType::ARRAY);
887887
RETURN_IF_ERROR(reshape.Add("shape", std::move(reshape_dims)));
888-
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
888+
if (MaxBatchSize() > 0) {
889+
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
890+
}
889891
}
890892
RETURN_IF_ERROR(io.Add("dims", std::move(dims)));
891893
RETURN_IF_ERROR(ios.Append(std::move(io)));
@@ -998,6 +1000,12 @@ class ModelInstanceState : public BackendModelInstance {
9981000
// map of output name -> tensor info
9991001
OnnxTensorInfoMap output_tensor_infos_;
10001002

1003+
// map of input name -> tensor info
1004+
OnnxTensorInfoMap input_tensor_infos_;
1005+
1006+
// A map from scalar output tensors to the dimension specified in model config
1007+
std::unordered_map<std::string, std::vector<int64_t>> scalar_outputs_;
1008+
10011009
// Onnx Runtime variables that will be reset and used for every run
10021010
// on this instance.
10031011
std::vector<OrtValue*> input_tensors_;
@@ -1313,9 +1321,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13131321
{
13141322
std::set<std::string> input_tensor_names;
13151323
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
1316-
1317-
OnnxTensorInfoMap input_tensor_infos;
1318-
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
1324+
RETURN_IF_ERROR(
1325+
InputInfos(session_, default_allocator_, input_tensor_infos_));
13191326

13201327
std::set<std::string> overridable_initializer_tensor_names;
13211328
RETURN_IF_ERROR(OverridableInitializerNames(
@@ -1325,12 +1332,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13251332
RETURN_IF_ERROR(OverridableInitializerInfos(
13261333
session_, default_allocator_, overridable_initializer_tensor_infos));
13271334

1328-
if (input_tensor_infos.size() != expected_input_cnt) {
1335+
if (input_tensor_infos_.size() != expected_input_cnt) {
13291336
return TRITONSERVER_ErrorNew(
13301337
TRITONSERVER_ERROR_INVALID_ARG,
13311338
(std::string("unable to load model '") + model_state_->Name() +
13321339
"', configuration expects " + std::to_string(expected_input_cnt) +
1333-
" inputs, model provides " + std::to_string(input_tensor_infos.size()))
1340+
" inputs, model provides " +
1341+
std::to_string(input_tensor_infos_.size()))
13341342
.c_str());
13351343
}
13361344

@@ -1357,8 +1365,9 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13571365

13581366
const auto& tensor_names =
13591367
io_optional ? overridable_initializer_tensor_names : input_tensor_names;
1360-
const auto& tensor_infos =
1361-
io_optional ? overridable_initializer_tensor_infos : input_tensor_infos;
1368+
const auto& tensor_infos = io_optional
1369+
? overridable_initializer_tensor_infos
1370+
: input_tensor_infos_;
13621371
auto iit = tensor_infos.find(io_name);
13631372
if (iit == tensor_infos.end()) {
13641373
RETURN_IF_ERROR(CheckAllowedModelInput(io, tensor_names));
@@ -1419,9 +1428,30 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
14191428
.c_str());
14201429
}
14211430
} else {
1422-
RETURN_IF_ERROR(CompareDimsSupported(
1423-
model_state_->Name(), io_name, iit->second.dims_, dims,
1424-
model_state_->MaxBatchSize(), false /* compare_exact */));
1431+
if (model_state_->MaxBatchSize() != 0 || iit->second.dims_.size() > 0) {
1432+
RETURN_IF_ERROR(CompareDimsSupported(
1433+
model_state_->Name(), io_name, iit->second.dims_, dims,
1434+
model_state_->MaxBatchSize(), false /* compare_exact */));
1435+
} else {
1436+
// if max_batch_size == 0 and is a scalar tensor all the
1437+
// dimensions specified must be equal to 1
1438+
for (auto& dim : dims) {
1439+
if (dim != 1) {
1440+
return TRITONSERVER_ErrorNew(
1441+
TRITONSERVER_ERROR_INVALID_ARG,
1442+
(std::string("unable to load model '") + model_state_->Name() +
1443+
"', scalar tensor '" + io_name +
1444+
"', should only provide 1 in the model configuration when the "
1445+
"model doesn't support batching. Model configuration "
1446+
"provided: " +
1447+
ShapeToString(dims) + ".")
1448+
.c_str());
1449+
}
1450+
}
1451+
1452+
// store the dimension for reference.
1453+
scalar_inputs_[io_name] = dims;
1454+
}
14251455
}
14261456
}
14271457

@@ -1482,9 +1512,30 @@ ModelInstanceState::ValidateOutputs()
14821512

14831513
// The batch output shape doesn't necessarily match the model
14841514
if (model_state_->FindBatchOutput(io_name) == nullptr) {
1485-
RETURN_IF_ERROR(CompareDimsSupported(
1486-
model_state_->Name(), io_name, iit->second.dims_, dims,
1487-
model_state_->MaxBatchSize(), true /* compare_exact */));
1515+
// if max_batch_size == 0 and is a scalar tensor all the
1516+
// dimensions specified must be equal to 1
1517+
if (model_state_->MaxBatchSize() > 0 || iit->second.dims_.size() > 0) {
1518+
RETURN_IF_ERROR(CompareDimsSupported(
1519+
model_state_->Name(), io_name, iit->second.dims_, dims,
1520+
model_state_->MaxBatchSize(), true /* compare_exact */));
1521+
} else {
1522+
for (auto& dim : dims) {
1523+
if (dim != 1) {
1524+
return TRITONSERVER_ErrorNew(
1525+
TRITONSERVER_ERROR_INVALID_ARG,
1526+
(std::string("unable to load model '") + model_state_->Name() +
1527+
"', scalar tensor '" + io_name +
1528+
"', should only provide 1 in the model configuration when the "
1529+
"model doesn't support batching. Model configuration "
1530+
"provided: " +
1531+
ShapeToString(dims) + ".")
1532+
.c_str());
1533+
}
1534+
}
1535+
1536+
// store the dimension for reference.
1537+
scalar_outputs_[io_name] = dims;
1538+
}
14881539
}
14891540
}
14901541

@@ -1900,13 +1951,34 @@ ModelInstanceState::SetInputTensors(
19001951
input_name, nullptr, 0, allowed_input_types, &input_buffer,
19011952
&batchn_byte_size, &memory_type, &memory_type_id));
19021953

1954+
auto iti = input_tensor_infos_.find(input_name);
1955+
if (iti == input_tensor_infos_.end()) {
1956+
return TRITONSERVER_ErrorNew(
1957+
TRITONSERVER_ERROR_INTERNAL,
1958+
std::string(
1959+
std::string(
1960+
"Failed to retrieve the ONNX input tensor info from '") +
1961+
input_name + "'.")
1962+
.c_str());
1963+
}
1964+
19031965
// Create ORT Tensor
1904-
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1905-
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1906-
: cpu_allocator_info_,
1907-
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1908-
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1909-
&input_tensors_.back()));
1966+
if (iti->second.dims_.size() == 0) {
1967+
// scalar tensor
1968+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1969+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1970+
: cpu_allocator_info_,
1971+
(void*)input_buffer, batchn_byte_size, nullptr /* scalar */,
1972+
0 /* number of dims */, ConvertToOnnxDataType(input_datatype),
1973+
&input_tensors_.back()));
1974+
} else {
1975+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1976+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1977+
: cpu_allocator_info_,
1978+
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1979+
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1980+
&input_tensors_.back()));
1981+
}
19101982
RETURN_IF_ORT_ERROR(
19111983
ort_api->BindInput(io_binding_, input_name, input_tensors_.back()));
19121984
} else {
@@ -2283,6 +2355,22 @@ ModelInstanceState::ReadOutputTensors(
22832355
batchn_shape, dtype, output_tensor, &output_buffer, string_buffers,
22842356
offsets));
22852357

2358+
// If the number of dimensions is equal to zero, it means that it is a
2359+
// scalar and it would use the dimensions specified in the mdel
2360+
// configuration.
2361+
if (batchn_shape.size() == 0) {
2362+
auto scalar_output_dims_it = scalar_outputs_.find(name);
2363+
if (scalar_output_dims_it == scalar_outputs_.end()) {
2364+
return TRITONSERVER_ErrorNew(
2365+
TRITONSERVER_ERROR_INTERNAL,
2366+
std::string(
2367+
"Failed to find the scalar output dimension for " + name +
2368+
" in the model configuration.")
2369+
.c_str());
2370+
}
2371+
batchn_shape = scalar_output_dims_it->second;
2372+
}
2373+
22862374
if (output_tensor_pair.first != -1) {
22872375
if (dtype == TRITONSERVER_TYPE_BYTES) {
22882376
auto content = string_buffers.back().data();

0 commit comments

Comments
 (0)