Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 63d90f1

Browse files
authored
Add missing type hints to synapse.replication.http. (#11856)
1 parent 8b309ad commit 63d90f1

File tree

13 files changed

+258
-162
lines changed

13 files changed

+258
-162
lines changed

changelog.d/11856.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to replication code.

synapse/replication/http/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, hs: "HomeServer"):
4040
super().__init__(hs, canonical_json=False, extract_context=True)
4141
self.register_servlets(hs)
4242

43-
def register_servlets(self, hs: "HomeServer"):
43+
def register_servlets(self, hs: "HomeServer") -> None:
4444
send_event.register_servlets(hs, self)
4545
federation.register_servlets(hs, self)
4646
presence.register_servlets(hs, self)

synapse/replication/http/_base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
import abc
1616
import logging
1717
import re
18-
import urllib
18+
import urllib.parse
1919
from inspect import signature
2020
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
2121

2222
from prometheus_client import Counter, Gauge
2323

24+
from twisted.web.server import Request
25+
2426
from synapse.api.errors import HttpResponseException, SynapseError
2527
from synapse.http import RequestTimedOutError
28+
from synapse.http.server import HttpServer
2629
from synapse.logging import opentracing
2730
from synapse.logging.opentracing import trace
31+
from synapse.types import JsonDict
2832
from synapse.util.caches.response_cache import ResponseCache
2933
from synapse.util.stringutils import random_string
3034

@@ -113,10 +117,12 @@ def __init__(self, hs: "HomeServer"):
113117
if hs.config.worker.worker_replication_secret:
114118
self._replication_secret = hs.config.worker.worker_replication_secret
115119

116-
def _check_auth(self, request) -> None:
120+
def _check_auth(self, request: Request) -> None:
117121
# Get the authorization header.
118122
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
119123

124+
if not auth_headers:
125+
raise RuntimeError("Missing Authorization header.")
120126
if len(auth_headers) > 1:
121127
raise RuntimeError("Too many Authorization headers.")
122128
parts = auth_headers[0].split(b" ")
@@ -129,7 +135,7 @@ def _check_auth(self, request) -> None:
129135
raise RuntimeError("Invalid Authorization header.")
130136

131137
@abc.abstractmethod
132-
async def _serialize_payload(**kwargs):
138+
async def _serialize_payload(**kwargs) -> JsonDict:
133139
"""Static method that is called when creating a request.
134140
135141
Concrete implementations should have explicit parameters (rather than
@@ -144,19 +150,20 @@ async def _serialize_payload(**kwargs):
144150
return {}
145151

146152
@abc.abstractmethod
147-
async def _handle_request(self, request, **kwargs):
153+
async def _handle_request(
154+
self, request: Request, **kwargs: Any
155+
) -> Tuple[int, JsonDict]:
148156
"""Handle incoming request.
149157
150158
This is called with the request object and PATH_ARGS.
151159
152160
Returns:
153-
tuple[int, dict]: HTTP status code and a JSON serialisable dict
154-
to be used as response body of request.
161+
HTTP status code and a JSON serialisable dict to be used as response
162+
body of request.
155163
"""
156-
pass
157164

158165
@classmethod
159-
def make_client(cls, hs: "HomeServer"):
166+
def make_client(cls, hs: "HomeServer") -> Callable:
160167
"""Create a client that makes requests.
161168
162169
Returns a callable that accepts the same parameters as
@@ -182,7 +189,7 @@ def make_client(cls, hs: "HomeServer"):
182189
)
183190

184191
@trace(opname="outgoing_replication_request")
185-
async def send_request(*, instance_name="master", **kwargs):
192+
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
186193
with outgoing_gauge.track_inprogress():
187194
if instance_name == local_instance_name:
188195
raise Exception("Trying to send HTTP request to self")
@@ -268,7 +275,7 @@ async def send_request(*, instance_name="master", **kwargs):
268275

269276
return send_request
270277

271-
def register(self, http_server):
278+
def register(self, http_server: HttpServer) -> None:
272279
"""Called by the server to register this as a handler to the
273280
appropriate path.
274281
"""
@@ -289,7 +296,9 @@ def register(self, http_server):
289296
self.__class__.__name__,
290297
)
291298

