@@ -60,6 +60,23 @@ def from_zmq(cls, msg: List[bytes]):
60
60
)
61
61
62
62
63
+ @dataclasses .dataclass
64
+ class TransferStatus :
65
+ """Used by KV Receiver to know when a transfer is done."""
66
+
67
+ # KV chunk IDs that have been received.
68
+ received_kvs : Set [int ] = dataclasses .field (default_factory = set )
69
+ # Number of kv chunks to expect, will know this after last chunk is received.
70
+ num_kvs_expected : Optional [int ] = None
71
+ # Whether aux data has been received.
72
+ received_aux : bool = False
73
+
74
+ def is_done (self ):
75
+ if self .num_kvs_expected is None :
76
+ return False
77
+ return self .num_kvs_expected == len (self .received_kvs ) and self .received_aux
78
+
79
+
63
80
class NixlKVManager (BaseKVManager ):
64
81
def __init__ (
65
82
self ,
@@ -94,11 +111,9 @@ def __init__(
94
111
self ._register_to_bootstrap ()
95
112
elif self .disaggregation_mode == DisaggregationMode .DECODE :
96
113
self .connection_pool : Dict [str , Dict [str , Union [str , int ]]] = {}
97
- # Map of room to kv chunk IDs that have been received.
98
- self .received_kvs : Dict [int , Set [int ]] = defaultdict (set )
99
- # Map of room to number of kv chunks to expect, will know this after last chunk is received.
100
- self .num_kvs_expected : Dict [int , int ] = {}
101
- self .received_aux : Dict [int , bool ] = {}
114
+ self .transfer_statuses : Dict [int , TransferStatus ] = defaultdict (
115
+ TransferStatus
116
+ )
102
117
else :
103
118
raise ValueError (
104
119
f"Unsupported DisaggregationMode: { self .disaggregation_mode } "
@@ -264,24 +279,16 @@ def update_transfer_status(self):
264
279
if components [1 ] == "kv" :
265
280
chunk_id = int (components [2 ])
266
281
is_last = bool (components [3 ])
267
- self .received_kvs [room ].add (chunk_id )
282
+ self .transfer_statuses [room ]. received_kvs .add (chunk_id )
268
283
if is_last :
269
- self .num_kvs_expected [room ] = chunk_id + 1
284
+ self .transfer_statuses [room ]. num_kvs_expected = chunk_id + 1
270
285
elif components [1 ] == "aux" :
271
- self .received_aux [room ] = True
286
+ self .transfer_statuses [room ]. received_aux = True
272
287
273
288
def check_transfer_done (self , room : int ):
274
- if room not in self .num_kvs_expected :
289
+ if room not in self .transfer_statuses :
275
290
return False
276
- if self .num_kvs_expected [room ] == len (
277
- self .received_kvs [room ]
278
- ) and self .received_aux .get (room , False ):
279
- # Cleanup
280
- del self .num_kvs_expected [room ]
281
- del self .received_kvs [room ]
282
- del self .received_aux [room ]
283
- return True
284
- return False
291
+ return self .transfer_statuses [room ].is_done ()
285
292
286
293
def _register_to_bootstrap (self ):
287
294
"""Register KVSender to bootstrap server via HTTP POST."""
0 commit comments