Skip to content

Commit 7e4c8cc

Browse files
committed
Initial commit
1 parent 8571a80 commit 7e4c8cc

File tree

4 files changed

+148
-16
lines changed

4 files changed

+148
-16
lines changed

include/triton/backend/backend_common.h

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,24 +320,59 @@ std::string ShapeToString(const std::vector<int64_t>& shape);
320320
///
321321
/// \param dims The shape dimensions.
322322
/// \param dims_count The number of dimensions.
323-
/// \return The number of elements.
323+
/// \return The number of elements,
324+
/// -1 if unable to determine the number,
325+
/// -2 if the shape contains an invalid dim,
326+
/// or -3 if the number is too large to represent as an int64_t.
324327
int64_t GetElementCount(const int64_t* dims, const size_t dims_count);
325328

326329
/// Return the number of elements of a shape.
327330
///
328331
/// \param shape The shape as a vector of dimensions.
329-
/// \return The number of elements.
332+
/// \return The number of elements,
333+
/// -1 if unable to determine the number,
334+
/// -2 if the shape contains an invalid dim,
335+
/// or -3 if the number is too large to represent as an int64_t.
330336
int64_t GetElementCount(const std::vector<int64_t>& shape);
331337

338+
/// Return the number of elements of a shape with error checking.
339+
///
340+
/// \param dims The shape dimensions.
341+
/// \param dims_count The number of dimensions.
342+
/// \param cnt Returns the number of elements.
343+
/// \return a TRITONSERVER_Error indicating success or failure.
344+
TRITONSERVER_Error* GetElementCount(
345+
const int64_t* dims, const size_t dims_count, int64_t* cnt);
346+
347+
/// Return the number of elements of a shape with error checking.
348+
///
349+
/// \param shape The shape as a vector of dimensions.
350+
/// \param cnt Returns the number of elements.
351+
/// \return a TRITONSERVER_Error indicating success or failure.
352+
TRITONSERVER_Error* GetElementCount(
353+
const std::vector<int64_t>& shape, int64_t* cnt);
354+
332355
/// Get the size, in bytes, of a tensor based on datatype and
333356
/// shape.
334357
/// \param dtype The data-type.
335358
/// \param dims The shape.
336-
/// \return The size, in bytes, of the corresponding tensor, or -1 if
337-
/// unable to determine the size.
359+
/// \return The size, in bytes, of the corresponding tensor,
360+
/// -1 if unable to determine the size,
361+
/// -2 if the shape contains an invalid dim,
362+
/// or -3 if the size is too large to represent as an int64_t.
338363
int64_t GetByteSize(
339364
const TRITONSERVER_DataType& dtype, const std::vector<int64_t>& dims);
340365

366+
/// Get the size, in bytes, of a tensor based on datatype and
367+
/// shape with error checking.
368+
/// \param dtype The data-type.
369+
/// \param dims The shape.
370+
/// \param size Returns the size, in bytes, of the corresponding tensor.
371+
/// \return a TRITONSERVER_Error indicating success or failure.
372+
TRITONSERVER_Error* GetByteSize(
373+
const TRITONSERVER_DataType& dtype, const std::vector<int64_t>& dims,
374+
int64_t* size);
375+
341376
/// Get an input tensor's contents into a buffer. This overload expects
342377
/// both 'buffer' and buffers of the input to be in CPU.
343378
///

src/backend_common.cc

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,18 @@ GetElementCount(const int64_t* dims, const size_t dims_count)
166166
for (size_t i = 0; i < dims_count; i++) {
167167
if (dims[i] == WILDCARD_DIM) {
168168
return -1;
169+
} else if (dims[i] < 0) {
170+
return -2;
169171
}
170172

171173
if (first) {
172174
cnt = dims[i];
173175
first = false;
174176
} else {
177+
// Check for overflow before multiplication
178+
if (cnt > INT64_MAX / dims[i]) {
179+
return -3;
180+
}
175181
cnt *= dims[i];
176182
}
177183
}
@@ -185,6 +191,42 @@ GetElementCount(const std::vector<int64_t>& shape)
185191
return GetElementCount(shape.data(), shape.size());
186192
}
187193

