Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit fd78beb

Browse files
committed
Implement MSC3706: partial state in /send_join response
1 parent 2676794 commit fd78beb

File tree

5 files changed

+249
-10
lines changed

5 files changed

+249
-10
lines changed

changelog.d/11967.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.

synapse/config/experimental.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
6161
self.msc2409_to_device_messages_enabled: bool = experimental.get(
6262
"msc2409_to_device_messages_enabled", False
6363
)
64+
65+
# MSC3706 (server-side support for partial state in /send_join responses)
66+
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

synapse/federation/federation_server.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Any,
2121
Awaitable,
2222
Callable,
23+
Collection,
2324
Dict,
2425
Iterable,
2526
List,
@@ -64,7 +65,7 @@
6465
ReplicationGetQueryRestServlet,
6566
)
6667
from synapse.storage.databases.main.lock import Lock
67-
from synapse.types import JsonDict, get_domain_from_id
68+
from synapse.types import JsonDict, StateMap, get_domain_from_id
6869
from synapse.util import json_decoder, unwrapFirstError
6970
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
7071
from synapse.util.caches.response_cache import ResponseCache
@@ -645,27 +646,62 @@ async def on_invite_request(
645646
return {"event": ret_pdu.get_pdu_json(time_now)}
646647

647648
async def on_send_join_request(
648-
self, origin: str, content: JsonDict, room_id: str
649+
self,
650+
origin: str,
651+
content: JsonDict,
652+
room_id: str,
653+
caller_supports_partial_state: bool,
649654
) -> Dict[str, Any]:
650655
event, context = await self._on_send_membership_event(
651656
origin, content, Membership.JOIN, room_id
652657
)
653658

659+
state_event_ids: Collection[str]
660+
servers_in_room: Optional[Collection[str]] = None
654661
prev_state_ids = await context.get_prev_state_ids()
655-
state_ids = list(prev_state_ids.values())
656-
auth_chain = await self.store.get_auth_chain(room_id, state_ids)
657-
state = await self.store.get_events(state_ids)
658662

663+
if caller_supports_partial_state:
664+
state_event_ids = _get_event_ids_for_partial_state_join(
665+
event, prev_state_ids
666+
)
667+
668+
servers_in_room = await self.state.get_hosts_in_room_at_events(
669+
room_id, event_ids=event.prev_event_ids()
670+
)
671+
672+
else:
673+
state_event_ids = prev_state_ids.values()
674+
675+
auth_chain_event_ids = await self.store.get_auth_chain_ids(
676+
room_id, state_event_ids
677+
)
678+
679+
# if the caller has opted in, we can omit any auth_chain events which are
680+
# already in state_event_ids
681+
if caller_supports_partial_state:
682+
auth_chain_event_ids.difference_update(state_event_ids)
683+
684+
auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
685+
state_events = await self.store.get_events_as_list(state_event_ids)
686+
687+
# we try to do all the async stuff before this point, so that time_now is as
688+
# accurate as possible.
659689
time_now = self._clock.time_msec()
660-
event_json = event.get_pdu_json()
661-
return {
690+
event_json = event.get_pdu_json(time_now)
691+
resp = {
662692
# TODO Remove the unstable prefix when servers have updated.
663693
"org.matrix.msc3083.v2.event": event_json,
664694
"event": event_json,
665-
"state": [p.get_pdu_json(time_now) for p in state.values()],
666-
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
695+
"state": [p.get_pdu_json(time_now) for p in state_events],
696+
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
697+
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
667698
}
668699

700+
if caller_supports_partial_state:
701+
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
702+
703+
return resp
704+
669705
async def on_make_leave_request(
670706
self, origin: str, room_id: str, user_id: str
671707
) -> Dict[str, Any]:
@@ -1339,3 +1375,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
13391375
# error.
13401376
logger.warning("No handler registered for query type %s", query_type)
13411377
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
1378+
1379+
1380+
def _get_event_ids_for_partial_state_join(
1381+
join_event: EventBase,
1382+
prev_state_ids: StateMap[str],
1383+
) -> Collection[str]:
1384+
"""Calculate state to be retuned in a partial_state send_join
1385+
1386+
Args:
1387+
join_event: the join event being send_joined
1388+
prev_state_ids: the event ids of the state before the join
1389+
1390+
Returns:
1391+
the event ids to be returned
1392+
"""
1393+
1394+
# return all non-member events
1395+
state_event_ids = {
1396+
event_id
1397+
for (event_type, state_key), event_id in prev_state_ids.items()
1398+
if event_type != EventTypes.Member
1399+
}
1400+
1401+
# we also need the current state of the current user (it's going to
1402+
# be an auth event for the new join, so we may as well return it)
1403+
current_membership_event_id = prev_state_ids.get(
1404+
(EventTypes.Member, join_event.state_key)
1405+
)
1406+
if current_membership_event_id is not None:
1407+
state_event_ids.add(current_membership_event_id)
1408+
1409+
# TODO: return a few more members:
1410+
# - those with invites
1411+
# - those that are kicked? / banned
1412+
1413+
return state_event_ids

synapse/federation/transport/server/federation.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
412412

413413
PREFIX = FEDERATION_V2_PREFIX
414414

415+
def __init__(self, hs: "HomeServer", *args, **kwargs):
416+
super().__init__(hs, *args, **kwargs)
417+
self._msc3706_enabled = hs.config.experimental.msc3706_enabled
418+
415419
async def on_PUT(
416420
self,
417421
origin: str,
@@ -422,7 +426,15 @@ async def on_PUT(
422426
) -> Tuple[int, JsonDict]:
423427
# TODO(paul): assert that event_id parsed from path actually
424428
# match those given in content
425-
result = await self.handler.on_send_join_request(origin, content, room_id)
429+
430+
partial_state = False
431+
if self._msc3706_enabled:
432+
partial_state = parse_boolean_from_args(
433+
query, "org.matrix.msc3706.partial_state", default=False
434+
)
435+
result = await self.handler.on_send_join_request(
436+
origin, content, room_id, caller_supports_partial_state=partial_state
437+
)
426438
return 200, result
427439

428440

tests/federation/test_federation_server.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,24 @@
1616

1717
from parameterized import parameterized
1818

19+
from twisted.test.proto_helpers import MemoryReactor
20+
21+
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
22+
from synapse.config.server import DEFAULT_ROOM_VERSION
23+
from synapse.crypto.event_signing import (
24+
add_hashes_and_signatures,
25+
compute_event_signature,
26+
)
1927
from synapse.events import make_event_from_dict
2028
from synapse.federation.federation_server import server_matches_acl_event
2129
from synapse.rest import admin
2230
from synapse.rest.client import login, room
31+
from synapse.server import HomeServer
32+
from synapse.types import JsonDict
33+
from synapse.util import Clock
2334

2435
from tests import unittest
36+
from tests.unittest import override_config
2537

2638

2739
class FederationServerTests(unittest.FederatingHomeserverTestCase):
@@ -152,6 +164,145 @@ def test_needs_to_be_in_room(self):
152164
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
153165

154166

167+
class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
168+
servlets = [
169+
admin.register_servlets,
170+
room.register_servlets,
171+
login.register_servlets,
172+
]
173+
174+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
175+
super().prepare(reactor, clock, hs)
176+
177+
# create the room
178+
creator_user_id = self.register_user("kermit", "test")
179+
tok = self.login("kermit", "test")
180+
self._room_id = self.helper.create_room_as(
181+
room_creator=creator_user_id, tok=tok
182+
)
183+
184+
# a second member on the orgin HS
185+
second_member_user_id = self.register_user("fozzie", "bear")
186+
tok2 = self.login("fozzie", "bear")
187+
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
188+
189+
def _make_join(self, user_id) -> JsonDict:
190+
channel = self.make_signed_federation_request(
191+
"GET",
192+
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
193+
f"?ver={DEFAULT_ROOM_VERSION}",
194+
)
195+
self.assertEquals(channel.code, 200, channel.json_body)
196+
return channel.json_body
197+
198+
def test_send_join(self):
199+
"""happy-path test of send_join"""
200+
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
201+
join_result = self._make_join(joining_user)
202+
203+
join_event_dict = join_result["event"]
204+
add_hashes_and_signatures(
205+
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
206+
join_event_dict,
207+
signature_name=self.OTHER_SERVER_NAME,
208+
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
209+
)
210+
channel = self.make_signed_federation_request(
211+
"PUT",
212+
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
213+
content=join_event_dict,
214+
)
215+
self.assertEquals(channel.code, 200, channel.json_body)
216+
217+
# we should get complete room state back
218+
returned_state = [
219+
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
220+
]
221+
self.assertCountEqual(
222+
returned_state,
223+
[
224+
("m.room.create", ""),
225+
("m.room.power_levels", ""),
226+
("m.room.join_rules", ""),
227+
("m.room.history_visibility", ""),
228+
("m.room.member", "@kermit:test"),
229+
("m.room.member", "@fozzie:test"),
230+
# nb: *not* the joining user
231+
],
232+
)
233+
234+
# also check the auth chain
235+
returned_auth_chain_events = [
236+
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
237+
]
238+
self.assertCountEqual(
239+
returned_auth_chain_events,
240+
[
241+
("m.room.create", ""),
242+
("m.room.member", "@kermit:test"),
243+
("m.room.power_levels", ""),
244+
("m.room.join_rules", ""),
245+
],
246+
)
247+
248+
# the room should show that the new user is a member
249+
r = self.get_success(
250+
self.hs.get_state_handler().get_current_state(self._room_id)
251+
)
252+
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
253+
254+
@override_config({"experimental_features": {"msc3706_enabled": True}})
255+
def test_send_join_partial_state(self):
256+
"""When MSC3706 support is enabled, /send_join should return partial state"""
257+
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
258+
join_result = self._make_join(joining_user)
259+
260+
join_event_dict = join_result["event"]
261+
add_hashes_and_signatures(
262+
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
263+
join_event_dict,
264+
signature_name=self.OTHER_SERVER_NAME,
265+
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
266+
)
267+
channel = self.make_signed_federation_request(
268+
"PUT",
269+
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
270+
content=join_event_dict,
271+
)
272+
self.assertEquals(channel.code, 200, channel.json_body)
273+
274+
# expect a reduced room state
275+
returned_state = [
276+
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
277+
]
278+
self.assertCountEqual(
279+
returned_state,
280+
[
281+
("m.room.create", ""),
282+
("m.room.power_levels", ""),
283+
("m.room.join_rules", ""),
284+
("m.room.history_visibility", ""),
285+
],
286+
)
287+
288+
# the auth chain should not include anything already in "state"
289+
returned_auth_chain_events = [
290+
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
291+
]
292+
self.assertCountEqual(
293+
returned_auth_chain_events,
294+
[
295+
("m.room.member", "@kermit:test"),
296+
],
297+
)
298+
299+
# the room should show that the new user is a member
300+
r = self.get_success(
301+
self.hs.get_state_handler().get_current_state(self._room_id)
302+
)
303+
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
304+
305+
155306
def _create_acl_event(content):
156307
return make_event_from_dict(
157308
{

0 commit comments

Comments
 (0)