Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1781bbe

Browse files
authored
Add type hints to response cache. (#8507)
1 parent 66ac4b1 commit 1781bbe

File tree

9 files changed

+48
-34
lines changed

9 files changed

+48
-34
lines changed

changelog.d/8507.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to various parts of the code base.

mypy.ini

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ files =
6565
synapse/types.py,
6666
synapse/util/async_helpers.py,
6767
synapse/util/caches/descriptors.py,
68+
synapse/util/caches/response_cache.py,
6869
synapse/util/caches/stream_change_cache.py,
6970
synapse/util/metrics.py,
7071
tests/replication,

synapse/appservice/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import logging
1616
import urllib
17-
from typing import TYPE_CHECKING, Optional
17+
from typing import TYPE_CHECKING, Optional, Tuple
1818

1919
from prometheus_client import Counter
2020

@@ -93,7 +93,7 @@ def __init__(self, hs):
9393

9494
self.protocol_meta_cache = ResponseCache(
9595
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
96-
)
96+
) # type: ResponseCache[Tuple[str, str]]
9797

9898
async def query_user(self, service, user_id):
9999
if service.url is None:

synapse/federation/federation_server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,20 @@ def __init__(self, hs):
116116
# We cache results for transaction with the same ID
117117
self._transaction_resp_cache = ResponseCache(
118118
hs, "fed_txn_handler", timeout_ms=30000
119-
)
119+
) # type: ResponseCache[Tuple[str, str]]
120120

121121
self.transaction_actions = TransactionActions(self.store)
122122

123123
self.registry = hs.get_federation_registry()
124124

125125
# We cache responses to state queries, as they take a while and often
126126
# come in waves.
127-
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
127+
self._state_resp_cache = ResponseCache(
128+
hs, "state_resp", timeout_ms=30000
129+
) # type: ResponseCache[Tuple[str, str]]
128130
self._state_ids_resp_cache = ResponseCache(
129131
hs, "state_ids_resp", timeout_ms=30000
130-
)
132+
) # type: ResponseCache[Tuple[str, str]]
131133