194+
TRITONSERVER_Error*
195+
GetElementCount(const int64_t* dims, const size_t dims_count, int64_t* cnt)
196+
{
197+
*cnt = GetElementCount(dims, dims_count);
198+
if (*cnt == -2) {
199+
return TRITONSERVER_ErrorNew(
200+
TRITONSERVER_ERROR_INVALID_ARG,
201+
(std::string("shape") + ShapeToString(dims, dims_count) +
202+
" contains an invalid dim.")
203+
.c_str());
204+
} else if (*cnt == -3) {
205+
return TRITONSERVER_ErrorNew(
206+
TRITONSERVER_ERROR_INVALID_ARG,
207+
"unexpected integer overflow while calculating element count.");
208+
}
209+
return nullptr; // success
210+
}
211+
212+
TRITONSERVER_Error*
213+
GetElementCount(const std::vector<int64_t>& shape, int64_t* cnt)
214+
{
215+
*cnt = GetElementCount(shape.data(), shape.size());
216+
if (*cnt == -2) {
217+
return TRITONSERVER_ErrorNew(
218+
TRITONSERVER_ERROR_INVALID_ARG,
219+
(std::string("shape") + ShapeToString(shape) +
220+
" contains an invalid dim.")
221+
.c_str());
222+
} else if (*cnt == -3) {
223+
return TRITONSERVER_ErrorNew(
224+
TRITONSERVER_ERROR_INVALID_ARG,
225+
"unexpected integer overflow while calculating element count.");
226+
}
227+
return nullptr; // success
228+
}
229+
188230
int64_t
189231
GetByteSize(
190232
const TRITONSERVER_DataType& dtype, const std::vector<int64_t>& dims)
@@ -195,13 +237,36 @@ GetByteSize(
195237
}
196238

197239
int64_t cnt = GetElementCount(dims);
198-
if (cnt == -1) {
199-
return -1;
240+
if (cnt < 0) {
241+
return cnt;
200242
}
201243

244+
if ((cnt > INT64_MAX / dt_size)) {
245+
return -3;
246+
}
202247
return cnt * dt_size;
203248
}
204249

250+
TRITONSERVER_Error*
251+
GetByteSize(
252+
const TRITONSERVER_DataType& dtype, const std::vector<int64_t>& dims,
253+
int64_t* size)
254+
{
255+
*size = GetByteSize(dtype, dims);
256+
if (*size == -2) {
257+
return TRITONSERVER_ErrorNew(
258+
TRITONSERVER_ERROR_INVALID_ARG,
259+
(std::string("shape") + ShapeToString(dims) +
260+
" contains an invalid dim.")
261+
.c_str());
262+
} else if (*size == -3) {
263+
return TRITONSERVER_ErrorNew(
264+
TRITONSERVER_ERROR_INVALID_ARG,
265+
"unexpected integer overflow while calculating byte size.");
266+
}
267+
return nullptr; // success
268+
}
269+
205270
TRITONSERVER_Error*
206271
ReadInputTensor(
207272
TRITONBACKEND_Request* request, const std::string& input_name, char* buffer,

src/backend_input_collector.cc

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -762,11 +762,12 @@ BackendInputCollector::BatchInputShape(
762762
requests_[req_idx], source_input.c_str(), &input));
763763
const int64_t* shape_arr;
764764
uint32_t dims_count;
765+
int64_t element_cnt = 0;
765766
RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy(
766767
input, host_policy_cstr_, nullptr, nullptr, &shape_arr, &dims_count,
767768
nullptr, nullptr));
768-
(*shape)[0] =
769-
std::max((*shape)[0], GetElementCount(shape_arr, dims_count));
769+
RETURN_IF_ERROR(GetElementCount(shape_arr, dims_count, &element_cnt));
770+
(*shape)[0] = std::max((*shape)[0], element_cnt);
770771
}
771772
break;
772773
}
@@ -841,7 +842,9 @@ BackendInputCollector::ProcessBatchInput(
841842
// Calculate the byte size of the buffer
842843
std::vector<int64_t> shape;
843844
RETURN_IF_ERROR(BatchInputShape(batch_input, &shape));
844-
*dst_buffer_byte_size = GetByteSize(batch_input.DataType(), shape);
845+
RETURN_IF_ERROR(GetByteSize(
846+
batch_input.DataType(), shape,
847+
reinterpret_cast<int64_t*>(dst_buffer_byte_size)));
845848
BackendMemory* backend_memory = nullptr;
846849
for (const auto& allowed_type : allowed_input_types) {
847850
std::vector<BackendMemory::AllocationType> alloc_types;
@@ -945,11 +948,31 @@ BackendInputCollector::ProcessBatchInput(
945948
const auto& source_input = batch_input.SourceInputs()[0];
946949
if (data_type == TRITONSERVER_TYPE_FP32) {
947950
*reinterpret_cast<float*>(input_buffer) = 0;
951+
if (*dst_buffer_byte_size < sizeof(float)) {
952+
return TRITONSERVER_ErrorNew(
953+
TRITONSERVER_ERROR_INVALID_ARG,
954+
(std::string(
955+
"Unexpected total byte size for batch input. Expect >= ") +
956+
std::to_string(sizeof(float)) + ", got " +
957+
std::to_string(*dst_buffer_byte_size))
958+
.c_str());
959+
}
960+
948961
RETURN_IF_ERROR(SetAccumulatedElementCount<float>(
949962
source_input, input_buffer + sizeof(float),
950963
*dst_buffer_byte_size - sizeof(float)));
951964
} else {
952965
*reinterpret_cast<int32_t*>(input_buffer) = 0;
966+
if (*dst_buffer_byte_size < sizeof(int32_t)) {
967+
return TRITONSERVER_ErrorNew(
968+
TRITONSERVER_ERROR_INVALID_ARG,
969+
(std::string(
970+
"Unexpected total byte size for batch input. Expect >= ") +
971+
std::to_string(sizeof(int32_t)) + ", got " +
972+
std::to_string(*dst_buffer_byte_size))
973+
.c_str());
974+
}
975+
953976
RETURN_IF_ERROR(SetAccumulatedElementCount<int32_t>(
954977
source_input, input_buffer + sizeof(int32_t),
955978
*dst_buffer_byte_size - sizeof(int32_t)));
@@ -1011,11 +1034,12 @@ BackendInputCollector::SetElementCount(
10111034
requests_[req_idx], source_input.c_str(), &input));
10121035
const int64_t* shape;
10131036
uint32_t dims_count;
1037+
int64_t element_cnt = 0;
10141038
RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy(
10151039
input, host_policy_cstr_, nullptr, nullptr, &shape, &dims_count,
10161040
nullptr, nullptr));
1017-
*(reinterpret_cast<T*>(buffer) + req_idx) =
1018-
GetElementCount(shape, dims_count);
1041+
RETURN_IF_ERROR(GetElementCount(shape, dims_count, &element_cnt));
1042+
*(reinterpret_cast<T*>(buffer) + req_idx) = element_cnt;
10191043
buffer_offset += sizeof(T);
10201044
}
10211045
// Set the rest of the buffer to 0
@@ -1046,10 +1070,12 @@ BackendInputCollector::SetAccumulatedElementCount(
10461070
requests_[req_idx], source_input.c_str(), &input));
10471071
const int64_t* shape;
10481072
uint32_t dims_count;
1073+
int64_t element_cnt = 0;
10491074
RETURN_IF_ERROR(TRITONBACKEND_InputPropertiesForHostPolicy(
10501075
input, host_policy_cstr_, nullptr, nullptr, &shape, &dims_count,
10511076
nullptr, nullptr));
1052-
accumulated_element_count += GetElementCount(shape, dims_count);
1077+
RETURN_IF_ERROR(GetElementCount(shape, dims_count, &element_cnt));
1078+
accumulated_element_count += element_cnt;
10531079
*(reinterpret_cast<T*>(buffer) + req_idx) = accumulated_element_count;
10541080
buffer_offset += sizeof(T);
10551081
}

