Skip to content

Commit 7db166f

Browse files
authored
Add request input to profile export data (#549)
* Add request_inputs field in profile export file * Add request input data to profile export * Add test cases * Remove comment * Add ticket and remove comments * Add tests * Remove first 4 bytes * Add warning * handle empty input * Avoid access to buffer when empty
1 parent e2c000c commit 7db166f

17 files changed

+288
-86
lines changed

src/c++/library/common.cc

+14
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,20 @@ InferInput::AppendFromString(const std::vector<std::string>& input)
182182
return AppendRaw(reinterpret_cast<const uint8_t*>(&sbuf[0]), sbuf.size());
183183
}
184184

185+
Error
186+
InferInput::RawData(const uint8_t** buf, size_t* byte_size)
187+
{
188+
if (bufs_.size()) {
189+
// TMA-1775 - handle multi-batch case
190+
*buf = bufs_[0];
191+
*byte_size = buf_byte_sizes_[0];
192+
} else {
193+
*buf = nullptr;
194+
*byte_size = 0;
195+
}
196+
return Error::Success;
197+
}
198+
185199
Error
186200
InferInput::ByteSize(size_t* byte_size) const
187201
{

src/c++/library/common.h

+9
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,15 @@ class InferInput {
334334
/// \return Error object indicating success or failure.
335335
Error AppendFromString(const std::vector<std::string>& input);
336336

337+
/// Get access to the buffer holding raw input. Note the buffer is owned by
338+
/// InferInput instance. Users can copy out the data if required to extend
339+
/// the lifetime.
340+
/// \param buf Returns the pointer to the start of the buffer.
341+
/// \param byte_size Returns the size of buffer in bytes.
342+
/// \return Error object indicating success or failure of the
343+
/// request.
344+
Error RawData(const uint8_t** buf, size_t* byte_size);
345+
337346
/// Gets the size of data added into this input in bytes.
338347
/// \param byte_size The size of data added in bytes.
339348
/// \return Error object indicating success or failure.

src/c++/perf_analyzer/client_backend/client_backend.cc

+8
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,14 @@ InferInput::SetSharedMemory(
505505
pa::GENERIC_ERROR);
506506
}
507507

508+
Error
509+
InferInput::RawData(const uint8_t** buf, size_t* byte_size)
510+
{
511+
return Error(
512+
"client backend of kind " + BackendKindToString(kind_) +
513+
" does not support RawData() for InferInput",
514+
pa::GENERIC_ERROR);
515+
}
508516

509517
InferInput::InferInput(
510518
const BackendKind kind, const std::string& name,

src/c++/perf_analyzer/client_backend/client_backend.h

+9
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ class InferInput {
558558
virtual Error SetSharedMemory(
559559
const std::string& name, size_t byte_size, size_t offset = 0);
560560

561+
/// Get access to the buffer holding raw input. Note the buffer is owned by
562+
/// InferInput instance. Users can copy out the data if required to extend
563+
/// the lifetime.
564+
/// \param buf Returns the pointer to the start of the buffer.
565+
/// \param byte_size Returns the size of buffer in bytes.
566+
/// \return Error object indicating success or failure of the
567+
/// request.
568+
virtual Error RawData(const uint8_t** buf, size_t* byte_size);
569+
561570
protected:
562571
InferInput(
563572
const BackendKind kind, const std::string& name,

src/c++/perf_analyzer/client_backend/openai/openai_infer_input.cc

+9
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ OpenAiInferInput::AppendRaw(const uint8_t* input, size_t input_byte_size)
7171
return Error::Success;
7272
}
7373

74+
Error
75+
OpenAiInferInput::RawData(const uint8_t** buf, size_t* byte_size)
76+
{
77+
// TMA-1775 - handle multi-batch case
78+
*buf = bufs_[0];
79+
*byte_size = buf_byte_sizes_[0];
80+
return Error::Success;
81+
}
82+
7483
Error
7584
OpenAiInferInput::PrepareForRequest()
7685
{

src/c++/perf_analyzer/client_backend/openai/openai_infer_input.h

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class OpenAiInferInput : public InferInput {
5151
Error Reset() override;
5252
/// See InferInput::AppendRaw()
5353
Error AppendRaw(const uint8_t* input, size_t input_byte_size) override;
54+
/// See InferInput::RawData()
55+
Error RawData(const uint8_t** buf, size_t* byte_size) override;
5456
/// Prepare the input to be in the form expected by an OpenAI client,
5557
/// must call before accessing the data.
5658
Error PrepareForRequest();

src/c++/perf_analyzer/client_backend/triton/triton_client_backend.cc

+7
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,13 @@ TritonInferInput::SetSharedMemory(
756756
return Error::Success;
757757
}
758758

759+
Error
760+
TritonInferInput::RawData(const uint8_t** buf, size_t* byte_size)
761+
{
762+
RETURN_IF_TRITON_ERROR(input_->RawData(buf, byte_size));
763+
return Error::Success;
764+
}
765+
759766
TritonInferInput::TritonInferInput(
760767
const std::string& name, const std::string& datatype)
761768
: InferInput(BackendKind::TRITON, name, datatype)

src/c++/perf_analyzer/client_backend/triton/triton_client_backend.h

+2
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ class TritonInferInput : public InferInput {
283283
/// See InferInput::SetSharedMemory()
284284
Error SetSharedMemory(
285285
const std::string& name, size_t byte_size, size_t offset = 0) override;
286+
/// See InferInput::RawData()
287+
Error RawData(const uint8_t** buf, size_t* byte_size) override;
286288

287289
private:
288290
explicit TritonInferInput(

src/c++/perf_analyzer/infer_context.cc

+37-10
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ InferContext::SendRequest(
112112
}
113113

114114
thread_stat_->num_sent_requests_++;
115+
116+
// Parse the request inputs to save in the profile export file
117+
RequestRecord::RequestInput request_inputs{GetInputs()};
118+
115119
if (async_) {
116120
uint64_t unique_request_id{(thread_id_ << 48) | ((request_id << 16) >> 16)};
117121
infer_data_.options_->request_id_ = std::to_string(unique_request_id);
@@ -120,6 +124,7 @@ InferContext::SendRequest(
120124
auto it = async_req_map_
121125
.emplace(infer_data_.options_->request_id_, RequestRecord())
122126
.first;
127+
it->second.request_inputs_ = {request_inputs};
123128
it->second.start_time_ = std::chrono::system_clock::now();
124129
it->second.sequence_end_ = infer_data_.options_->sequence_end_;
125130
it->second.delayed_ = delayed;
@@ -149,10 +154,10 @@ InferContext::SendRequest(
149154
&results, *(infer_data_.options_), infer_data_.valid_inputs_,
150155
infer_data_.outputs_);
151156
thread_stat_->idle_timer.Stop();
152-
RequestRecord::ResponseOutput response_output{};
157+
RequestRecord::ResponseOutput response_outputs{};
153158
if (results != nullptr) {
154159
if (thread_stat_->status_.IsOk()) {
155-
response_output = GetOutput(*results);
160+
response_outputs = GetOutputs(*results);
156161
thread_stat_->status_ = ValidateOutputs(results);
157162
}
158163
delete results;
@@ -169,8 +174,9 @@ InferContext::SendRequest(
169174
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
170175
auto total = end_time_sync - start_time_sync;
171176
thread_stat_->request_records_.emplace_back(RequestRecord(
172-
start_time_sync, std::move(end_time_syncs), {response_output},
173-
infer_data_.options_->sequence_end_, delayed, sequence_id, false));
177+
start_time_sync, std::move(end_time_syncs), {request_inputs},
178+
{response_outputs}, infer_data_.options_->sequence_end_, delayed,
179+
sequence_id, false));
174180
thread_stat_->status_ =
175181
infer_backend_->ClientInferStat(&(thread_stat_->contexts_stat_[id_]));
176182
if (!thread_stat_->status_.IsOk()) {
@@ -180,15 +186,36 @@ InferContext::SendRequest(
180186
}
181187
}
182188

189+
const RequestRecord::RequestInput
190+
InferContext::GetInputs()
191+
{
192+
RequestRecord::RequestInput input{};
193+
for (const auto& request_input : infer_data_.valid_inputs_) {
194+
const uint8_t* buf{nullptr};
195+
size_t byte_size{0};
196+
std::string data_type{request_input->Datatype()};
197+
request_input->RawData(&buf, &byte_size);
198+
199+
// The first 4 bytes of BYTES data is a 32-bit integer to indicate the size
200+
// of the rest of the data (which we already know based on byte_size). It
201+
// should be ignored here, as it isn't part of the actual request
202+
if (data_type == "BYTES" && byte_size >= 4) {
203+
buf += 4;
204+
byte_size -= 4;
205+
}
206+
input.emplace(request_input->Name(), RecordData(buf, byte_size, data_type));
207+
}
208+
return input;
209+
}
210+
183211
const RequestRecord::ResponseOutput
184-
InferContext::GetOutput(const cb::InferResult& infer_result)
212+
InferContext::GetOutputs(const cb::InferResult& infer_result)
185213
{
186214
RequestRecord::ResponseOutput output{};
187215
for (const auto& requested_output : infer_data_.outputs_) {
188216
const uint8_t* buf{nullptr};
189217
size_t byte_size{0};
190218
infer_result.RawData(requested_output->Name(), &buf, &byte_size);
191-
192219
// The first 4 bytes of BYTES data is a 32-bit integer to indicate the size
193220
// of the rest of the data (which we already know based on byte_size). It
194221
// should be ignored here, as it isn't part of the actual response
@@ -282,7 +309,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
282309
}
283310
it->second.response_timestamps_.push_back(
284311
std::chrono::system_clock::now());
285-
it->second.response_outputs_.push_back(GetOutput(*result));
312+
it->second.response_outputs_.push_back(GetOutputs(*result));
286313
num_responses_++;
287314
if (is_null_response == true) {
288315
it->second.has_null_last_response_ = true;
@@ -296,9 +323,9 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
296323
has_received_final_response_ = is_final_response;
297324
thread_stat_->request_records_.emplace_back(
298325
it->second.start_time_, it->second.response_timestamps_,
299-
it->second.response_outputs_, it->second.sequence_end_,
300-
it->second.delayed_, it->second.sequence_id_,
301-
it->second.has_null_last_response_);
326+
it->second.request_inputs_, it->second.response_outputs_,
327+
it->second.sequence_end_, it->second.delayed_,
328+
it->second.sequence_id_, it->second.has_null_last_response_);
302329
infer_backend_->ClientInferStat(&(thread_stat_->contexts_stat_[id_]));
303330
thread_stat_->cb_status_ = ValidateOutputs(result);
304331
async_req_map_.erase(request_id);

src/c++/perf_analyzer/infer_context.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ class InferContext {
185185
std::function<void(uint32_t)> async_callback_finalize_func_ = nullptr;
186186

187187
private:
188-
const RequestRecord::ResponseOutput GetOutput(
188+
const RequestRecord::RequestInput GetInputs();
189+
190+
const RequestRecord::ResponseOutput GetOutputs(
189191
const cb::InferResult& infer_result);
190192

191193
const uint32_t id_{0};

src/c++/perf_analyzer/profile_data_exporter.cc

+45
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ ProfileDataExporter::AddRequests(
122122
request.AddMember("sequence_id", sequence_id, document_.GetAllocator());
123123
}
124124

125+
rapidjson::Value request_inputs(rapidjson::kObjectType);
126+
AddRequestInputs(request_inputs, raw_request.request_inputs_);
127+
request.AddMember(
128+
"request_inputs", request_inputs, document_.GetAllocator());
129+
125130
rapidjson::Value response_timestamps(rapidjson::kArrayType);
126131
AddResponseTimestamps(
127132
response_timestamps, raw_request.response_timestamps_);
@@ -151,6 +156,45 @@ ProfileDataExporter::AddResponseTimestamps(
151156
}
152157
}
153158

159+
void
160+
ProfileDataExporter::AddRequestInputs(
161+
rapidjson::Value& request_inputs_json,
162+
const std::vector<RequestRecord::RequestInput>& request_inputs)
163+
{
164+
for (const auto& request_input : request_inputs) {
165+
for (const auto& input : request_input) {
166+
const auto& name{input.first};
167+
const auto& buf{input.second.data_.get()};
168+
const auto& byte_size{input.second.size_};
169+
const auto& data_type{input.second.data_type_};
170+
rapidjson::Value name_json(name.c_str(), document_.GetAllocator());
171+
rapidjson::Value input_json{};
172+
// TMA-1777: support other data types
173+
if (buf != nullptr) {
174+
if (data_type == "BYTES" || data_type == "JSON") {
175+
input_json.SetString(
176+
reinterpret_cast<const char*>(buf), byte_size,
177+
document_.GetAllocator());
178+
} else if (data_type == "INT32") {
179+
auto* val = reinterpret_cast<int32_t*>(buf);
180+
input_json.SetInt(*val);
181+
} else if (data_type == "BOOL") {
182+
bool is_true = (*buf > 0);
183+
input_json.SetBool(is_true);
184+
} else {
185+
std::cerr << "WARNING: data type '" + data_type +
186+
"' is not supported with JSON."
187+
<< std::endl;
188+
}
189+
} else {
190+
input_json.SetString("", 0, document_.GetAllocator());
191+
}
192+
request_inputs_json.AddMember(
193+
name_json, input_json, document_.GetAllocator());
194+
}
195+
}
196+
}
197+
154198
void
155199
ProfileDataExporter::AddResponseOutputs(
156200
rapidjson::Value& outputs_json,
@@ -164,6 +208,7 @@ ProfileDataExporter::AddResponseOutputs(
164208
const auto& byte_size{output.second.size_};
165209
rapidjson::Value name_json(name.c_str(), document_.GetAllocator());
166210
rapidjson::Value output_json{};
211+
// TMA-1777: support other data types
167212
if (buf != nullptr) {
168213
output_json.SetString(
169214
reinterpret_cast<const char*>(buf), byte_size,

src/c++/perf_analyzer/profile_data_exporter.h

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class ProfileDataExporter {
6969
void AddRequests(
7070
rapidjson::Value& entry, rapidjson::Value& requests,
7171
const Experiment& raw_experiment);
72+
void AddRequestInputs(
73+
rapidjson::Value& inputs_json,
74+
const std::vector<RequestRecord::RequestInput>& inputs);
7275
void AddResponseTimestamps(
7376
rapidjson::Value& timestamps_json,
7477
const std::vector<std::chrono::time_point<std::chrono::system_clock>>&

src/c++/perf_analyzer/request_record.h

+13-8
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,19 @@
3333

3434
namespace triton { namespace perfanalyzer {
3535

36-
/// A record containing the data of a single response
37-
struct ResponseData {
38-
ResponseData(const uint8_t* buf, size_t size)
36+
/// A record containing the data of a single request input or response output
37+
struct RecordData {
38+
RecordData(const uint8_t* buf, size_t size, std::string data_type = "")
3939
{
4040
uint8_t* array = new uint8_t[size];
4141
std::memcpy(array, buf, size);
4242
data_ = std::shared_ptr<uint8_t>(array, [](uint8_t* p) { delete[] p; });
4343
size_ = size;
44+
data_type_ = data_type;
4445
}
4546

4647
// Define equality comparison operator so it can be inserted into maps
47-
bool operator==(const ResponseData& other) const
48+
bool operator==(const RecordData& other) const
4849
{
4950
if (size_ != other.size_)
5051
return false;
@@ -54,24 +55,28 @@ struct ResponseData {
5455

5556
std::shared_ptr<uint8_t> data_;
5657
size_t size_;
58+
std::string data_type_;
5759
};
5860

5961

6062
/// A record of an individual request
6163
struct RequestRecord {
62-
using ResponseOutput = std::unordered_map<std::string, ResponseData>;
64+
using RequestInput = std::unordered_map<std::string, RecordData>;
65+
using ResponseOutput = std::unordered_map<std::string, RecordData>;
6366

6467
RequestRecord(
6568
std::chrono::time_point<std::chrono::system_clock> start_time =
6669
std::chrono::time_point<std::chrono::system_clock>(),
6770
std::vector<std::chrono::time_point<std::chrono::system_clock>>
6871
response_timestamps = {},
72+
std::vector<RequestInput> request_inputs = {},
6973
std::vector<ResponseOutput> response_outputs = {},
7074
bool sequence_end = true, bool delayed = false, uint64_t sequence_id = 0,
7175
bool has_null_last_response = false)
7276
: start_time_(start_time), response_timestamps_(response_timestamps),
73-
response_outputs_(response_outputs), sequence_end_(sequence_end),
74-
delayed_(delayed), sequence_id_(sequence_id),
77+
request_inputs_(request_inputs), response_outputs_(response_outputs),
78+
sequence_end_(sequence_end), delayed_(delayed),
79+
sequence_id_(sequence_id),
7580
has_null_last_response_(has_null_last_response)
7681
{
7782
}
@@ -81,7 +86,7 @@ struct RequestRecord {
8186
std::vector<std::chrono::time_point<std::chrono::system_clock>>
8287
response_timestamps_;
8388

84-
// Collection of response outputs
89+
std::vector<RequestInput> request_inputs_;
8590
std::vector<ResponseOutput> response_outputs_;
8691
// Whether or not the request is at the end of a sequence.
8792
bool sequence_end_;

0 commit comments

Comments
 (0)