Skip to content

Commit e19cfe7

Browse files
authored
Fix notify state destruction and inflight states tracking (#6451)
* Ensure notify_state_ gets properly destructed * Fix inflight state tracking to properly erase states * Prevent removing the notify_state from being erased * Wrap notify_state_ object within unique_ptr
1 parent ccdb26b commit e19cfe7

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

src/grpc/infer_handler.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,8 @@ ModelInferHandler::InferResponseComplete(
973973
return;
974974
}
975975

976+
state->context_->EraseInflightState(state);
977+
976978
#ifdef TRITON_ENABLE_TRACING
977979
state->trace_timestamps_.emplace_back(std::make_pair(
978980
"INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp()));
@@ -987,7 +989,6 @@ ModelInferHandler::InferResponseComplete(
987989
"deleting GRPC inference response");
988990

989991
state->step_ = Steps::CANCELLED;
990-
state->context_->EraseInflightState(state);
991992

992993
LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
993994
<< state->unique_id_

src/grpc/infer_handler.h

+13-6
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,9 @@ class InferHandlerState {
640640

641641
void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state)
642642
{
643-
InferHandlerStateType* wrapped_state =
644-
new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state);
645-
ctx_->AsyncNotifyWhenDone(wrapped_state);
643+
notify_state_ = std::unique_ptr<InferHandlerStateType>(
644+
new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state));
645+
ctx_->AsyncNotifyWhenDone(notify_state_.get());
646646
}
647647

648648
void SetReceivedNotification(bool value) { received_notification_ = true; }
@@ -666,8 +666,12 @@ class InferHandlerState {
666666
all_states_.insert(state);
667667
}
668668

669-
// Adds the state object created on this context
670-
void EraseState(InferHandlerStateType* state) { all_states_.erase(state); }
669+
// Erases the state object created on this context
670+
void EraseState(InferHandlerStateType* state)
671+
{
672+
EraseInflightState(state);
673+
all_states_.erase(state);
674+
}
671675

672676
bool HandleCompletion()
673677
{
@@ -975,6 +979,10 @@ class InferHandlerState {
975979
// True if there is an ongoing write to the grpc stream
976980
std::atomic<bool> ongoing_write_;
977981

982+
// The state object that is sent to grpc async notification
983+
// for tracking the gRPC stream.
984+
std::unique_ptr<InferHandlerState> notify_state_;
985+
978986
// Tracks whether the async notification has been delivered by
979987
// completion queue.
980988
bool received_notification_;
@@ -1274,7 +1282,6 @@ InferHandler<
12741282
state->context_->SetReceivedNotification(true);
12751283
LOG_VERBOSE(1) << "Received notification for " << Name() << ", "
12761284
<< state->unique_id_;
1277-
delete state_wrapper;
12781285
}
12791286
LOG_VERBOSE(2) << "Grpc::CQ::Next() "
12801287
<< state->context_->DebugString(state);

src/grpc/stream_infer_handler.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,15 @@ ModelStreamInferHandler::StreamInferResponseComplete(
576576
}
577577
}
578578

579+
// If receiving the final callback then erase the state from the inflight
580+
// state data structure to prevent cancellation being called on the request.
581+
// Also make sure that if this state was sent to gRPC async notification
582+
// mechanism then the state is not removed as it would be needed for handling
583+
// the cancellation if detected.
584+
if (state->complete_ && (!state->IsAsyncNotifyState())) {
585+
state->context_->EraseInflightState(state);
586+
}
587+
579588
if (state->IsGrpcContextCancelled()) {
580589
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
581590
// Clean-up the received response object.
@@ -593,7 +602,6 @@ ModelStreamInferHandler::StreamInferResponseComplete(
593602
// that state object can be released.
594603
if (state->complete_) {
595604
state->step_ = Steps::CANCELLED;
596-
state->context_->EraseInflightState(state);
597605
state->context_->PutTaskBackToQueue(state);
598606
}
599607

@@ -692,7 +700,6 @@ ModelStreamInferHandler::StreamInferResponseComplete(
692700
// that state object can be released.
693701
if (state->complete_) {
694702
state->step_ = Steps::CANCELLED;
695-
state->context_->EraseInflightState(state);
696703
state->context_->PutTaskBackToQueue(state);
697704
}
698705

0 commit comments

Comments
 (0)