Skip to content

Commit 3bde101

Browse files
authored
[PD] Abort request if transfer fails (#6504)
1 parent 7513558 commit 3bde101

File tree

5 files changed

+84
-4
lines changed

5 files changed

+84
-4
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_mla_backend,
4242
kv_to_page_indices,
4343
poll_and_all_reduce,
44+
prepare_abort,
4445
)
4546
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
4647
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
@@ -178,7 +179,17 @@ def _update_handshake_waiters(self) -> None:
178179
elif poll == KVPoll.WaitingForInput:
179180
decode_req.waiting_for_input = True
180181
elif poll == KVPoll.Failed:
181-
raise Exception("Handshake failed")
182+
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
183+
try:
184+
decode_req.kv_receiver.failure_exception()
185+
except Exception as e:
186+
error_message += f" with exception {e}"
187+
logger.error(error_message)
188+
prepare_abort(
189+
decode_req.req,
190+
error_message,
191+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
192+
)
182193

183194
def pop_preallocated(self) -> List[DecodeRequest]:
184195
"""Pop the preallocated requests from the pending queue (FIFO)."""
@@ -333,7 +344,24 @@ def pop_transferred(self) -> List[DecodeRequest]:
333344
indices_to_remove = set()
334345
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
335346
if poll == KVPoll.Failed:
336-
raise Exception("Transfer failed")
347+
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
348+
try:
349+
decode_req.kv_receiver.failure_exception()
350+
except Exception as e:
351+
error_message += f" with exception {e}"
352+
logger.error(error_message)
353+
prepare_abort(
354+
decode_req.req,
355+
error_message,
356+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
357+
)
358+
self.scheduler.stream_output(
359+
[decode_req.req], decode_req.req.return_logprob
360+
)
361+
# unlock the kv cache or it will have memory leak
362+
self.tree_cache.cache_finished_req(decode_req.req)
363+
indices_to_remove.add(i)
364+
continue
337365
elif poll == KVPoll.Success:
338366
# pop and push it to waiting queue
339367
idx = decode_req.metadata_buffer_index

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def poll(self) -> KVPoll:
496496
return self.kv_mgr.check_status(self.bootstrap_room)
497497

498498
def failure_exception(self):
499+
# TODO: raise a real exception
499500
raise Exception("Fake KVSender Exception")
500501

501502

@@ -723,6 +724,7 @@ def poll(self) -> KVPoll:
723724
return self.kv_mgr.check_status(self.bootstrap_room)
724725

725726
def failure_exception(self):
727+
# TODO: raise a real exception
726728
raise Exception("Fake KVReceiver Exception")
727729

728730

python/sglang/srt/disaggregation/prefill.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
kv_to_page_indices,
3939
kv_to_page_num,
4040
poll_and_all_reduce,
41+
prepare_abort,
4142
)
4243
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
4344

@@ -157,7 +158,18 @@ def pop_bootstrapped(self) -> List[Req]:
157158
if poll == KVPoll.Bootstrapping:
158159
continue
159160
elif poll == KVPoll.Failed:
160-
raise Exception("Bootstrap failed")
161+
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
162+
try:
163+
req.disagg_kv_sender.failure_exception()
164+
except Exception as e:
165+
error_message += f" with exception {e}"
166+
logger.error(error_message)
167+
prepare_abort(
168+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
169+
)
170+
self.scheduler.stream_output([req], req.return_logprob)
171+
indices_to_remove.add(i)
172+
continue
161173

162174
# KV.WaitingForInput
163175
num_kv_indices = len(req.origin_input_ids)
@@ -335,7 +347,17 @@ def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
335347
# FIXME: clean up req's data in transfer engine
336348
done_reqs.append(req)
337349
elif poll == KVPoll.Failed:
338-
raise Exception("Transferring failed")
350+
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
351+
try:
352+
req.disagg_kv_sender.failure_exception()
353+
except Exception as e:
354+
error_message += f" with exception {e}"
355+
logger.warning(error_message)
356+
self.tree_cache.cache_finished_req(req) # unlock the tree
357+
prepare_abort(
358+
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
359+
)
360+
done_reqs.append(req)
339361

340362
for req in done_reqs:
341363
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(

python/sglang/srt/disaggregation/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
167167
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
168168

169169
return isinstance(target_kv_pool, MLATokenToKVPool)
170+
171+
172+
def prepare_abort(req: Req, error_message: str, status_code=None):
173+
from sglang.srt.managers.schedule_batch import FINISH_ABORT
174+
175+
# populate finish metadata and stream output
176+
req.finished_reason = FINISH_ABORT(error_message, status_code)
177+
178+
if req.return_logprob:
179+
req.input_token_logprobs_val = []
180+
req.input_token_logprobs_idx = []
181+
req.input_top_logprobs_val = []
182+
req.input_top_logprobs_idx = []
183+
req.input_token_ids_logprobs_val = []
184+
req.input_token_ids_logprobs_idx = []

python/sglang/srt/managers/scheduler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
DisaggregationMode,
5151
ReqToMetadataIdxAllocator,
5252
TransferBackend,
53+
prepare_abort,
5354
)
5455
from sglang.srt.distributed import get_pp_group, get_world_group
5556
from sglang.srt.hf_transformers_utils import (
@@ -935,6 +936,18 @@ def handle_generate_request(
935936
)
936937
req.tokenizer = self.tokenizer
937938

939+
if self.disaggregation_mode != DisaggregationMode.NULL:
940+
# Invalid request for disaggregated mode
941+
if recv_req.bootstrap_room is None:
942+
error_message = (
943+
f"Invalid request: Disaggregated request received without "
944+
f"boostrap room id. {req.rid=}"
945+
)
946+
logger.error(error_message)
947+
prepare_abort(req, error_message)
948+
self.stream_output([req], req.return_logprob)
949+
return
950+
938951
if (
939952
recv_req.session_params is not None
940953
and recv_req.session_params.id is not None

0 commit comments

Comments
 (0)