Skip to content

Commit 8a1d1a3

Browse files
authored
Add scalar support in ORT backend (#213)
* Add scalar support in ORT backend * Review edits * Review edit
1 parent 4b88138 commit 8a1d1a3

File tree

1 file changed

+107
-20
lines changed

1 file changed

+107
-20
lines changed

src/onnxruntime.cc

+107-20
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,11 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
885885
triton::common::TritonJson::Value reshape_dims(
886886
ModelConfig(), triton::common::TritonJson::ValueType::ARRAY);
887887
RETURN_IF_ERROR(reshape.Add("shape", std::move(reshape_dims)));
888-
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
888+
// Empty reshape with `max_batch_size` indicates a scalar tensor in the
889+
// model configuration which is not a valid model configuration.
890+
if (MaxBatchSize() > 0) {
891+
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
892+
}
889893
}
890894
RETURN_IF_ERROR(io.Add("dims", std::move(dims)));
891895
RETURN_IF_ERROR(ios.Append(std::move(io)));
@@ -998,6 +1002,12 @@ class ModelInstanceState : public BackendModelInstance {
9981002
// map of output name -> tensor info
9991003
OnnxTensorInfoMap output_tensor_infos_;
10001004

1005+
// map of input name -> tensor info
1006+
OnnxTensorInfoMap input_tensor_infos_;
1007+
1008+
// A map from scalar output tensors to the dimension specified in model config
1009+
std::unordered_map<std::string, std::vector<int64_t>> scalar_outputs_;
1010+
10011011
// Onnx Runtime variables that will be reset and used for every run
10021012
// on this instance.
10031013
std::vector<OrtValue*> input_tensors_;
@@ -1313,9 +1323,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13131323
{
13141324
std::set<std::string> input_tensor_names;
13151325
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
1316-
1317-
OnnxTensorInfoMap input_tensor_infos;
1318-
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
1326+
RETURN_IF_ERROR(
1327+
InputInfos(session_, default_allocator_, input_tensor_infos_));
13191328

13201329
std::set<std::string> overridable_initializer_tensor_names;
13211330
RETURN_IF_ERROR(OverridableInitializerNames(
@@ -1325,12 +1334,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13251334
RETURN_IF_ERROR(OverridableInitializerInfos(
13261335
session_, default_allocator_, overridable_initializer_tensor_infos));
13271336

1328-
if (input_tensor_infos.size() != expected_input_cnt) {
1337+
if (input_tensor_infos_.size() != expected_input_cnt) {
13291338
return TRITONSERVER_ErrorNew(
13301339
TRITONSERVER_ERROR_INVALID_ARG,
13311340
(std::string("unable to load model '") + model_state_->Name() +
13321341
"', configuration expects " + std::to_string(expected_input_cnt) +
1333-
" inputs, model provides " + std::to_string(input_tensor_infos.size()))
1342+
" inputs, model provides " +
1343+
std::to_string(input_tensor_infos_.size()))
13341344
.c_str());
13351345
}
13361346

@@ -1357,8 +1367,9 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13571367

13581368
const auto& tensor_names =
13591369
io_optional ? overridable_initializer_tensor_names : input_tensor_names;
1360-
const auto& tensor_infos =
1361-
io_optional ? overridable_initializer_tensor_infos : input_tensor_infos;
1370+
const auto& tensor_infos = io_optional
1371+
? overridable_initializer_tensor_infos
1372+
: input_tensor_infos_;
13621373
auto iit = tensor_infos.find(io_name);
13631374
if (iit == tensor_infos.end()) {
13641375
RETURN_IF_ERROR(CheckAllowedModelInput(io, tensor_names));
@@ -1419,9 +1430,28 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
14191430
.c_str());
14201431
}
14211432
} else {
1422-
RETURN_IF_ERROR(CompareDimsSupported(
1423-
model_state_->Name(), io_name, iit->second.dims_, dims,
1424-
model_state_->MaxBatchSize(), false /* compare_exact */));
1433+
// Only compare the dimensions if the tensor is not scalar
1434+
if (iit->second.dims_.size() != 0) {
1435+
RETURN_IF_ERROR(CompareDimsSupported(
1436+
model_state_->Name(), io_name, iit->second.dims_, dims,
1437+
model_state_->MaxBatchSize(), false /* compare_exact */));
1438+
} else {
1439+
// if max_batch_size == 0 and is a scalar tensor all the
1440+
// dimensions specified must be equal to 1
1441+
for (auto& dim : dims) {
1442+
if (dim != 1) {
1443+
return TRITONSERVER_ErrorNew(
1444+
TRITONSERVER_ERROR_INVALID_ARG,
1445+
(std::string("unable to load model '") + model_state_->Name() +
1446+
"', scalar tensor '" + io_name +
1447+
"', should only provide 1 in the model configuration when the "
1448+
"model doesn't support batching. Model configuration "
1449+
"provided: " +
1450+
ShapeToString(dims) + ".")
1451+
.c_str());
1452+
}
1453+
}
1454+
}
14251455
}
14261456
}
14271457

@@ -1482,9 +1512,29 @@ ModelInstanceState::ValidateOutputs()
14821512

14831513
// The batch output shape doesn't necessarily match the model
14841514
if (model_state_->FindBatchOutput(io_name) == nullptr) {
1485-
RETURN_IF_ERROR(CompareDimsSupported(
1486-
model_state_->Name(), io_name, iit->second.dims_, dims,
1487-
model_state_->MaxBatchSize(), true /* compare_exact */));
1515+
// Only compare the dimensions if the tensor is not scalar
1516+
if (iit->second.dims_.size() != 0) {
1517+
RETURN_IF_ERROR(CompareDimsSupported(
1518+
model_state_->Name(), io_name, iit->second.dims_, dims,
1519+
model_state_->MaxBatchSize(), true /* compare_exact */));
1520+
} else {
1521+
for (auto& dim : dims) {
1522+
if (dim != 1) {
1523+
return TRITONSERVER_ErrorNew(
1524+
TRITONSERVER_ERROR_INVALID_ARG,
1525+
(std::string("unable to load model '") + model_state_->Name() +
1526+
"', scalar tensor '" + io_name +
1527+
"', should only provide 1 in the model configuration when the "
1528+
"model doesn't support batching. Model configuration "
1529+
"provided: " +
1530+
ShapeToString(dims) + ".")
1531+
.c_str());
1532+
}
1533+
}
1534+
1535+
// store the dimension for reference.
1536+
scalar_outputs_[io_name] = dims;
1537+
}
14881538
}
14891539
}
14901540

