Skip to content

Commit 5a9b10a

Browse files
committed
Remove cleanup from get_transfer_status(). Use TransferStatus class
1 parent 983bdb5 commit 5a9b10a

File tree

1 file changed

+25
-18
lines changed
  • python/sglang/srt/disaggregation/nixl

1 file changed

+25
-18
lines changed

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ def from_zmq(cls, msg: List[bytes]):
6060
)
6161

6262

63+
@dataclasses.dataclass
64+
class TransferStatus:
65+
"""Used by KV Receiver to know when a transfer is done."""
66+
67+
# KV chunk IDs that have been received.
68+
received_kvs: Set[int] = dataclasses.field(default_factory=set)
69+
# Number of kv chunks to expect, will know this after last chunk is received.
70+
num_kvs_expected: Optional[int] = None
71+
# Whether aux data has been received.
72+
received_aux: bool = False
73+
74+
def is_done(self):
75+
if self.num_kvs_expected is None:
76+
return False
77+
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
78+
79+
6380
class NixlKVManager(BaseKVManager):
6481
def __init__(
6582
self,
@@ -94,11 +111,9 @@ def __init__(
94111
self._register_to_bootstrap()
95112
elif self.disaggregation_mode == DisaggregationMode.DECODE:
96113
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
97-
# Map of room to kv chunk IDs that have been received.
98-
self.received_kvs: Dict[int, Set[int]] = defaultdict(set)
99-
# Map of room to number of kv chunks to expect, will know this after last chunk is received.
100-
self.num_kvs_expected: Dict[int, int] = {}
101-
self.received_aux: Dict[int, bool] = {}
114+
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
115+
TransferStatus
116+
)
102117
else:
103118
raise ValueError(
104119
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
@@ -264,24 +279,16 @@ def update_transfer_status(self):
264279
if components[1] == "kv":
265280
chunk_id = int(components[2])
266281
is_last = bool(components[3])
267-
self.received_kvs[room].add(chunk_id)
282+
self.transfer_statuses[room].received_kvs.add(chunk_id)
268283
if is_last:
269-
self.num_kvs_expected[room] = chunk_id + 1
284+
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
270285
elif components[1] == "aux":
271-
self.received_aux[room] = True
286+
self.transfer_statuses[room].received_aux = True
272287

273288
def check_transfer_done(self, room: int):
274-
if room not in self.num_kvs_expected:
289+
if room not in self.transfer_statuses:
275290
return False
276-
if self.num_kvs_expected[room] == len(
277-
self.received_kvs[room]
278-
) and self.received_aux.get(room, False):
279-
# Cleanup
280-
del self.num_kvs_expected[room]
281-
del self.received_kvs[room]
282-
del self.received_aux[room]
283-
return True
284-
return False
291+
return self.transfer_statuses[room].is_done()
285292

286293
def _register_to_bootstrap(self):
287294
"""Register KVSender to bootstrap server via HTTP POST."""

0 commit comments

Comments
 (0)