Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2ee0b6e

Browse files
authored
Safe async event cache (#13308)
Fix race conditions in the async cache invalidation logic, by separating the async & local invalidation calls and ensuring any async call i executed first. Signed off by Nick @ Beeper (@Fizzadar).
1 parent 7864f33 commit 2ee0b6e

File tree

8 files changed

+102
-21
lines changed

8 files changed

+102
-21
lines changed

changelog.d/13308.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).

synapse/storage/_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def _attempt_to_invalidate_cache(
9696
cache doesn't exist. Mainly used for invalidating caches on workers,
9797
where they may not have the cache.
9898
99+
Note that this function does not invalidate any remote caches, only the
100+
local in-memory ones. Any remote invalidation must be performed before
101+
calling this.
102+
99103
Args:
100104
cache_name
101105
key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +116,10 @@ def _attempt_to_invalidate_cache(
112116
if key is None:
113117
cache.invalidate_all()
114118
else:
115-
cache.invalidate(tuple(key))
119+
# Prefer any local-only invalidation method. Invalidating any non-local
120+
# cache must be be done before this.
121+
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
122+
invalidate_method(tuple(key))
116123

117124

118125
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:

synapse/storage/database.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import (
2424
TYPE_CHECKING,
2525
Any,
26+
Awaitable,
2627
Callable,
2728
Collection,
2829
Dict,
@@ -57,7 +58,7 @@
5758
from synapse.storage.background_updates import BackgroundUpdater
5859
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
5960
from synapse.storage.types import Connection, Cursor
60-
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
61+
from synapse.util.async_helpers import delay_cancellation
6162
from synapse.util.iterutils import batch_iter
6263

6364
if TYPE_CHECKING:
@@ -168,6 +169,7 @@ def cursor(
168169
*,
169170
txn_name: Optional[str] = None,
170171
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
172+
async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
171173
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
172174
) -> "LoggingTransaction":
173175
if not txn_name:
@@ -178,6 +180,7 @@ def cursor(
178180
name=txn_name,
179181
database_engine=self.engine,
180182
after_callbacks=after_callbacks,
183+
async_after_callbacks=async_after_callbacks,
181184
exception_callbacks=exception_callbacks,
182185
)
183186

@@ -209,6 +212,9 @@ def __getattr__(self, name: str) -> Any:
209212

210213
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
211214
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
215+
_AsyncCallbackListEntry = Tuple[
216+
Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
217+
]
212218

213219
P = ParamSpec("P")
214220
R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
227233
that have been added by `call_after` which should be run on
228234
successful completion of the transaction. None indicates that no
229235
callbacks should be allowed to be scheduled to run.
236+
async_after_callbacks: A list that asynchronous callbacks will be appended
237+
to by `async_call_after` which should run, before after_callbacks, on
238+
successful completion of the transaction. None indicates that no
239+
callbacks should be allowed to be scheduled to run.
230240
exception_callbacks: A list that callbacks will be appended
231241
to that have been added by `call_on_exception` which should be run
232242
if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
238248
"name",
239249
"database_engine",
240250
"after_callbacks",
251+
"async_after_callbacks",
241252
"exception_callbacks",
242253
]
243254

@@ -247,12 +258,14 @@ def __init__(
247258
name: str,
248259
database_engine: BaseDatabaseEngine,
249260
after_callbacks: Optional[List[_CallbackListEntry]] = None,
261+
async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
250262
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
251263
):
252264
self.txn = txn
253265
self.name = name
254266
self.database_engine = database_engine
255267
self.after_callbacks = after_callbacks
268+
self.async_after_callbacks = async_after_callbacks
256269
self.exception_callbacks = exception_callbacks
257270

258271
def call_after(
@@ -277,6 +290,28 @@ def call_after(
277290
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
278291
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
279292

293+
def async_call_after(
294+
self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
295+
) -> None:
296+
"""Call the given asynchronous callback on the main twisted thread after
297+
the transaction has finished (but before those added in `call_after`).
298+
299+
Mostly used to invalidate remote caches after transactions.
300+
301+
Note that transactions may be retried a few times if they encounter database
302+
errors such as serialization failures. Callbacks given to `async_call_after`
303+
will accumulate across transaction attempts and will _all_ be called once a
304+
transaction attempt succeeds, regardless of whether previous transaction
305+
attempts failed. Otherwise, if all transaction attempts fail, all
306+
`call_on_exception` callbacks will be run instead.
307+
"""
308+
# if self.async_after_callbacks is None, that means that whatever constructed the
309+
# LoggingTransaction isn't expecting there to be any callbacks; assert that
310+
# is not the case.
311+
assert self.async_after_callbacks is not None
312+
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
313+
self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
314+
280315
def call_on_exception(
281316
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
282317
) -> None:
@@ -574,6 +609,7 @@ def new_transaction(
574609
conn: LoggingDatabaseConnection,
575610
desc: str,
576611
after_callbacks: List[_CallbackListEntry],
612+
async_after_callbacks: List[_AsyncCallbackListEntry],
577613
exception_callbacks: List[_CallbackListEntry],
578614
func: Callable[Concatenate[LoggingTransaction, P], R],
579615
*args: P.args,
@@ -597,6 +633,7 @@ def new_transaction(
597633
conn
598634
desc
599635
after_callbacks
636+
async_after_callbacks
600637
exception_callbacks
601638
func
602639
*args
@@ -659,6 +696,7 @@ def new_transaction(
659696
cursor = conn.cursor(
660697
txn_name=name,
661698
after_callbacks=after_callbacks,
699+
async_after_callbacks=async_after_callbacks,
662700
exception_callbacks=exception_callbacks,
663701
)
664702
try:
@@ -798,6 +836,7 @@ async def runInteraction(
798836

799837
async def _runInteraction() -> R:
800838
after_callbacks: List[_CallbackListEntry] = []
839+
async_after_callbacks: List[_AsyncCallbackListEntry] = []
801840
exception_callbacks: List[_CallbackListEntry] = []
802841

803842
if not current_context():
@@ -809,6 +848,7 @@ async def _runInteraction() -> R:
809848
self.new_transaction,
810849
desc,
811850
after_callbacks,
851+
async_after_callbacks,
812852
exception_callbacks,
813853
func,
814854
*args,
@@ -817,15 +857,17 @@ async def _runInteraction() -> R:
817857
**kwargs,
818858
)
819859

860+
# We order these assuming that async functions call out to external
861+
# systems (e.g. to invalidate a cache) and the sync functions make these
862+
# changes on any local in-memory caches/similar, and thus must be second.
863+
for async_callback, async_args, async_kwargs in async_after_callbacks:
864+
await async_callback(*async_args, **async_kwargs)
820865
for after_callback, after_args, after_kwargs in after_callbacks:
821-
await maybe_awaitable(after_callback(*after_args, **after_kwargs))
822-
866+
after_callback(*after_args, **after_kwargs)
823867
return cast(R, result)
824868
except Exception:
825869
for exception_callback, after_args, after_kwargs in exception_callbacks:
826-
await maybe_awaitable(
827-
exception_callback(*after_args, **after_kwargs)
828-
)
870+
exception_callback(*after_args, **after_kwargs)
829871
raise
830872

831873
# To handle cancellation, we ensure that `after_callback`s and

synapse/storage/databases/main/censor_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def delete_expired_event_txn(txn: LoggingTransaction) -> None:
194194
# changed its content in the database. We can't call
195195
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
196196
# right type.
197-
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
197+
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
198198
# Send that invalidation to replication so that other workers also invalidate
199199
# the event cache.
200200
self._send_invalidation_to_replication(

synapse/storage/databases/main/events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ def _update_room_depths_txn(
12931293
depth_updates: Dict[str, int] = {}
12941294
for event, context in events_and_contexts:
12951295
# Remove the any existing cache entries for the event_ids
1296-
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
1296+
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
12971297
# Then update the `stream_ordering` position to mark the latest
12981298
# event as the front of the room. This should not be done for
12991299
# backfilled events because backfilled events have negative
@@ -1675,7 +1675,7 @@ async def prefill() -> None:
16751675
(cache_entry.event.event_id,), cache_entry
16761676
)
16771677

1678-
txn.call_after(prefill)
1678+
txn.async_call_after(prefill)
16791679

16801680
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
16811681
"""Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
16841684
_invalidate_caches_for_event.
16851685
"""
16861686
assert event.redacts is not None
1687-
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
1687+
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
16881688
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
16891689
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
16901690

synapse/storage/databases/main/events_worker.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -712,17 +712,41 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
712712

713713
return event_entry_map
714714

715-
async def _invalidate_get_event_cache(self, event_id: str) -> None:
716-
# First we invalidate the asynchronous cache instance. This may include
717-
# out-of-process caches such as Redis/memcache. Once complete we can
718-
# invalidate any in memory cache. The ordering is important here to
719-
# ensure we don't pull in any remote invalid value after we invalidate
720-
# the in-memory cache.
715+
def invalidate_get_event_cache_after_txn(
716+
self, txn: LoggingTransaction, event_id: str
717+
) -> None:
718+
"""
719+
Prepares a database transaction to invalidate the get event cache for a given
720+
event ID when executed successfully. This is achieved by attaching two callbacks
721+
to the transaction, one to invalidate the async cache and one for the in memory
722+
sync cache (importantly called in that order).
723+
724+
Arguments:
725+
txn: the database transaction to attach the callbacks to
726+
event_id: the event ID to be invalidated from caches
727+
"""
728+
729+
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
730+
txn.call_after(self._invalidate_local_get_event_cache, event_id)
731+
732+
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
733+
"""
734+
Invalidates an event in the asyncronous get event cache, which may be remote.
735+
736+
Arguments:
737+
event_id: the event ID to invalidate
738+
"""
739+
721740
await self._get_event_cache.invalidate((event_id,))
722-
self._event_ref.pop(event_id, None)
723-
self._current_event_fetches.pop(event_id, None)
724741

725742
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
743+
"""
744+
Invalidates an event in local in-memory get event caches.
745+
746+
Arguments:
747+
event_id: the event ID to invalidate
748+
"""
749+
726750
self._get_event_cache.invalidate_local((event_id,))
727751
self._event_ref.pop(event_id, None)
728752
self._current_event_fetches.pop(event_id, None)
@@ -958,7 +982,13 @@ def _fetch_event_list(
958982
}
959983

960984
row_dict = self.db_pool.new_transaction(
961-
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
985+
conn,
986+
"do_fetch",
987+
[],
988+
[],
989+
[],
990+
self._fetch_event_rows,
991+
events_to_fetch,
962992
)
963993

964994
# We only want to resolve deferreds from the main thread

synapse/storage/databases/main/monthly_active_users.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
"initialise_mau_threepids",
6767
[],
6868
[],
69+
[],
6970
self._initialise_reserved_users,
7071
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
7172
)

synapse/storage/databases/main/purge_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _purge_history_txn(
304304
self._invalidate_cache_and_stream(
305305
txn, self.have_seen_event, (room_id, event_id)
306306
)
307-
txn.call_after(self._invalidate_get_event_cache, event_id)
307+
self.invalidate_get_event_cache_after_txn(txn, event_id)
308308

309309
logger.info("[purge] done")
310310

0 commit comments

Comments
 (0)