Skip to content

Commit 3ac4eb1

Browse files
authored
Adding the support tracing of child models invoked from a BLS model (#277)
* Adding tracing for bls * Added access to trace from BLS request creation * Added tracing to decoupled * clang format * Adding InferenceTrace object
1 parent a9e6a77 commit 3ac4eb1

File tree

5 files changed

+53
-10
lines changed

5 files changed

+53
-10
lines changed

src/infer_request.cc

+11-2
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ InferRequest::InferRequest(
4444
const std::string& model_name, const int64_t model_version,
4545
const std::string& parameters, const uint32_t flags, const int32_t timeout,
4646
const intptr_t response_factory_address, const intptr_t request_address,
47-
const PreferredMemory& preferred_memory)
47+
const PreferredMemory& preferred_memory, const InferenceTrace& trace)
4848
: request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs),
4949
requested_output_names_(requested_output_names), model_name_(model_name),
5050
model_version_(model_version), parameters_(parameters), flags_(flags),
5151
timeout_(timeout), response_factory_address_(response_factory_address),
52-
request_address_(request_address), preferred_memory_(preferred_memory)
52+
request_address_(request_address), preferred_memory_(preferred_memory),
53+
trace_(trace)
5354
{
5455
for (auto& input : inputs) {
5556
if (!input) {
@@ -166,6 +167,12 @@ InferRequest::GetPreferredMemory()
166167
return preferred_memory_;
167168
}
168169

170+
InferenceTrace&
171+
InferRequest::Trace()
172+
{
173+
return trace_;
174+
}
175+
169176
void
170177
InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
171178
{
@@ -191,6 +198,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
191198
infer_request_shm_ptr_->is_decoupled = is_decoupled_;
192199
infer_request_shm_ptr_->timeout = timeout_;
193200
infer_request_shm_ptr_->preferred_memory = preferred_memory_;
201+
infer_request_shm_ptr_->trace = trace_;
194202

195203
output_names_handle_shm_ptr_ =
196204
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
@@ -368,6 +376,7 @@ InferRequest::InferRequest(
368376
is_decoupled_ = infer_request_shm_ptr_->is_decoupled;
369377
timeout_ = infer_request_shm_ptr_->timeout;
370378
preferred_memory_ = infer_request_shm_ptr_->preferred_memory;
379+
trace_ = infer_request_shm_ptr_->trace;
371380

372381
#ifdef TRITON_PB_STUB
373382
response_sender_ = std::make_shared<ResponseSender>(

src/infer_request.h

+16-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ namespace triton { namespace backend { namespace python {
4141

4242
class Stub;
4343

44+
//
45+
// Inference Trace
46+
//
47+
struct InferenceTrace {
48+
#ifndef TRITON_PB_STUB
49+
TRITONSERVER_InferenceTrace* triton_trace_;
50+
#else
51+
void* triton_trace_;
52+
#endif
53+
};
54+
4455
//
4556
// Inference Request
4657
//
@@ -55,6 +66,7 @@ struct InferRequestShm {
5566
bool is_decoupled;
5667
int32_t timeout;
5768
PreferredMemory preferred_memory;
69+
InferenceTrace trace;
5870
};
5971

6072
class InferRequest {
@@ -68,7 +80,8 @@ class InferRequest {
6880
const int32_t timeout = 0, const intptr_t response_factory_address = 0,
6981
const intptr_t request_address = 0,
7082
const PreferredMemory& preferred_memory =
71-
PreferredMemory(PreferredMemory::DEFAULT, 0));
83+
PreferredMemory(PreferredMemory::DEFAULT, 0),
84+
const InferenceTrace& trace = {.triton_trace_ = nullptr});
7285

7386
const std::vector<std::shared_ptr<PbTensor>>& Inputs();
7487
const std::string& RequestId();
@@ -84,6 +97,7 @@ class InferRequest {
8497
bool IsDecoupled();
8598
void SetIsDecoupled(const bool is_decoupled);
8699
PreferredMemory& GetPreferredMemory();
100+
InferenceTrace& Trace();
87101

88102
#ifdef TRITON_PB_STUB
89103
std::shared_ptr<InferResponse> Exec(const bool is_decoupled);
@@ -139,6 +153,7 @@ class InferRequest {
139153
intptr_t request_address_;
140154
bool is_decoupled_;
141155
PreferredMemory preferred_memory_;
156+
InferenceTrace trace_;
142157

143158
// Shared Memory Data Structures
144159
AllocatedSharedMemory<char> infer_request_shm_;

src/pb_stub.cc

+9-3
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,9 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
13621362
.value("TRITONSERVER_MEMORY_CPU", PreferredMemory::MemoryType::CPU)
13631363
.export_values();
13641364

1365+
py::class_<InferenceTrace, std::shared_ptr<InferenceTrace>>(
1366+
module, "InferenceTrace");
1367+
13651368
py::class_<InferRequest, std::shared_ptr<InferRequest>>(
13661369
module, "InferenceRequest")
13671370
.def(
@@ -1371,7 +1374,8 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
13711374
const std::string& model_name,
13721375
const int64_t model_version, const uint32_t flags,
13731376
const int32_t timeout,
1374-
const PreferredMemory& preferred_memory) {
1377+
const PreferredMemory& preferred_memory,
1378+
const InferenceTrace& trace) {
13751379
std::set<std::string> requested_outputs;
13761380
for (auto& requested_output_name : requested_output_names) {
13771381
requested_outputs.emplace(requested_output_name);
@@ -1381,7 +1385,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
13811385
request_id, correlation_id, inputs, requested_outputs,
13821386
model_name, model_version, "" /*parameters*/, flags, timeout,
13831387
0 /*response_factory_address*/, 0 /*request_address*/,
1384-
preferred_memory);
1388+
preferred_memory, trace);
13851389
}),
13861390
py::arg("request_id").none(false) = "",
13871391
py::arg("correlation_id").none(false) = 0,
@@ -1391,7 +1395,8 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
13911395
py::arg("model_version").none(false) = -1,
13921396
py::arg("flags").none(false) = 0, py::arg("timeout").none(false) = 0,
13931397
py::arg("preferred_memory").none(false) =
1394-
PreferredMemory(PreferredMemory::DEFAULT, 0))
1398+
PreferredMemory(PreferredMemory::DEFAULT, 0),
1399+
py::arg("trace").none(false) = nullptr)
13951400
.def(
13961401
"inputs", &InferRequest::Inputs,
13971402
py::return_value_policy::reference_internal)
@@ -1401,6 +1406,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
14011406
.def("set_flags", &InferRequest::SetFlags)
14021407
.def("timeout", &InferRequest::Timeout)
14031408
.def("parameters", &InferRequest::Parameters)
1409+
.def("trace", &InferRequest::Trace)
14041410
.def(
14051411
"exec",
14061412
[](std::shared_ptr<InferRequest>& infer_request,

src/python_be.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ ModelInstanceState::SaveRequestsToSharedMemory(
364364
uint32_t flags;
365365
RETURN_IF_ERROR(TRITONBACKEND_RequestFlags(request, &flags));
366366

367+
TRITONSERVER_InferenceTrace* triton_trace;
368+
RETURN_IF_ERROR(TRITONBACKEND_RequestTrace(request, &triton_trace));
369+
370+
InferenceTrace trace = {triton_trace};
371+
367372
std::unique_ptr<InferRequest> infer_request;
368373
if (model_state->IsDecoupled()) {
369374
TRITONBACKEND_ResponseFactory* factory_ptr;
@@ -372,13 +377,15 @@ ModelInstanceState::SaveRequestsToSharedMemory(
372377
id, correlation_id, pb_input_tensors, requested_output_names,
373378
model_state->Name(), model_state->Version(), parameters_string, flags,
374379
0 /* BLS request timeout*/, reinterpret_cast<intptr_t>(factory_ptr),
375-
reinterpret_cast<intptr_t>(request));
380+
reinterpret_cast<intptr_t>(request),
381+
PreferredMemory(PreferredMemory::DEFAULT, 0), trace);
376382
} else {
377383
infer_request = std::make_unique<InferRequest>(
378384
id, correlation_id, pb_input_tensors, requested_output_names,
379385
model_state->Name(), model_state->Version(), parameters_string, flags,
380386
0 /* BLS request timeout*/, 0 /* response_factory_address */,
381-
reinterpret_cast<intptr_t>(request));
387+
reinterpret_cast<intptr_t>(request),
388+
PreferredMemory(PreferredMemory::DEFAULT, 0), trace);
382389
}
383390

384391
RETURN_IF_EXCEPTION(infer_request->SaveToSharedMemory(Stub()->ShmPool()));

src/request_executor.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@ RequestExecutor::Infer(
359359
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetReleaseCallback(
360360
irequest, InferRequestComplete, nullptr /* request_release_userp */));
361361

362+
TRITONSERVER_InferenceTrace* trace = nullptr;
363+
if (infer_request->Trace().triton_trace_ != nullptr) {
364+
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceTraceSpawnChildTrace(
365+
infer_request->Trace().triton_trace_, &trace));
366+
}
367+
362368
for (auto& infer_input : infer_request->Inputs()) {
363369
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestAddInput(
364370
irequest, infer_input->Name().c_str(),
@@ -388,8 +394,8 @@ RequestExecutor::Infer(
388394
reinterpret_cast<void*>(infer_payload->ResponseAllocUserp().get()),
389395
InferResponseComplete, reinterpret_cast<void*>(infer_payload.get())));
390396

391-
THROW_IF_TRITON_ERROR(TRITONSERVER_ServerInferAsync(
392-
server_, irequest, nullptr /* trace */));
397+
THROW_IF_TRITON_ERROR(
398+
TRITONSERVER_ServerInferAsync(server_, irequest, trace));
393399
}
394400
}
395401
catch (const PythonBackendException& pb_exception) {

0 commit comments

Comments
 (0)