292-
async def _check_auth_and_handle(self, request, **kwargs):
299+
async def _check_auth_and_handle(
300+
self, request: Request, **kwargs: Any
301+
) -> Tuple[int, JsonDict]:
293302
"""Called on new incoming requests when caching is enabled. Checks
294303
if there is a cached response for the request and returns that,
295304
otherwise calls `_handle_request` and caches its response.

synapse/replication/http/account_data.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING
16+
from typing import TYPE_CHECKING, Tuple
1717

18+
from twisted.web.server import Request
19+
20+
from synapse.http.server import HttpServer
1821
from synapse.http.servlet import parse_json_object_from_request
1922
from synapse.replication.http._base import ReplicationEndpoint
23+
from synapse.types import JsonDict
2024

2125
if TYPE_CHECKING:
2226
from synapse.server import HomeServer
@@ -48,14 +52,18 @@ def __init__(self, hs: "HomeServer"):
4852
self.clock = hs.get_clock()
4953

5054
@staticmethod
51-
async def _serialize_payload(user_id, account_data_type, content):
55+
async def _serialize_payload( # type: ignore[override]
56+
user_id: str, account_data_type: str, content: JsonDict
57+
) -> JsonDict:
5258
payload = {
5359
"content": content,
5460
}
5561

5662
return payload
5763

58-
async def _handle_request(self, request, user_id, account_data_type):
64+
async def _handle_request( # type: ignore[override]
65+
self, request: Request, user_id: str, account_data_type: str
66+
) -> Tuple[int, JsonDict]:
5967
content = parse_json_object_from_request(request)
6068

6169
max_stream_id = await self.handler.add_account_data_for_user(
@@ -89,14 +97,18 @@ def __init__(self, hs: "HomeServer"):
8997
self.clock = hs.get_clock()
9098

9199
@staticmethod
92-
async def _serialize_payload(user_id, room_id, account_data_type, content):
100+
async def _serialize_payload( # type: ignore[override]
101+
user_id: str, room_id: str, account_data_type: str, content: JsonDict
102+
) -> JsonDict:
93103
payload = {
94104
"content": content,
95105
}
96106

97107
return payload
98108

99-
async def _handle_request(self, request, user_id, room_id, account_data_type):
109+
async def _handle_request( # type: ignore[override]
110+
self, request: Request, user_id: str, room_id: str, account_data_type: str
111+
) -> Tuple[int, JsonDict]:
100112
content = parse_json_object_from_request(request)
101113

102114
max_stream_id = await self.handler.add_account_data_to_room(
@@ -130,14 +142,18 @@ def __init__(self, hs: "HomeServer"):
130142
self.clock = hs.get_clock()
131143

132144
@staticmethod
133-
async def _serialize_payload(user_id, room_id, tag, content):
145+
async def _serialize_payload( # type: ignore[override]
146+
user_id: str, room_id: str, tag: str, content: JsonDict
147+
) -> JsonDict:
134148
payload = {
135149
"content": content,
136150
}
137151

138152
return payload
139153

140-
async def _handle_request(self, request, user_id, room_id, tag):
154+
async def _handle_request( # type: ignore[override]
155+
self, request: Request, user_id: str, room_id: str, tag: str
156+
) -> Tuple[int, JsonDict]:
141157
content = parse_json_object_from_request(request)
142158

143159
max_stream_id = await self.handler.add_tag_to_room(
@@ -173,11 +189,13 @@ def __init__(self, hs: "HomeServer"):
173189
self.clock = hs.get_clock()
174190

175191
@staticmethod
176-
async def _serialize_payload(user_id, room_id, tag):
192+
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
177193

178194
return {}
179195

180-
async def _handle_request(self, request, user_id, room_id, tag):
196+
async def _handle_request( # type: ignore[override]
197+
self, request: Request, user_id: str, room_id: str, tag: str
198+
) -> Tuple[int, JsonDict]:
181199
max_stream_id = await self.handler.remove_tag_from_room(
182200
user_id,
183201
room_id,
@@ -187,7 +205,7 @@ async def _handle_request(self, request, user_id, room_id, tag):
187205
return 200, {"max_stream_id": max_stream_id}
188206

189207

190-
def register_servlets(hs: "HomeServer", http_server):
208+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
191209
ReplicationUserAccountDataRestServlet(hs).register(http_server)
192210
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
193211
ReplicationAddTagRestServlet(hs).register(http_server)

synapse/replication/http/devices.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING
16+
from typing import TYPE_CHECKING, Tuple
1717

18+
from twisted.web.server import Request
19+
20+
from synapse.http.server import HttpServer
1821
from synapse.replication.http._base import ReplicationEndpoint
22+
from synapse.types import JsonDict
1923

2024
if TYPE_CHECKING:
2125
from synapse.server import HomeServer
@@ -63,14 +67,16 @@ def __init__(self, hs: "HomeServer"):
6367
self.clock = hs.get_clock()
6468

6569
@staticmethod
66-
async def _serialize_payload(user_id):
70+
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
6771
return {}
6872

69-
async def _handle_request(self, request, user_id):
73+
async def _handle_request( # type: ignore[override]
74+
self, request: Request, user_id: str
75+
) -> Tuple[int, JsonDict]:
7076
user_devices = await self.device_list_updater.user_device_resync(user_id)
7177

7278
return 200, user_devices
7379

7480

75-
def register_servlets(hs: "HomeServer", http_server):
81+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
7682
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)

0 commit comments

Comments
 (0)