Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 47bc84d

Browse files
authored
Pass the Requester down to the HttpTransactionCache. (#15200)
1 parent 820f02b commit 47bc84d

File tree

6 files changed

+215
-129
lines changed

6 files changed

+215
-129
lines changed

changelog.d/15200.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.

synapse/rest/admin/server_notice_servlet.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from http import HTTPStatus
15-
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
15+
from typing import TYPE_CHECKING, Optional, Tuple
1616

1717
from synapse.api.constants import EventTypes
1818
from synapse.api.errors import NotFoundError, SynapseError
@@ -23,10 +23,10 @@
2323
parse_json_object_from_request,
2424
)
2525
from synapse.http.site import SynapseRequest
26-
from synapse.rest.admin import assert_requester_is_admin
27-
from synapse.rest.admin._base import admin_patterns
26+
from synapse.logging.opentracing import set_tag
27+
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
2828
from synapse.rest.client.transactions import HttpTransactionCache
29-
from synapse.types import JsonDict, UserID
29+
from synapse.types import JsonDict, Requester, UserID
3030

3131
if TYPE_CHECKING:
3232
from synapse.server import HomeServer
@@ -70,10 +70,13 @@ def register(self, json_resource: HttpServer) -> None:
7070
self.__class__.__name__,
7171
)
7272

