Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit fb46c79

Browse files
committed
Implement MSC3706: partial state in /send_join response
1 parent c81e1e7 commit fb46c79

File tree

5 files changed

+134
-4
lines changed

5 files changed

+134
-4
lines changed

changelog.d/11967.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.

synapse/config/experimental.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
6161
self.msc2409_to_device_messages_enabled: bool = experimental.get(
6262
"msc2409_to_device_messages_enabled", False
6363
)
64+
65+
# MSC3706 (server-side support for partial state in /send_join responses)
66+
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

synapse/federation/federation_server.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Any,
2121
Awaitable,
2222
Callable,
23+
Collection,
2324
Dict,
2425
Iterable,
2526
List,
@@ -64,7 +65,7 @@
6465
ReplicationGetQueryRestServlet,
6566
)
6667
from synapse.storage.databases.main.lock import Lock
67-
from synapse.types import JsonDict, get_domain_from_id
68+
from synapse.types import JsonDict, StateMap, get_domain_from_id
6869
from synapse.util import json_decoder, unwrapFirstError
6970
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
7071
from synapse.util.caches.response_cache import ResponseCache
@@ -645,19 +646,40 @@ async def on_invite_request(
645646
return {"event": ret_pdu.get_pdu_json(time_now)}
646647

647648
async def on_send_join_request(
648-
self, origin: str, content: JsonDict, room_id: str
649+
self,
650+
origin: str,
651+
content: JsonDict,
652+
room_id: str,
653+
caller_supports_partial_state: bool,
649654
) -> Dict[str, Any]:
650655
event, context = await self._on_send_membership_event(
651656
origin, content, Membership.JOIN, room_id
652657
)
653658

654659
prev_state_ids = await context.get_prev_state_ids()
655-
state_event_ids = prev_state_ids.values()
660+
661+
state_event_ids: Collection[str]
662+
servers_in_room: Optional[Collection[str]]
663+
if caller_supports_partial_state:
664+
state_event_ids = _get_event_ids_for_partial_state_join(
665+
event, prev_state_ids
666+
)
667+
servers_in_room = await self.state.get_hosts_in_room_at_events(
668+
room_id, event_ids=event.prev_event_ids()
669+
)
670+
else:
671+
state_event_ids = prev_state_ids.values()
672+
servers_in_room = None
656673

657674
auth_chain_event_ids = await self.store.get_auth_chain_ids(
658675
room_id, state_event_ids
659676
)
660677

678+
# if the caller has opted in, we can omit any auth_chain events which are
679+
# already in state_event_ids
680+
if caller_supports_partial_state:
681+
auth_chain_event_ids.difference_update(state_event_ids)
682+
661683
auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
662684
state_events = await self.store.get_events_as_list(state_event_ids)
663685

@@ -671,7 +693,12 @@ async def on_send_join_request(
671693
"event": event_json,
672694
"state": [p.get_pdu_json(time_now) for p in state_events],
673695
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
696+
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
674697
}
698+
699+
if caller_supports_partial_state:
700+
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
701+
675702
return resp
676703

677704
async def on_make_leave_request(
@@ -1347,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
13471374
# error.
13481375
logger.warning("No handler registered for query type %s", query_type)
13491376
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
1377+
1378+
1379+
def _get_event_ids_for_partial_state_join(
1380+
join_event: EventBase,
1381+
prev_state_ids: StateMap[str],
1382+
) -> Collection[str]:
1383+
"""Calculate state to be retuned in a partial_state send_join
1384+
1385+
Args:
1386+
join_event: the join event being send_joined
1387+
prev_state_ids: the event ids of the state before the join
1388+
1389+
Returns:
1390+
the event ids to be returned
1391+
"""
1392+
1393+
# return all non-member events
1394+
state_event_ids = {
1395+
event_id
1396+
for (event_type, state_key), event_id in prev_state_ids.items()
1397+
if event_type != EventTypes.Member
1398+
}
1399+
1400+
# we also need the current state of the current user (it's going to
1401+
# be an auth event for the new join, so we may as well return it)
1402+
current_membership_event_id = prev_state_ids.get(
1403+
(EventTypes.Member, join_event.state_key)
1404+
)
1405+
if current_membership_event_id is not None:
1406+
state_event_ids.add(current_membership_event_id)
1407+
1408+
# TODO: return a few more members:
1409+
# - those with invites
1410+
# - those that are kicked? / banned
1411+
1412+
return state_event_ids

synapse/federation/transport/server/federation.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
412412

413413
PREFIX = FEDERATION_V2_PREFIX
414414

415+
def __init__(self, hs: "HomeServer", *args, **kwargs):
416+
super().__init__(hs, *args, **kwargs)
417+
self._msc3706_enabled = hs.config.experimental.msc3706_enabled
418+
415419
async def on_PUT(
416420
self,
417421
origin: str,
@@ -422,7 +426,15 @@ async def on_PUT(
422426
) -> Tuple[int, JsonDict]:
423427
# TODO(paul): assert that event_id parsed from path actually
424428
# match those given in content
425-
result = await self.handler.on_send_join_request(origin, content, room_id)
429+
430+
partial_state = False
431+
if self._msc3706_enabled:
432+
partial_state = parse_boolean_from_args(
433+
query, "org.matrix.msc3706.partial_state", default=False
434+
)
435+
result = await self.handler.on_send_join_request(
436+
origin, content, room_id, caller_supports_partial_state=partial_state
437+
)
426438
return 200, result
427439

428440

tests/federation/test_federation_server.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,57 @@ def test_send_join(self):
251251
)
252252
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
253253

254+
@override_config({"experimental_features": {"msc3706_enabled": True}})
255+
def test_send_join_partial_state(self):
256+
"""When MSC3706 support is enabled, /send_join should return partial state"""
257+
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
258+
join_result = self._make_join(joining_user)
259+
260+
join_event_dict = join_result["event"]
261+
add_hashes_and_signatures(
262+
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
263+
join_event_dict,
264+
signature_name=self.OTHER_SERVER_NAME,
265+
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
266+
)
267+
channel = self.make_signed_federation_request(
268+
"PUT",
269+
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
270+
content=join_event_dict,
271+
)
272+
self.assertEquals(channel.code, 200, channel.json_body)
273+
274+
# expect a reduced room state
275+
returned_state = [
276+
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
277+
]
278+
self.assertCountEqual(
279+
returned_state,
280+
[
281+
("m.room.create", ""),
282+
("m.room.power_levels", ""),
283+
("m.room.join_rules", ""),
284+
("m.room.history_visibility", ""),
285+
],
286+
)
287+
288+
# the auth chain should not include anything already in "state"
289+
returned_auth_chain_events = [
290+
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
291+
]
292+
self.assertCountEqual(
293+
returned_auth_chain_events,
294+
[
295+
("m.room.member", "@kermit:test"),
296+
],
297+
)
298+
299+
# the room should show that the new user is a member
300+
r = self.get_success(
301+
self.hs.get_state_handler().get_current_state(self._room_id)
302+
)
303+
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
304+
254305

255306
def _create_acl_event(content):
256307
return make_event_from_dict(

0 commit comments

Comments
 (0)