Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2c15d4c

Browse files
committed
Merge branch 'rav/faster_joins/03_get_room_state' into rav/faster_joins/05_optimise_get_state
2 parents db88ddd + df03367 commit 2c15d4c

File tree

7 files changed

+375
-17
lines changed

7 files changed

+375
-17
lines changed

changelog.d/12013.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.

synapse/federation/federation_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ async def _check_sigs_and_hash(
4747
) -> EventBase:
4848
"""Checks that event is correctly signed by the sending server.
4949
50+
Also checks the content hash, and redacts the event if there is a mismatch.
51+
52+
Also runs the event through the spam checker; if it fails, redacts the event
53+
and flags it as soft-failed.
54+
5055
Args:
5156
room_version: The room version of the PDU
5257
pdu: the event to be checked
@@ -55,7 +60,10 @@ async def _check_sigs_and_hash(
5560
* the original event if the checks pass
5661
* a redacted version of the event (if the signature
5762
matched but the hash did not)
58-
* throws a SynapseError if the signature check failed."""
63+
64+
Raises:
65+
SynapseError if the signature check failed.
66+
"""
5967
try:
6068
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
6169
except SynapseError as e:

synapse/federation/federation_client.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -419,26 +419,90 @@ async def get_room_state_ids(
419419

420420
return state_event_ids, auth_event_ids
421421

422+
async def get_room_state(
423+
self,
424+
destination: str,
425+
room_id: str,
426+
event_id: str,
427+
room_version: RoomVersion,
428+
) -> Tuple[List[EventBase], List[EventBase]]:
429+
"""Calls the /state endpoint to fetch the state at a particular point
430+
in the room.
431+
432+
Any invalid events (those with incorrect or unverifiable signatures or hashes)
433+
are filtered out from the response, and any duplicate events are removed.
434+
435+
(Size limits and other event-format checks are *not* performed.)
436+
437+
Note that the result is not ordered, so callers must be careful to process
438+
the events in an order that handles dependencies.
439+
440+
Returns:
441+
a tuple of (state events, auth events)
442+
"""
443+
result = await self.transport_layer.get_room_state(
444+
room_version,
445+
destination,
446+
room_id,
447+
event_id,
448+
)
449+
state_events = result.state
450+
auth_events = result.auth_events
451+
452+
# we may as well filter out any duplicates from the response, to save
453+
# processing them multiple times. (In particular, events may be present in
454+
# `auth_events` as well as `state`, which is redundant).
455+
#
456+
# We don't rely on the sort order of the events, so we can just stick them
457+
# in a dict.
458+
state_event_map = {event.event_id: event for event in state_events}
459+
auth_event_map = {
460+
event.event_id: event
461+
for event in auth_events
462+
if event.event_id not in state_event_map
463+
}
464+
465+
logger.info(
466+
"Processing from /state: %d state events, %d auth events",
467+
len(state_event_map),
468+
len(auth_event_map),
469+
)
470+
471+
valid_auth_events = await self._check_sigs_and_hash_and_fetch(
472+
destination, auth_event_map.values(), room_version
473+
)
474+
475+
valid_state_events = await self._check_sigs_and_hash_and_fetch(
476+
destination, state_event_map.values(), room_version
477+
)
478+
479+
return valid_state_events, valid_auth_events
480+
422481
async def _check_sigs_and_hash_and_fetch(
423482
self,
424483
origin: str,
425484
pdus: Collection[EventBase],
426485
room_version: RoomVersion,
427486
) -> List[EventBase]:
428-
"""Takes a list of PDUs and checks the signatures and hashes of each
429-
one. If a PDU fails its signature check then we check if we have it in
430-
the database and if not then request if from the originating server of
431-
that PDU.
487+
"""Checks the signatures and hashes of a list of events.
488+
489+
If a PDU fails its signature check then we check if we have it in
490+
the database, and if not then request it from the sender's server (if that
491+
is different from `origin`). If that still fails, the event is omitted from
492+
the returned list.
432493
433494
If a PDU fails its content hash check then it is redacted.
434495
435-
The given list of PDUs are not modified, instead the function returns
496+
Also runs each event through the spam checker; if it fails, redacts the event
497+
and flags it as soft-failed.
498+
499+
The given list of PDUs are not modified; instead the function returns
436500
a new list.
437501
438502
Args:
439-
origin
440-
pdu
441-
room_version
503+
origin: The server that sent us these events
504+
pdus: The events to be checked
505+
room_version: the version of the room these events are in
442506
443507
Returns:
444508
A list of PDUs that have valid signatures and hashes.
@@ -469,11 +533,16 @@ async def _check_sigs_and_hash_and_fetch_one(
469533
origin: str,
470534
room_version: RoomVersion,
471535
) -> Optional[EventBase]:
472-
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
473-
its signature check then we check if we have it in the database and if
474-
not then request if from the originating server of that PDU.
536+
"""Takes a PDU and checks its signatures and hashes.
537+
538+
If the PDU fails its signature check then we check if we have it in the
539+
database; if not, we then request it from sender's server (if that is not the
540+
same as `origin`). If that still fails, we return None.
541+
542+
If the PDU fails its content hash check, it is redacted.
475543
476-
If then PDU fails its content hash check then it is redacted.
544+
Also runs the event through the spam checker; if it fails, redacts the event
545+
and flags it as soft-failed.
477546
478547
Args:
479548
origin

