@@ -133,9 +133,9 @@ def __init__(self, hs: "HomeServer"):
133
133
if hs .should_send_federation ():
134
134
self .send_handler = FederationSenderHandler (hs )
135
135
136
- # Map from stream to list of deferreds waiting for the stream to
136
+ # Map from stream and instance to list of deferreds waiting for the stream to
137
137
# arrive at a particular position. The lists are sorted by stream position.
138
- self ._streams_to_waiters : Dict [str , List [Tuple [int , Deferred ]]] = {}
138
+ self ._streams_to_waiters : Dict [Tuple [ str , str ] , List [Tuple [int , Deferred ]]] = {}
139
139
140
140
async def on_rdata (
141
141
self , stream_name : str , instance_name : str , token : int , rows : list
@@ -270,7 +270,7 @@ async def on_rdata(
270
270
# Notify any waiting deferreds. The list is ordered by position so we
271
271
# just iterate through the list until we reach a position that is
272
272
# greater than the received row position.
273
- waiting_list = self ._streams_to_waiters .get (stream_name , [])
273
+ waiting_list = self ._streams_to_waiters .get (( stream_name , instance_name ) , [])
274
274
275
275
# Index of first item with a position after the current token, i.e we
276
276
# have called all deferreds before this index. If not overwritten by
@@ -279,14 +279,13 @@ async def on_rdata(
279
279
# `len(list)` works for both cases.
280
280
index_of_first_deferred_not_called = len (waiting_list )
281
281
282
+ # We don't fire the deferreds until after we finish iterating over the
283
+ # list, to avoid the list changing when we fire the deferreds.
284
+ deferreds_to_callback = []
285
+
282
286
for idx , (position , deferred ) in enumerate (waiting_list ):
283
287
if position <= token :
284
- try :
285
- with PreserveLoggingContext ():
286
- deferred .callback (None )
287
- except Exception :
288
- # The deferred has been cancelled or timed out.
289
- pass
288
+ deferreds_to_callback .append (deferred )
290
289
else :
291
290
# The list is sorted by position so we don't need to continue
292
291
# checking any further entries in the list.
@@ -297,6 +296,14 @@ async def on_rdata(
297
296
# loop. (This maintains the order so no need to resort)
298
297
waiting_list [:] = waiting_list [index_of_first_deferred_not_called :]
299
298
299
+ for deferred in deferreds_to_callback :
300
+ try :
301
+ with PreserveLoggingContext ():
302
+ deferred .callback (None )
303
+ except Exception :
304
+ # The deferred has been cancelled or timed out.
305
+ pass
306
+
300
307
async def on_position (
301
308
self , stream_name : str , instance_name : str , token : int
302
309
) -> None :
@@ -349,7 +356,9 @@ async def wait_for_stream_position(
349
356
deferred , _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS , self ._reactor
350
357
)
351
358
352
- waiting_list = self ._streams_to_waiters .setdefault (stream_name , [])
359
+ waiting_list = self ._streams_to_waiters .setdefault (
360
+ (stream_name , instance_name ), []
361
+ )
353
362
354
363
waiting_list .append ((position , deferred ))
355
364
waiting_list .sort (key = lambda t : t [0 ])
0 commit comments