Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 6bd8763

Browse files
author
Mathieu Velten
authored
Add cache invalidation across workers to module API (#13667)
Signed-off-by: Mathieu Velten <[email protected]>
1 parent 16e1a9d commit 6bd8763

File tree

7 files changed

+153
-21
lines changed

7 files changed

+153
-21
lines changed

changelog.d/13667.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add cache invalidation across workers to module API.

scripts-dev/mypy_synapse_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_method_signature_hook(
2929
self, fullname: str
3030
) -> Optional[Callable[[MethodSigContext], CallableType]]:
3131
if fullname.startswith(
32-
"synapse.util.caches.descriptors._CachedFunction.__call__"
32+
"synapse.util.caches.descriptors.CachedFunction.__call__"
3333
) or fullname.startswith(
3434
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
3535
):
@@ -38,7 +38,7 @@ def get_method_signature_hook(
3838

3939

4040
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
41-
"""Fixes the `_CachedFunction.__call__` signature to be correct.
41+
"""Fixes the `CachedFunction.__call__` signature to be correct.
4242
4343
It already has *almost* the correct signature, except:
4444

synapse/module_api/__init__.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
)
126126
from synapse.util import Clock
127127
from synapse.util.async_helpers import maybe_awaitable
128-
from synapse.util.caches.descriptors import cached
128+
from synapse.util.caches.descriptors import CachedFunction, cached
129129
from synapse.util.frozenutils import freeze
130130

131131
if TYPE_CHECKING:
@@ -836,6 +836,37 @@ def run_db_interaction(
836836
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
837837
)
838838

839+
def register_cached_function(self, cached_func: CachedFunction) -> None:
840+
"""Register a cached function that should be invalidated across workers.
841+
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
842+
however invalidation that needs to go to other workers needs to call `invalidate_cache`
843+
on the module API instead.
844+
845+
Args:
846+
cached_function: The cached function that will be registered to receive invalidation
847+
locally and from other workers.
848+
"""
849+
self._store.register_external_cached_function(
850+
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
851+
)
852+
853+
async def invalidate_cache(
854+
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
855+
) -> None:
856+
"""Invalidate a cache entry of a cached function across workers. The cached function
857+
needs to be registered on all workers first with `register_cached_function`.
858+
859+
Args:
860+
cached_function: The cached function that needs an invalidation
861+
keys: keys of the entry to invalidate, usually matching the arguments of the
862+
cached function.
863+
"""
864+
cached_func.invalidate(keys)
865+
await self._store.send_invalidation_to_replication(
866+
f"{cached_func.__module__}.{cached_func.__name__}",
867+
keys,
868+
)
869+
839870
async def complete_sso_login_async(
840871
self,
841872
registered_user_id: str,

synapse/storage/_base.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# limitations under the License.
1616
import logging
1717
from abc import ABCMeta
18-
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
18+
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
1919

2020
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
2121
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
2222
from synapse.types import get_domain_from_id
2323
from synapse.util import json_decoder
24+
from synapse.util.caches.descriptors import CachedFunction
2425

2526
if TYPE_CHECKING:
2627
from synapse.server import HomeServer
@@ -47,6 +48,8 @@ def __init__(
4748
self.database_engine = database.engine
4849
self.db_pool = database
4950

51+
self.external_cached_functions: Dict[str, CachedFunction] = {}
52+
5053
def process_replication_rows(
5154
self,
5255
stream_name: str,
@@ -95,7 +98,7 @@ def _invalidate_state_caches(
9598

9699
def _attempt_to_invalidate_cache(
97100
self, cache_name: str, key: Optional[Collection[Any]]
98-
) -> None:
101+
) -> bool:
99102
"""Attempts to invalidate the cache of the given name, ignoring if the
100103
cache doesn't exist. Mainly used for invalidating caches on workers,
101104
where they may not have the cache.
@@ -113,9 +116,12 @@ def _attempt_to_invalidate_cache(
113116
try:
114117
cache = getattr(self, cache_name)
115118
except AttributeError:
116-
# We probably haven't pulled in the cache in this worker,
117-
# which is fine.
118-
return
119+
# Check if an externally defined module cache has been registered
120+
cache = self.external_cached_functions.get(cache_name)
121+
if not cache:
122+
# We probably haven't pulled in the cache in this worker,
123+
# which is fine.
124+
return False
119125

120126
if key is None:
121127
cache.invalidate_all()
@@ -125,6 +131,13 @@ def _attempt_to_invalidate_cache(
125131
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
126132
invalidate_method(tuple(key))
127133

134+
return True
135+
136+
def register_external_cached_function(
137+
self, cache_name: str, func: CachedFunction
138+
) -> None:
139+
self.external_cached_functions[cache_name] = func
140+
128141

129142
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
130143
"""

synapse/storage/databases/main/cache.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434
from synapse.storage.engines import PostgresEngine
3535
from synapse.storage.util.id_generators import MultiWriterIdGenerator
36-
from synapse.util.caches.descriptors import _CachedFunction
36+
from synapse.util.caches.descriptors import CachedFunction
3737
from synapse.util.iterutils import batch_iter
3838

3939
if TYPE_CHECKING:
@@ -269,17 +269,15 @@ async def invalidate_cache_and_stream(
269269
return
270270

271271
cache_func.invalidate(keys)
272-
await self.db_pool.runInteraction(
273-
"invalidate_cache_and_stream",
274-
self._send_invalidation_to_replication,
272+
await self.send_invalidation_to_replication(
275273
cache_func.__name__,
276274
keys,
277275
)
278276

279277
def _invalidate_cache_and_stream(
280278
self,
281279
txn: LoggingTransaction,
282-
cache_func: _CachedFunction,
280+
cache_func: CachedFunction,
283281
keys: Tuple[Any, ...],
284282
) -> None:
285283
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ def _invalidate_cache_and_stream(
293291
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
294292

295293
def _invalidate_all_cache_and_stream(
296-
self, txn: LoggingTransaction, cache_func: _CachedFunction
294+
self, txn: LoggingTransaction, cache_func: CachedFunction
297295
) -> None:
298296
"""Invalidates the entire cache and adds it to the cache stream so slaves
299297
will know to invalidate their caches.
@@ -334,6 +332,16 @@ def _invalidate_state_caches_and_stream(
334332
txn, CURRENT_STATE_CACHE_NAME, [room_id]
335333
)
336334

335+
async def send_invalidation_to_replication(
336+
self, cache_name: str, keys: Optional[Collection[Any]]
337+
) -> None:
338+
await self.db_pool.runInteraction(
339+
"send_invalidation_to_replication",
340+
self._send_invalidation_to_replication,
341+
cache_name,
342+
keys,
343+
)
344+
337345
def _send_invalidation_to_replication(
338346
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
339347
) -> None:

synapse/util/caches/descriptors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
F = TypeVar("F", bound=Callable[..., Any])
5454

5555

56-
class _CachedFunction(Generic[F]):
56+
class CachedFunction(Generic[F]):
5757
invalidate: Any = None
5858
invalidate_all: Any = None
5959
prefill: Any = None
@@ -242,7 +242,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:
242242

243243
return ret2
244244

245-
wrapped = cast(_CachedFunction, _wrapped)
245+
wrapped = cast(CachedFunction, _wrapped)
246246
wrapped.cache = cache
247247
obj.__dict__[self.name] = wrapped
248248

@@ -363,7 +363,7 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:
363363

364364
return make_deferred_yieldable(ret)
365365

366-
wrapped = cast(_CachedFunction, _wrapped)
366+
wrapped = cast(CachedFunction, _wrapped)
367367

368368
if self.num_args == 1:
369369
assert not self.tree
@@ -572,7 +572,7 @@ def cached(
572572
iterable: bool = False,
573573
prune_unread_entries: bool = True,
574574
name: Optional[str] = None,
575-
) -> Callable[[F], _CachedFunction[F]]:
575+
) -> Callable[[F], CachedFunction[F]]:
576576
func = lambda orig: DeferredCacheDescriptor(
577577
orig,
578578
max_entries=max_entries,
@@ -585,7 +585,7 @@ def cached(
585585
name=name,
586586
)
587587

588-
return cast(Callable[[F], _CachedFunction[F]], func)
588+
return cast(Callable[[F], CachedFunction[F]], func)
589589

590590

591591
def cachedList(
@@ -594,7 +594,7 @@ def cachedList(
594594
list_name: str,
595595
num_args: Optional[int] = None,
596596
name: Optional[str] = None,
597-
) -> Callable[[F], _CachedFunction[F]]:
597+
) -> Callable[[F], CachedFunction[F]]:
598598
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
599599
600600
Used to do batch lookups for an already created cache. One of the arguments
@@ -631,7 +631,7 @@ def batch_do_something(self, first_arg, second_args):
631631
name=name,
632632
)
633633

634-
return cast(Callable[[F], _CachedFunction[F]], func)
634+
return cast(Callable[[F], CachedFunction[F]], func)
635635

636636

637637
def _get_cache_key_builder(
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2022 The Matrix.org Foundation C.I.C.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
16+
import synapse
17+
from synapse.module_api import cached
18+
19+
from tests.replication._base import BaseMultiWorkerStreamTestCase
20+
21+
logger = logging.getLogger(__name__)
22+
23+
FIRST_VALUE = "one"
24+
SECOND_VALUE = "two"
25+
26+
KEY = "mykey"
27+
28+
29+
class TestCache:
30+
current_value = FIRST_VALUE
31+
32+
@cached()
33+
async def cached_function(self, user_id: str) -> str:
34+
return self.current_value
35+
36+
37+
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
38+
servlets = [
39+
synapse.rest.admin.register_servlets,
40+
]
41+
42+
def test_module_cache_full_invalidation(self):
43+
main_cache = TestCache()
44+
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
45+
46+
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
47+
48+
worker_cache = TestCache()
49+
worker_hs.get_module_api().register_cached_function(
50+
worker_cache.cached_function
51+
)
52+
53+
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
54+
self.assertEqual(
55+
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
56+
)
57+
58+
main_cache.current_value = SECOND_VALUE
59+
worker_cache.current_value = SECOND_VALUE
60+
# No invalidation yet, should return the cached value on both the main process and the worker
61+
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
62+
self.assertEqual(
63+
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
64+
)
65+
66+
# Full invalidation on the main process, should be replicated on the worker that
67+
# should returned the updated value too
68+
self.get_success(
69+
self.hs.get_module_api().invalidate_cache(
70+
main_cache.cached_function, (KEY,)
71+
)
72+
)
73+
74+
self.assertEqual(
75+
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
76+
)
77+
self.assertEqual(
78+
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
79+
)

0 commit comments

Comments
 (0)