132134
self._federation_metrics_domains = (
133135
hs.get_config().federation.federation_metrics_domains

synapse/handlers/initial_sync.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, Optional, Tuple
1818

1919
from twisted.internet import defer
2020

@@ -47,12 +47,14 @@ def __init__(self, hs: "HomeServer"):
4747
self.state = hs.get_state_handler()
4848
self.clock = hs.get_clock()
4949
self.validator = EventValidator()
50-
self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
50+
self.snapshot_cache = ResponseCache(
51+
hs, "initial_sync_cache"
52+
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
5153
self._event_serializer = hs.get_event_client_serializer()
5254
self.storage = hs.get_storage()
5355
self.state_store = self.storage.state
5456

55-
def snapshot_all_rooms(
57+
async def snapshot_all_rooms(
5658
self,
5759
user_id: str,
5860
pagin_config: PaginationConfig,
@@ -84,7 +86,7 @@ def snapshot_all_rooms(
8486
include_archived,
8587
)
8688

87-
return self.snapshot_cache.wrap(
89+
return await self.snapshot_cache.wrap(
8890
key,
8991
self._snapshot_all_rooms,
9092
user_id,

synapse/handlers/room.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self, hs: "HomeServer"):
120120
# subsequent requests
121121
self._upgrade_response_cache = ResponseCache(
122122
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
123-
)
123+
) # type: ResponseCache[Tuple[str, str]]
124124
self._server_notices_mxid = hs.config.server_notices_mxid
125125

126126
self.third_party_event_rules = hs.get_third_party_event_rules()

synapse/handlers/sync.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def __init__(self, hs: "HomeServer"):
243243
self.presence_handler = hs.get_presence_handler()
244244
self.event_sources = hs.get_event_sources()
245245
self.clock = hs.get_clock()
246-
self.response_cache = ResponseCache(hs, "sync")
246+
self.response_cache = ResponseCache(
247+
hs, "sync"
248+
) # type: ResponseCache[Tuple[Any, ...]]
247249
self.state = hs.get_state_handler()
248250
self.auth = hs.get_auth()
249251
self.storage = hs.get_storage()

synapse/replication/http/_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, hs):
9292
if self.CACHE:
9393
self.response_cache = ResponseCache(
9494
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
95-
)
95+
) # type: ResponseCache[str]
9696

9797
# We reserve `instance_name` as a parameter to sending requests, so we
9898
# assert here that sub classes don't try and use the name.

synapse/util/caches/response_cache.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,47 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import logging
16+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
1617

1718
from twisted.internet import defer
1819

1920
from synapse.logging.context import make_deferred_yieldable, run_in_background
2021
from synapse.util.async_helpers import ObservableDeferred
2122
from synapse.util.caches import register_cache
2223

24+
if TYPE_CHECKING:
25+
from synapse.app.homeserver import HomeServer
26+
2327
logger = logging.getLogger(__name__)
2428

29+
T = TypeVar("T")
30+
2531

26-
class ResponseCache:
32+
class ResponseCache(Generic[T]):
2733
"""
2834
This caches a deferred response. Until the deferred completes it will be
2935
returned from the cache. This means that if the client retries the request
3036
while the response is still being computed, that original response will be
3137
used rather than trying to compute a new response.
3238
"""
3339

34-
def __init__(self, hs, name, timeout_ms=0):
35-
self.pending_result_cache = {} # Requests that haven't finished yet.
40+
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
41+
# Requests that haven't finished yet.
42+
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
3643

3744
self.clock = hs.get_clock()
3845
self.timeout_sec = timeout_ms / 1000.0
3946

4047
self._name = name
4148
self._metrics = register_cache("response_cache", name, self, resizable=False)
4249

43-
def size(self):
50+
def size(self) -> int:
4451
return len(self.pending_result_cache)
4552

46-
def __len__(self):
53+
def __len__(self) -> int:
4754
return self.size()
4855

49-
def get(self, key):
56+
def get(self, key: T) -> Optional[defer.Deferred]:
5057
"""Look up the given key.
5158
5259
Can return either a new Deferred (which also doesn't follow the synapse
@@ -58,12 +65,11 @@ def get(self, key):
5865
from an absent cache entry.
5966
6067
Args:
61-
key (hashable):
68+
key: key to get/set in the cache
6269
6370
Returns:
64-
twisted.internet.defer.Deferred|None|E: None if there is no entry
65-
for this key; otherwise either a deferred result or the result
66-
itself.
71+
None if there is no entry for this key; otherwise a deferred which
72+
resolves to the result.
6773
"""
6874
result = self.pending_result_cache.get(key)
6975
if result is not None:
@@ -73,7 +79,7 @@ def get(self, key):
7379
self._metrics.inc_misses()
7480
return None
7581

76-
def set(self, key, deferred):
82+
def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
7783
"""Set the entry for the given key to the given deferred.
7884
7985
*deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -85,12 +91,11 @@ def set(self, key, deferred):
8591
result. You will probably want to make_deferred_yieldable the result.
8692
8793
Args:
88-
key (hashable):
89-
deferred (twisted.internet.defer.Deferred[T):
94+
key: key to get/set in the cache
95+
deferred: The deferred which resolves to the result.
9096
9197
Returns:
92-
twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
93-
result.
98+
A new deferred which resolves to the actual result.
9499
"""
95100
result = ObservableDeferred(deferred, consumeErrors=True)
96101
self.pending_result_cache[key] = result
@@ -107,7 +112,9 @@ def remove(r):
107112
result.addBoth(remove)
108113
return result.observe()
109114

110-
def wrap(self, key, callback, *args, **kwargs):
115+
def wrap(
116+
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
117+
) -> defer.Deferred:
111118
"""Wrap together a *get* and *set* call, taking care of logcontexts
112119
113120
First looks up the key in the cache, and if it is present makes it
@@ -118,29 +125,28 @@ def wrap(self, key, callback, *args, **kwargs):
118125
119126
Example usage:
120127
121-
@defer.inlineCallbacks
122-
def handle_request(request):
128+
async def handle_request(request):
123129
# etc
124130
return result
125131
126-
result = yield response_cache.wrap(
132+
result = await response_cache.wrap(
127133
key,
128134
handle_request,
129135
request,
130136
)
131137
132138
Args:
133-
key (hashable): key to get/set in the cache
139+
key: key to get/set in the cache
134140
135-
callback (callable): function to call if the key is not found in
141+
callback: function to call if the key is not found in
136142
the cache
137143
138144
*args: positional parameters to pass to the callback, if it is used
139145
140146
**kwargs: named parameters to pass to the callback, if it is used
141147
142148
Returns:
143-
twisted.internet.defer.Deferred: yieldable result
149+
Deferred which resolves to the result
144150
"""
145151
result = self.get(key)
146152
if not result:

0 commit comments

Comments
 (0)