Skip to content
This repository was archived by the owner on Mar 26, 2024. It is now read-only.

Commit c71199e

Browse files
authored
Update all stream IDs after processing replication rows (matrix-org#14723) (#52)
* Update all stream IDs after processing replication rows (matrix-org#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: matrix-org#14158 (comment) # Conflicts: # synapse/storage/databases/main/devices.py # synapse/storage/databases/main/events_worker.py * Fix bad cherry-picking * Remove leftover stream advance
1 parent 90878d6 commit c71199e

File tree

15 files changed

+115
-66
lines changed

15 files changed

+115
-66
lines changed

changelog.d/14723.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).

synapse/replication/tcp/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ async def on_rdata(
148148
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
149149
"""
150150
self.store.process_replication_rows(stream_name, instance_name, token, rows)
151+
# NOTE: this must be called after process_replication_rows to ensure any
152+
# cache invalidations are first handled before any stream ID advances.
153+
self.store.process_replication_position(stream_name, instance_name, token)
151154

152155
if self.send_handler:
153156
await self.send_handler.process_replication_rows(stream_name, token, rows)

synapse/storage/_base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,22 @@ def process_replication_rows( # noqa: B027 (no-op by design)
5959
token: int,
6060
rows: Iterable[Any],
6161
) -> None:
62-
pass
62+
"""
63+
Used by storage classes to invalidate caches based on incoming replication data. These
64+
must not update any ID generators, use `process_replication_position`.
65+
"""
66+
67+
def process_replication_position( # noqa: B027 (no-op by design)
68+
self,
69+
stream_name: str,
70+
instance_name: str,
71+
token: int,
72+
) -> None:
73+
"""
74+
Used by storage classes to advance ID generators based on incoming replication data. This
75+
is called after process_replication_rows such that caches are invalidated before any token
76+
positions advance.
77+
"""
6378

6479
def _invalidate_state_caches(
6580
self, room_id: str, members_changed: Collection[str]

synapse/storage/databases/main/account_data.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,7 @@ def process_replication_rows(
415415
token: int,
416416
rows: Iterable[Any],
417417
) -> None:
418-
if stream_name == TagAccountDataStream.NAME:
419-
self._account_data_id_gen.advance(instance_name, token)
420-
elif stream_name == AccountDataStream.NAME:
421-
self._account_data_id_gen.advance(instance_name, token)
418+
if stream_name == AccountDataStream.NAME:
422419
for row in rows:
423420
if not row.room_id:
424421
self.get_global_account_data_by_type_for_user.invalidate(
@@ -433,6 +430,15 @@ def process_replication_rows(
433430

434431
super().process_replication_rows(stream_name, instance_name, token, rows)
435432

433+
def process_replication_position(
434+
self, stream_name: str, instance_name: str, token: int
435+
) -> None:
436+
if stream_name == TagAccountDataStream.NAME:
437+
self._account_data_id_gen.advance(instance_name, token)
438+
elif stream_name == AccountDataStream.NAME:
439+
self._account_data_id_gen.advance(instance_name, token)
440+
super().process_replication_position(stream_name, instance_name, token)
441+
436442
async def add_account_data_to_room(
437443
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
438444
) -> int:

synapse/storage/databases/main/cache.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ def process_replication_rows(
164164
backfilled=True,
165165
)
166166
elif stream_name == CachesStream.NAME:
167-
if self._cache_id_gen:
168-
self._cache_id_gen.advance(instance_name, token)
169-
170167
for row in rows:
171168
if row.cache_func == CURRENT_STATE_CACHE_NAME:
172169
if row.keys is None:
@@ -182,6 +179,14 @@ def process_replication_rows(
182179

183180
super().process_replication_rows(stream_name, instance_name, token, rows)
184181

182+
def process_replication_position(
183+
self, stream_name: str, instance_name: str, token: int
184+
) -> None:
185+
if stream_name == CachesStream.NAME:
186+
if self._cache_id_gen:
187+
self._cache_id_gen.advance(instance_name, token)
188+
super().process_replication_position(stream_name, instance_name, token)
189+
185190
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
186191
data = row.data
187192

@@ -198,8 +203,14 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
198203
backfilled=False,
199204
)
200205
elif row.type == EventsStreamCurrentStateRow.TypeId:
201-
# TODO: Nothing to do here, handled in events_worker, cleanup?
202-
pass
206+
assert isinstance(data, EventsStreamCurrentStateRow)
207+
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
208+
209+
if data.type == EventTypes.Member:
210+
self.get_rooms_for_user_with_stream_ordering.invalidate(
211+
(data.state_key,)
212+
)
213+
self.get_rooms_for_user.invalidate((data.state_key,))
203214
else:
204215
raise Exception("Unknown events stream row type %s" % (row.type,))
205216

synapse/storage/databases/main/deviceinbox.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,15 @@ def process_replication_rows(
160160
self._device_federation_outbox_stream_cache.entity_has_changed(
161161
row.entity, token
162162
)
163-
# Important that the ID gen advances after stream change caches
164-
self._device_inbox_id_gen.advance(instance_name, token)
165163
return super().process_replication_rows(stream_name, instance_name, token, rows)
166164

165+
def process_replication_position(
166+
self, stream_name: str, instance_name: str, token: int
167+
) -> None:
168+
if stream_name == ToDeviceStream.NAME:
169+
self._device_inbox_id_gen.advance(instance_name, token)
170+
super().process_replication_position(stream_name, instance_name, token)
171+
167172
def get_to_device_stream_token(self) -> int:
168173
return self._device_inbox_id_gen.get_current_token()
169174

synapse/storage/databases/main/devices.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,20 @@ def process_replication_rows(
163163
) -> None:
164164
if stream_name == DeviceListsStream.NAME:
165165
self._invalidate_caches_for_devices(token, rows)
166-
# Important that the ID gen advances after stream change caches
167-
self._device_list_id_gen.advance(instance_name, token)
168166
elif stream_name == UserSignatureStream.NAME:
169167
for row in rows:
170168
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
171-
# Important that the ID gen advances after stream change caches
172-
self._device_list_id_gen.advance(instance_name, token)
173169
return super().process_replication_rows(stream_name, instance_name, token, rows)
174170

171+
def process_replication_position(
172+
self, stream_name: str, instance_name: str, token: int
173+
) -> None:
174+
if stream_name == DeviceListsStream.NAME:
175+
self._device_list_id_gen.advance(instance_name, token)
176+
elif stream_name == UserSignatureStream.NAME:
177+
self._device_list_id_gen.advance(instance_name, token)
178+
super().process_replication_position(stream_name, instance_name, token)
179+
175180
def _invalidate_caches_for_devices(
176181
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
177182
) -> None:

synapse/storage/databases/main/event_federation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ async def get_forward_extremities_for_room_at_stream_ordering(
11871187
"""
11881188
# We want to make the cache more effective, so we clamp to the last
11891189
# change before the given ordering.
1190-
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
1190+
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
11911191

11921192
# We don't always have a full stream_to_exterm_id table, e.g. after
11931193
# the upgrade that introduced it, so we make sure we never ask for a

synapse/storage/databases/main/events_worker.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,6 @@ def __init__(
249249
prefilled_cache=curr_state_delta_prefill,
250250
)
251251

252-
event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
253-
db_conn,
254-
"events",
255-
entity_column="room_id",
256-
stream_column="stream_ordering",
257-
max_value=events_max,
258-
)
259-
self._events_stream_cache = StreamChangeCache(
260-
"EventsRoomStreamChangeCache",
261-
min_event_val,
262-
prefilled_cache=event_cache_prefill,
263-
)
264-
self._membership_stream_cache = StreamChangeCache(
265-
"MembershipStreamChangeCache", events_max
266-
)
267-
268252
if hs.config.worker.run_background_tasks:
269253
# We periodically clean out old transaction ID mappings
270254
self._clock.looping_call(
@@ -325,35 +309,14 @@ def get_chain_id_txn(txn: Cursor) -> int:
325309
id_column="chain_id",
326310
)
327311

328-
def process_replication_rows(
329-
self,
330-
stream_name: str,
331-
instance_name: str,
332-
token: int,
333-
rows: Iterable[Any],
312+
def process_replication_position(
313+
self, stream_name: str, instance_name: str, token: int
334314
) -> None:
335-
# Process event stream replication rows, handling both the ID generators from the events
336-
# worker store and the stream change caches in this store as the two are interlinked.
337315
if stream_name == EventsStream.NAME:
338-
for row in rows:
339-
if row.type == EventsStreamEventRow.TypeId:
340-
self._events_stream_cache.entity_has_changed(
341-
row.data.room_id, token
342-
)
343-
if row.data.type == EventTypes.Member:
344-
self._membership_stream_cache.entity_has_changed(
345-
row.data.state_key, token
346-
)
347-
if row.type == EventsStreamCurrentStateRow.TypeId:
348-
self._curr_state_delta_stream_cache.entity_has_changed(
349-
row.data.room_id, token
350-
)
351-
# Important that the ID gen advances after stream change caches
352316
self._stream_id_gen.advance(instance_name, token)
353317
elif stream_name == BackfillStream.NAME:
354318
self._backfill_id_gen.advance(instance_name, -token)
355-
356-
super().process_replication_rows(stream_name, instance_name, token, rows)
319+
super().process_replication_position(stream_name, instance_name, token)
357320

358321
async def have_censored_event(self, event_id: str) -> bool:
359322
"""Check if an event has been censored, i.e. if the content of the event has been erased

synapse/storage/databases/main/presence.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,14 @@ def process_replication_rows(
439439
rows: Iterable[Any],
440440
) -> None:
441441
if stream_name == PresenceStream.NAME:
442-
self._presence_id_gen.advance(instance_name, token)
443442
for row in rows:
444443
self.presence_stream_cache.entity_has_changed(row.user_id, token)
445444
self._get_presence_for_user.invalidate((row.user_id,))
446445
return super().process_replication_rows(stream_name, instance_name, token, rows)
446+
447+
def process_replication_position(
448+
self, stream_name: str, instance_name: str, token: int
449+
) -> None:
450+
if stream_name == PresenceStream.NAME:
451+
self._presence_id_gen.advance(instance_name, token)
452+
super().process_replication_position(stream_name, instance_name, token)

synapse/storage/databases/main/push_rule.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,18 @@ def process_replication_rows(
148148
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
149149
) -> None:
150150
if stream_name == PushRulesStream.NAME:
151-
self._push_rules_stream_id_gen.advance(instance_name, token)
152151
for row in rows:
153152
self.get_push_rules_for_user.invalidate((row.user_id,))
154153
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
155154
return super().process_replication_rows(stream_name, instance_name, token, rows)
156155

156+
def process_replication_position(
157+
self, stream_name: str, instance_name: str, token: int
158+
) -> None:
159+
if stream_name == PushRulesStream.NAME:
160+
self._push_rules_stream_id_gen.advance(instance_name, token)
161+
super().process_replication_position(stream_name, instance_name, token)
162+
157163
@cached(max_entries=5000)
158164
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
159165
rows = await self.db_pool.simple_select_list(

synapse/storage/databases/main/pusher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
111111
def get_pushers_stream_token(self) -> int:
112112
return self._pushers_id_gen.get_current_token()
113113

114-
def process_replication_rows(
115-
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
114+
def process_replication_position(
115+
self, stream_name: str, instance_name: str, token: int
116116
) -> None:
117117
if stream_name == PushersStream.NAME:
118118
self._pushers_id_gen.advance(instance_name, token)
119-
return super().process_replication_rows(stream_name, instance_name, token, rows)
119+
super().process_replication_position(stream_name, instance_name, token)
120120

121121
async def get_pushers_by_app_id_and_pushkey(
122122
self, app_id: str, pushkey: str

synapse/storage/databases/main/receipts.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,15 @@ def process_replication_rows(
600600
row.room_id, row.receipt_type, row.user_id
601601
)
602602
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
603-
# Important that the ID gen advances after stream change caches
604-
self._receipts_id_gen.advance(instance_name, token)
605-
606603
return super().process_replication_rows(stream_name, instance_name, token, rows)
607604

605+
def process_replication_position(
606+
self, stream_name: str, instance_name: str, token: int
607+
) -> None:
608+
if stream_name == ReceiptsStream.NAME:
609+
self._receipts_id_gen.advance(instance_name, token)
610+
super().process_replication_position(stream_name, instance_name, token)
611+
608612
def _insert_linearized_receipt_txn(
609613
self,
610614
txn: LoggingTransaction,

synapse/storage/databases/main/stream.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from synapse.storage.util.id_generators import MultiWriterIdGenerator
7272
from synapse.types import PersistedEventPosition, RoomStreamToken
7373
from synapse.util.caches.descriptors import cached
74+
from synapse.util.caches.stream_change_cache import StreamChangeCache
7475
from synapse.util.cancellation import cancellable
7576

7677
if TYPE_CHECKING:
@@ -396,6 +397,23 @@ def __init__(
396397
# during startup which would cause one to die.
397398
self._need_to_reset_federation_stream_positions = self._send_federation
398399

400+
events_max = self.get_room_max_stream_ordering()
401+
event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
402+
db_conn,
403+
"events",
404+
entity_column="room_id",
405+
stream_column="stream_ordering",
406+
max_value=events_max,
407+
)
408+
self._events_stream_cache = StreamChangeCache(
409+
"EventsRoomStreamChangeCache",
410+
min_event_val,
411+
prefilled_cache=event_cache_prefill,
412+
)
413+
self._membership_stream_cache = StreamChangeCache(
414+
"MembershipStreamChangeCache", events_max
415+
)
416+
399417
self._stream_order_on_start = self.get_room_max_stream_ordering()
400418
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
401419

synapse/storage/databases/main/tags.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,19 @@ def process_replication_rows(
300300
rows: Iterable[Any],
301301
) -> None:
302302
if stream_name == TagAccountDataStream.NAME:
303-
self._account_data_id_gen.advance(instance_name, token)
304303
for row in rows:
305304
self.get_tags_for_user.invalidate((row.user_id,))
306305
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
307306

308307
super().process_replication_rows(stream_name, instance_name, token, rows)
309308

309+
def process_replication_position(
310+
self, stream_name: str, instance_name: str, token: int
311+
) -> None:
312+
if stream_name == TagAccountDataStream.NAME:
313+
self._account_data_id_gen.advance(instance_name, token)
314+
super().process_replication_position(stream_name, instance_name, token)
315+
310316

311317
class TagsStore(TagsWorkerStore):
312318
pass

0 commit comments

Comments
 (0)