73-
async def on_POST(
74-
self, request: SynapseRequest, txn_id: Optional[str] = None
73+
async def _do(
74+
self,
75+
request: SynapseRequest,
76+
requester: Requester,
77+
txn_id: Optional[str],
7578
) -> Tuple[int, JsonDict]:
76-
await assert_requester_is_admin(self.auth, request)
79+
await assert_user_is_admin(self.auth, requester)
7780
body = parse_json_object_from_request(request)
7881
assert_params_in_dict(body, ("user_id", "content"))
7982
event_type = body.get("type", EventTypes.Message)
@@ -106,9 +109,18 @@ async def on_POST(
106109

107110
return HTTPStatus.OK, {"event_id": event.event_id}
108111

109-
def on_PUT(
112+
async def on_POST(
113+
self,
114+
request: SynapseRequest,
115+
) -> Tuple[int, JsonDict]:
116+
requester = await self.auth.get_user_by_req(request)
117+
return await self._do(request, requester, None)
118+
119+
async def on_PUT(
110120
self, request: SynapseRequest, txn_id: str
111-
) -> Awaitable[Tuple[int, JsonDict]]:
112-
return self.txns.fetch_or_execute_request(
113-
request, self.on_POST, request, txn_id
121+
) -> Tuple[int, JsonDict]:
122+
requester = await self.auth.get_user_by_req(request)
123+
set_tag("txn_id", txn_id)
124+
return await self.txns.fetch_or_execute_request(
125+
request, requester, self._do, request, requester, txn_id
114126
)

synapse/rest/client/room.py

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from synapse.rest.client._base import client_patterns
5858
from synapse.rest.client.transactions import HttpTransactionCache
5959
from synapse.streams.config import PaginationConfig
60-
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
60+
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
6161
from synapse.types.state import StateFilter
6262
from synapse.util import json_decoder
6363
from synapse.util.cancellation import cancellable
@@ -151,15 +151,22 @@ def register(self, http_server: HttpServer) -> None:
151151
PATTERNS = "/createRoom"
152152
register_txn_path(self, PATTERNS, http_server)
153153

154-
def on_PUT(
154+
async def on_PUT(
155155
self, request: SynapseRequest, txn_id: str
156-
) -> Awaitable[Tuple[int, JsonDict]]:
156+
) -> Tuple[int, JsonDict]:
157+
requester = await self.auth.get_user_by_req(request)
157158
set_tag("txn_id", txn_id)
158-
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
159+
return await self.txns.fetch_or_execute_request(
160+
request, requester, self._do, request, requester
161+
)
159162

160163
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
161164
requester = await self.auth.get_user_by_req(request)
165+
return await self._do(request, requester)
162166

167+
async def _do(
168+
self, request: SynapseRequest, requester: Requester
169+
) -> Tuple[int, JsonDict]:
163170
room_id, _, _ = await self._room_creation_handler.create_room(
164171
requester, self.get_room_config(request)
165172
)
@@ -172,9 +179,9 @@ def get_room_config(self, request: Request) -> JsonDict:
172179

173180

174181
# TODO: Needs unit testing for generic events
175-
class RoomStateEventRestServlet(TransactionRestServlet):
182+
class RoomStateEventRestServlet(RestServlet):
176183
def __init__(self, hs: "HomeServer"):
177-
super().__init__(hs)
184+
super().__init__()
178185
self.event_creation_handler = hs.get_event_creation_handler()
179186
self.room_member_handler = hs.get_room_member_handler()
180187
self.message_handler = hs.get_message_handler()
@@ -324,16 +331,16 @@ def __init__(self, hs: "HomeServer"):
324331
def register(self, http_server: HttpServer) -> None:
325332
# /rooms/$roomid/send/$event_type[/$txn_id]
326333
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
327-
register_txn_path(self, PATTERNS, http_server, with_get=True)
334+
register_txn_path(self, PATTERNS, http_server)
328335

329-
async def on_POST(
336+
async def _do(
330337
self,
331338
request: SynapseRequest,
339+
requester: Requester,
332340
room_id: str,
333341
event_type: str,
334-
txn_id: Optional[str] = None,
342+
txn_id: Optional[str],
335343
) -> Tuple[int, JsonDict]:
336-
requester = await self.auth.get_user_by_req(request, allow_guest=True)
337344
content = parse_json_object_from_request(request)
338345

339346
event_dict: JsonDict = {
@@ -362,18 +369,30 @@ async def on_POST(
362369
set_tag("event_id", event_id)
363370
return 200, {"event_id": event_id}
364371

365-
def on_GET(
366-
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
367-
) -> Tuple[int, str]:
368-
return 200, "Not implemented"
372+
async def on_POST(
373+
self,
374+
request: SynapseRequest,
375+
room_id: str,
376+
event_type: str,
377+
) -> Tuple[int, JsonDict]:
378+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
379+
return await self._do(request, requester, room_id, event_type, None)
369380

370-
def on_PUT(
381+
async def on_PUT(
371382
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
372-
) -> Awaitable[Tuple[int, JsonDict]]:
383+
) -> Tuple[int, JsonDict]:
384+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
373385
set_tag("txn_id", txn_id)
374386

375-
return self.txns.fetch_or_execute_request(
376-
request, self.on_POST, request, room_id, event_type, txn_id
387+
return await self.txns.fetch_or_execute_request(
388+
request,
389+
requester,
390+
self._do,
391+
request,
392+
requester,
393+
room_id,
394+
event_type,
395+
txn_id,
377396
)
378397

379398

@@ -389,14 +408,13 @@ def register(self, http_server: HttpServer) -> None:
389408
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
390409
register_txn_path(self, PATTERNS, http_server)
391410

392-
async def on_POST(
411+
async def _do(
393412
self,
394413
request: SynapseRequest,
414+
requester: Requester,
395415
room_identifier: str,
396-
txn_id: Optional[str] = None,
416+
txn_id: Optional[str],
397417
) -> Tuple[int, JsonDict]:
398-
requester = await self.auth.get_user_by_req(request, allow_guest=True)
399-
400418
content = parse_json_object_from_request(request, allow_empty_body=True)
401419

402420
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@@ -420,22 +438,31 @@ async def on_POST(
420438

421439
return 200, {"room_id": room_id}
422440

423-
def on_PUT(
441+
async def on_POST(
442+
self,
443+
request: SynapseRequest,
444+
room_identifier: str,
445+
) -> Tuple[int, JsonDict]:
446+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
447+
return await self._do(request, requester, room_identifier, None)
448+
449+
async def on_PUT(
424450
self, request: SynapseRequest, room_identifier: str, txn_id: str
425-
) -> Awaitable[Tuple[int, JsonDict]]:
451+
) -> Tuple[int, JsonDict]:
452+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
426453
set_tag("txn_id", txn_id)
427454

428-
return self.txns.fetch_or_execute_request(
429-
request, self.on_POST, request, room_identifier, txn_id
455+
return await self.txns.fetch_or_execute_request(
456+
request, requester, self._do, request, requester, room_identifier, txn_id
430457
)
431458

432459

433460
# TODO: Needs unit testing
434-
class PublicRoomListRestServlet(TransactionRestServlet):
461+
class PublicRoomListRestServlet(RestServlet):
435462
PATTERNS = client_patterns("/publicRooms$", v1=True)
436463

437464
def __init__(self, hs: "HomeServer"):
438-
super().__init__(hs)
465+
super().__init__()
439466
self.hs = hs
440467
self.auth = hs.get_auth()
441468

@@ -907,22 +934,25 @@ def register(self, http_server: HttpServer) -> None:
907934
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
908935
register_txn_path(self, PATTERNS, http_server)
909936

910-
async def on_POST(
911-
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
912-
) -> Tuple[int, JsonDict]:
913-
requester = await self.auth.get_user_by_req(request, allow_guest=False)
914-
937+
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
915938
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
916939

917940
return 200, {}
918941

919-
def on_PUT(
942+
async def on_POST(
943+
self, request: SynapseRequest, room_id: str
944+
) -> Tuple[int, JsonDict]:
945+
requester = await self.auth.get_user_by_req(request, allow_guest=False)
946+
return await self._do(requester, room_id)
947+
948+
async def on_PUT(
920949
self, request: SynapseRequest, room_id: str, txn_id: str
921-
) -> Awaitable[Tuple[int, JsonDict]]:
950+
) -> Tuple[int, JsonDict]:
951+
requester = await self.auth.get_user_by_req(request, allow_guest=False)
922952
set_tag("txn_id", txn_id)
923953

924-
return self.txns.fetch_or_execute_request(
925-
request, self.on_POST, request, room_id, txn_id
954+
return await self.txns.fetch_or_execute_request(
955+
request, requester, self._do, requester, room_id
926956
)
927957

928958

@@ -941,15 +971,14 @@ def register(self, http_server: HttpServer) -> None:
941971
)
942972
register_txn_path(self, PATTERNS, http_server)
943973

944-
async def on_POST(
974+
async def _do(
945975
self,
946976
request: SynapseRequest,
977+
requester: Requester,
947978
room_id: str,
948979
membership_action: str,
949-
txn_id: Optional[str] = None,
980+
txn_id: Optional[str],
950981
) -> Tuple[int, JsonDict]:
951-
requester = await self.auth.get_user_by_req(request, allow_guest=True)
952-
953982
if requester.is_guest and membership_action not in {
954983
Membership.JOIN,
955984
Membership.LEAVE,
@@ -1014,13 +1043,30 @@ async def on_POST(
10141043

10151044
return 200, return_value
10161045

1017-
def on_PUT(
1046+
async def on_POST(
1047+
self,
1048+
request: SynapseRequest,
1049+
room_id: str,
1050+
membership_action: str,
1051+
) -> Tuple[int, JsonDict]:
1052+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
1053+
return await self._do(request, requester, room_id, membership_action, None)
1054+
1055+
async def on_PUT(
10181056
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
1019-
) -> Awaitable[Tuple[int, JsonDict]]:
1057+
) -> Tuple[int, JsonDict]:
1058+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
10201059
set_tag("txn_id", txn_id)
10211060

1022-
return self.txns.fetch_or_execute_request(
1023-
request, self.on_POST, request, room_id, membership_action, txn_id
1061+
return await self.txns.fetch_or_execute_request(
1062+
request,
1063+
requester,
1064+
self._do,
1065+
request,
1066+
requester,
1067+
room_id,
1068+
membership_action,
1069+
txn_id,
10241070
)
10251071

10261072

@@ -1036,14 +1082,14 @@ def register(self, http_server: HttpServer) -> None:
10361082
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
10371083
register_txn_path(self, PATTERNS, http_server)
10381084

1039-
async def on_POST(
1085+
async def _do(
10401086
self,
10411087
request: SynapseRequest,
1088+
requester: Requester,
10421089
room_id: str,
10431090
event_id: str,
1044-
txn_id: Optional[str] = None,
1091+
txn_id: Optional[str],
10451092
) -> Tuple[int, JsonDict]:
1046-
requester = await self.auth.get_user_by_req(request)
10471093
content = parse_json_object_from_request(request)
10481094

10491095
try:
@@ -1094,13 +1140,23 @@ async def on_POST(
10941140
set_tag("event_id", event_id)
10951141
return 200, {"event_id": event_id}
10961142

1097-
def on_PUT(
1143+
async def on_POST(
1144+
self,
1145+
request: SynapseRequest,
1146+
room_id: str,
1147+
event_id: str,
1148+
) -> Tuple[int, JsonDict]:
1149+
requester = await self.auth.get_user_by_req(request)
1150+
return await self._do(request, requester, room_id, event_id, None)
1151+
1152+
async def on_PUT(
10981153
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
1099-
) -> Awaitable[Tuple[int, JsonDict]]:
1154+
) -> Tuple[int, JsonDict]:
1155+
requester = await self.auth.get_user_by_req(request)
11001156
set_tag("txn_id", txn_id)
11011157

1102-
return self.txns.fetch_or_execute_request(
1103-
request, self.on_POST, request, room_id, event_id, txn_id
1158+
return await self.txns.fetch_or_execute_request(
1159+
request, requester, self._do, request, requester, room_id, event_id, txn_id
11041160
)
11051161

11061162

@@ -1224,7 +1280,6 @@ def register_txn_path(
12241280
servlet: RestServlet,
12251281
regex_string: str,
12261282
http_server: HttpServer,
1227-
with_get: bool = False,
12281283
) -> None:
12291284
"""Registers a transaction-based path.
12301285
@@ -1236,7 +1291,6 @@ def register_txn_path(
12361291
regex_string: The regex string to register. Must NOT have a
12371292
trailing $ as this string will be appended to.
12381293
http_server: The http_server to register paths with.
1239-
with_get: True to also register respective GET paths for the PUTs.
12401294
"""
12411295
on_POST = getattr(servlet, "on_POST", None)
12421296
on_PUT = getattr(servlet, "on_PUT", None)
@@ -1254,18 +1308,6 @@ def register_txn_path(
12541308
on_PUT,
12551309
servlet.__class__.__name__,
12561310
)
1257-
on_GET = getattr(servlet, "on_GET", None)
1258-
if with_get:
1259-
if on_GET is None:
1260-
raise RuntimeError(
1261-
"register_txn_path called with with_get = True, but no on_GET method exists"
1262-
)
1263-
http_server.register_paths(
1264-
"GET",
1265-
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
1266-
on_GET,
1267-
servlet.__class__.__name__,
1268-
)
12691311

12701312

12711313
class TimestampLookupRestServlet(RestServlet):

0 commit comments

Comments
 (0)