synapse/federation/transport/client.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ def __init__(self, hs):
6565
async def get_room_state_ids(
6666
self, destination: str, room_id: str, event_id: str
6767
) -> JsonDict:
68-
"""Requests all state for a given room from the given server at the
69-
given event. Returns the state's event_id's
68+
"""Requests the IDs of all state for a given room at the given event.
7069
7170
Args:
7271
destination: The host name of the remote homeserver we want
7372
to get the state from.
74-
context: The name of the context we want the state of
73+
room_id: the room we want the state of
7574
event_id: The event we want the context at.
7675
7776
Returns:
@@ -87,6 +86,29 @@ async def get_room_state_ids(
8786
try_trailing_slash_on_400=True,
8887
)
8988

89+
async def get_room_state(
90+
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
91+
) -> "StateRequestResponse":
92+
"""Requests the full state for a given room at the given event.
93+
94+
Args:
95+
room_version: the version of the room (required to build the event objects)
96+
destination: The host name of the remote homeserver we want
97+
to get the state from.
98+
room_id: the room we want the state of
99+
event_id: The event we want the context at.
100+
101+
Returns:
102+
Results in a dict received from the remote homeserver.
103+
"""
104+
path = _create_v1_path("/state/%s", room_id)
105+
return await self.client.get_json(
106+
destination,
107+
path=path,
108+
args={"event_id": event_id},
109+
parser=_StateParser(room_version),
110+
)
111+
90112
async def get_event(
91113
self, destination: str, event_id: str, timeout: Optional[int] = None
92114
) -> JsonDict:
@@ -1284,6 +1306,14 @@ class SendJoinResponse:
12841306
servers_in_room: Optional[List[str]] = None
12851307

12861308

1309+
@attr.s(slots=True, auto_attribs=True)
1310+
class StateRequestResponse:
1311+
"""The parsed response of a `/state` request."""
1312+
1313+
auth_events: List[EventBase]
1314+
state: List[EventBase]
1315+
1316+
12871317
@ijson.coroutine
12881318
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
12891319
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
@@ -1411,3 +1441,37 @@ def finish(self) -> SendJoinResponse:
14111441
self._response.event_dict, self._room_version
14121442
)
14131443
return self._response
1444+
1445+
1446+
class _StateParser(ByteParser[StateRequestResponse]):
1447+
"""A parser for the response to `/state` requests.
1448+
1449+
Args:
1450+
room_version: The version of the room.
1451+
"""
1452+
1453+
CONTENT_TYPE = "application/json"
1454+
1455+
def __init__(self, room_version: RoomVersion):
1456+
self._response = StateRequestResponse([], [])
1457+
self._room_version = room_version
1458+
self._coros = [
1459+
ijson.items_coro(
1460+
_event_list_parser(room_version, self._response.state),
1461+
"pdus.item",
1462+
use_float=True,
1463+
),
1464+
ijson.items_coro(
1465+
_event_list_parser(room_version, self._response.auth_events),
1466+
"auth_chain.item",
1467+
use_float=True,
1468+
),
1469+
]
1470+
1471+
def write(self, data: bytes) -> int:
1472+
for c in self._coros:
1473+
c.send(data)
1474+
return len(data)
1475+
1476+
def finish(self) -> StateRequestResponse:
1477+
return self._response

synapse/http/matrixfederationclient.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ async def post_json(
958958
)
959959
return body
960960

961+
@overload
961962
async def get_json(
962963
self,
963964
destination: str,
@@ -967,7 +968,38 @@ async def get_json(
967968
timeout: Optional[int] = None,
968969
ignore_backoff: bool = False,
969970
try_trailing_slash_on_400: bool = False,
971+
parser: Literal[None] = None,
972+
max_response_size: Optional[int] = None,
970973
) -> Union[JsonDict, list]:
974+
...
975+
976+
@overload
977+
async def get_json(
978+
self,
979+
destination: str,
980+
path: str,
981+
args: Optional[QueryArgs] = ...,
982+
retry_on_dns_fail: bool = ...,
983+
timeout: Optional[int] = ...,
984+
ignore_backoff: bool = ...,
985+
try_trailing_slash_on_400: bool = ...,
986+
parser: ByteParser[T] = ...,
987+
max_response_size: Optional[int] = ...,
988+
) -> T:
989+
...
990+
991+
async def get_json(
992+
self,
993+
destination: str,
994+
path: str,
995+
args: Optional[QueryArgs] = None,
996+
retry_on_dns_fail: bool = True,
997+
timeout: Optional[int] = None,
998+
ignore_backoff: bool = False,
999+
try_trailing_slash_on_400: bool = False,
1000+
parser: Optional[ByteParser] = None,
1001+
max_response_size: Optional[int] = None,
1002+
):
9711003
"""GETs some json from the given host homeserver and path
9721004
9731005
Args:
@@ -992,6 +1024,13 @@ async def get_json(
9921024
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
9931025
response we should try appending a trailing slash to the end of
9941026
the request. Workaround for #3622 in Synapse <= v0.99.3.
1027+
1028+
parser: The parser to use to decode the response. Defaults to
1029+
parsing as JSON.
1030+
1031+
max_response_size: The maximum size to read from the response. If None,
1032+
uses the default.
1033+
9951034
Returns:
9961035
Succeeds when we get a 2xx HTTP response. The
9971036
result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ async def get_json(
10261065
else:
10271066
_sec_timeout = self.default_timeout
10281067

1068+
if parser is None:
1069+
parser = JsonParser()
1070+
10291071
body = await _handle_response(
1030-
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
1072+
self.reactor,
1073+
_sec_timeout,
1074+
request,
1075+
response,
1076+
start_ms,
1077+
parser=parser,
1078+
max_response_size=max_response_size,
10311079
)
10321080

10331081
return body

0 commit comments

Comments
 (0)