Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Speed up @cachedList #13591

Merged
merged 10 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 106 additions & 89 deletions synapse/util/caches/deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import enum
import threading
from typing import (
Any,
Callable,
Collection,
Dict,
Generic,
Iterable,
MutableMapping,
Optional,
Set,
Sized,
TypeVar,
Union,
Expand All @@ -31,7 +36,6 @@
from prometheus_client import Gauge

from twisted.internet import defer
from twisted.python import failure
from twisted.python.failure import Failure

from synapse.util.async_helpers import ObservableDeferred
Expand Down Expand Up @@ -159,15 +163,16 @@ def get(
Raises:
KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
val.add_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
return val.deferred(key)

callbacks = (callback,) if callback else ()

val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
Expand Down Expand Up @@ -218,84 +223,70 @@ def set(
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")

callbacks = [callback] if callback else []
self.check_thread()

existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache.pop(key, None)

# XXX: why don't we invalidate the entry in `self.cache` yet?

# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, cast(VT, result), callbacks)
return value

# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
entry = CacheEntrySingle[KT, VT](value)
entry.add_callback(key, callback)
self._pending_deferred_cache[key] = entry
deferred = entry.deferred(key).addCallbacks(
self._set_completed_callback,
self._error_callback,
callbackArgs=(entry, key),
errbackArgs=(entry, key),
)

observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
# we return a new Deferred which will be called before any subsequent observers.
return deferred

def _set_completed_callback(
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
"""Called when a deferred is completed."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return value

self._pending_deferred_cache[key] = entry
self.cache.set(key, value, entry.get_callbacks(key))

def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.

Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True

# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

return False

def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()

def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()

# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return value

# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
def _error_callback(
self,
failure: Failure,
entry: "CacheEntry[KT, VT]",
key: KT,
) -> Failure:
"""Called when a deferred errors."""

# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return failure

for cb in entry.get_callbacks(key):
cb()

return failure

def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
callbacks = [callback] if callback else []
callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
self._pending_deferred_cache.pop(key, None)

def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
Expand All @@ -311,41 +302,67 @@ def invalidate(self, key: KT) -> None:
self.cache.del_multi(key)

# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# _pending_deferred_cache, which will (a) stop it being returned for
# future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)

# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
for cb in entry.get_callbacks(key):
cb()

def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
for key, entry in self._pending_deferred_cache.items():
for cb in entry.get_callbacks(key):
cb()

self._pending_deferred_cache.clear()


class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
"""Abstract class for entries in `DeferredCache[KT, VT]`"""

def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False

def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
@abc.abstractmethod
def deferred(self, key: KT) -> "defer.Deferred[VT]":
"""Get a deferred that a caller can wait on to get the value at the
given key"""
...

@abc.abstractmethod
def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:
"""Add an invalidation callback"""
...

@abc.abstractmethod
def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
"""Get all invalidation callbacks"""
...


class CacheEntrySingle(CacheEntry[KT, VT]):
"""An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry.
"""

__slots__ = ["_deferred", "_callbacks"]

def __init__(self, deferred: "defer.Deferred[VT]") -> None:
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
self._callbacks: Set[Callable[[], None]] = set()

def deferred(self, key: KT) -> "defer.Deferred[VT]":
return self._deferred.observe()

def add_callback(self, key: KT, callback: Optional[Callable[[], None]]) -> None:
if callback is None:
return

self._callbacks.add(callback)

def get_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks
3 changes: 3 additions & 0 deletions synapse/util/caches/treecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def pop(self, key, default=None):
def values(self):
return iterate_tree_cache_entry(self.root)

def items(self):
return iterate_tree_cache_items((), self.root)

def __len__(self) -> int:
return self.size

Expand Down