src/backend_output_responder.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ BackendOutputResponder::ProcessTensor(
109109
batch_size_offset += shape[0];
110110
}
111111

112-
const size_t tensor_byte_size = GetByteSize(datatype, batchn_shape);
112+
int64_t tensor_byte_size = 0;
113+
RESPOND_AND_SET_NULL_IF_ERROR(
114+
&response, GetByteSize(datatype, batchn_shape, &tensor_byte_size));
113115

114116
TRITONBACKEND_Output* response_output;
115117
if (response != nullptr) {
@@ -218,7 +220,9 @@ BackendOutputResponder::ProcessStateTensor(
218220
batch_size_offset += shape[0];
219221
}
220222

221-
const size_t tensor_byte_size = GetByteSize(datatype, batchn_shape);
223+
int64_t tensor_byte_size = 0;
224+
RESPOND_AND_SET_NULL_IF_ERROR(
225+
&response, GetByteSize(datatype, batchn_shape, &tensor_byte_size));
222226

223227
TRITONBACKEND_State* output_state;
224228
if (response != nullptr) {
@@ -554,8 +558,10 @@ BackendOutputResponder::ProcessBatchOutput(
554558
}
555559
}
556560

557-
const size_t tensor_byte_size =
558-
GetByteSize(datatype, output_batchn_shape);
561+
int64_t tensor_byte_size = 0;
562+
RESPOND_AND_SET_NULL_IF_ERROR(
563+
&response,
564+
GetByteSize(datatype, output_batchn_shape, &tensor_byte_size));
559565

560566
TRITONBACKEND_Output* response_output;
561567
if (response != nullptr) {

0 commit comments

Comments
 (0)