@@ -1900,13 +1950,34 @@ ModelInstanceState::SetInputTensors(
19001950
input_name, nullptr, 0, allowed_input_types, &input_buffer,
19011951
&batchn_byte_size, &memory_type, &memory_type_id));
19021952

1953+
auto iti = input_tensor_infos_.find(input_name);
1954+
if (iti == input_tensor_infos_.end()) {
1955+
return TRITONSERVER_ErrorNew(
1956+
TRITONSERVER_ERROR_INTERNAL,
1957+
std::string(
1958+
std::string(
1959+
"Failed to retrieve the ONNX input tensor info from '") +
1960+
input_name + "'.")
1961+
.c_str());
1962+
}
1963+
19031964
// Create ORT Tensor
1904-
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1905-
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1906-
: cpu_allocator_info_,
1907-
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1908-
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1909-
&input_tensors_.back()));
1965+
if (iti->second.dims_.size() == 0) {
1966+
// scalar tensor
1967+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1968+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1969+
: cpu_allocator_info_,
1970+
(void*)input_buffer, batchn_byte_size, nullptr /* scalar */,
1971+
0 /* number of dims */, ConvertToOnnxDataType(input_datatype),
1972+
&input_tensors_.back()));
1973+
} else {
1974+
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
1975+
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
1976+
: cpu_allocator_info_,
1977+
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
1978+
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
1979+
&input_tensors_.back()));
1980+
}
19101981
RETURN_IF_ORT_ERROR(
19111982
ort_api->BindInput(io_binding_, input_name, input_tensors_.back()));
19121983
} else {
@@ -2283,6 +2354,22 @@ ModelInstanceState::ReadOutputTensors(
22832354
batchn_shape, dtype, output_tensor, &output_buffer, string_buffers,
22842355
offsets));
22852356

2357+
// If the number of dimensions is equal to zero, it means that it is a
2358+
// scalar and it would use the dimensions specified in the model
2359+
// configuration.
2360+
if (batchn_shape.size() == 0) {
2361+
auto scalar_output_dims_it = scalar_outputs_.find(name);
2362+
if (scalar_output_dims_it == scalar_outputs_.end()) {
2363+
return TRITONSERVER_ErrorNew(
2364+
TRITONSERVER_ERROR_INTERNAL,
2365+
std::string(
2366+
"Failed to find the scalar output dimension for " + name +
2367+
" in the model configuration.")
2368+
.c_str());
2369+
}
2370+
batchn_shape = scalar_output_dims_it->second;
2371+
}
2372+
22862373
if (output_tensor_pair.first != -1) {
22872374
if (dtype == TRITONSERVER_TYPE_BYTES) {
22882375
auto content = string_buffers.back().data();

0 commit comments

Comments
 (0)