Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 63c4634

Browse files
authored
Implement MSC3706: partial state in /send_join response (#11967)
* Make `get_auth_chain_ids` return a Set It has a set internally, and a set is often useful where it gets used, so let's avoid converting to an intermediate list. * Minor refactors in `on_send_join_request` A little bit of non-functional groundwork * Implement MSC3706: partial state in /send_join response
1 parent b2b971f commit 63c4634

File tree

7 files changed

+262
-21
lines changed

7 files changed

+262
-21
lines changed

changelog.d/11967.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.

synapse/config/experimental.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
6161
self.msc2409_to_device_messages_enabled: bool = experimental.get(
6262
"msc2409_to_device_messages_enabled", False
6363
)
64+
65+
# MSC3706 (server-side support for partial state in /send_join responses)
66+
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

synapse/federation/federation_server.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Any,
2121
Awaitable,
2222
Callable,
23+
Collection,
2324
Dict,
2425
Iterable,
2526
List,
@@ -64,7 +65,7 @@
6465
ReplicationGetQueryRestServlet,
6566
)
6667
from synapse.storage.databases.main.lock import Lock
67-
from synapse.types import JsonDict, get_domain_from_id
68+
from synapse.types import JsonDict, StateMap, get_domain_from_id
6869
from synapse.util import json_decoder, unwrapFirstError
6970
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
7071
from synapse.util.caches.response_cache import ResponseCache
@@ -571,7 +572,7 @@ async def _on_state_ids_request_compute(
571572
) -> JsonDict:
572573
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
573574
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
574-
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
575+
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
575576

576577
async def _on_context_state_request_compute(
577578
self, room_id: str, event_id: Optional[str]
@@ -645,27 +646,61 @@ async def on_invite_request(
645646
return {"event": ret_pdu.get_pdu_json(time_now)}
646647

647648
async def on_send_join_request(
648-
self, origin: str, content: JsonDict, room_id: str
649+
self,
650+
origin: str,
651+
content: JsonDict,
652+
room_id: str,
653+
caller_supports_partial_state: bool = False,
649654
) -> Dict[str, Any]:
650655
event, context = await self._on_send_membership_event(
651656
origin, content, Membership.JOIN, room_id
652657
)
653658

654659
prev_state_ids = await context.get_prev_state_ids()
655-
state_ids = list(prev_state_ids.values())
656-
auth_chain = await self.store.get_auth_chain(room_id, state_ids)
657-
state = await self.store.get_events(state_ids)
658660

661+
state_event_ids: Collection[str]
662+
servers_in_room: Optional[Collection[str]]
663+
if caller_supports_partial_state:
664+
state_event_ids = _get_event_ids_for_partial_state_join(
665+
event, prev_state_ids
666+
)
667+
servers_in_room = await self.state.get_hosts_in_room_at_events(
668+
room_id, event_ids=event.prev_event_ids()
669+
)
670+
else:
671+
state_event_ids = prev_state_ids.values()
672+
servers_in_room = None
673+
674+
auth_chain_event_ids = await self.store.get_auth_chain_ids(
675+
room_id, state_event_ids
676+
)
677+
678+
# if the caller has opted in, we can omit any auth_chain events which are
679+
# already in state_event_ids
680+
if caller_supports_partial_state:
681+
auth_chain_event_ids.difference_update(state_event_ids)
682+
683+
auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
684+
state_events = await self.store.get_events_as_list(state_event_ids)
685+
686+
# we try to do all the async stuff before this point, so that time_now is as
687+
# accurate as possible.
659688
time_now = self._clock.time_msec()
660-
event_json = event.get_pdu_json()
661-
return {
689+
event_json = event.get_pdu_json(time_now)
690+
resp = {
662691
# TODO Remove the unstable prefix when servers have updated.
663692
"org.matrix.msc3083.v2.event": event_json,
664693
"event": event_json,
665-
"state": [p.get_pdu_json(time_now) for p in state.values()],
666-
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
694+
"state": [p.get_pdu_json(time_now) for p in state_events],
695+
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
696+
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
667697
}
668698

699+
if servers_in_room is not None:
700+
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
701+
702+
return resp
703+
669704
async def on_make_leave_request(
670705
self, origin: str, room_id: str, user_id: str
671706
) -> Dict[str, Any]:
@@ -1339,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
13391374
# error.
13401375
logger.warning("No handler registered for query type %s", query_type)
13411376
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
1377+
1378+
1379+
def _get_event_ids_for_partial_state_join(
1380+
join_event: EventBase,
1381+
prev_state_ids: StateMap[str],
1382+
) -> Collection[str]:
1383+
"""Calculate state to be retuned in a partial_state send_join
1384+
1385+
Args:
1386+
join_event: the join event being send_joined
1387+
prev_state_ids: the event ids of the state before the join
1388+
1389+
Returns:
1390+
the event ids to be returned
1391+
"""
1392+
1393+
# return all non-member events
1394+
state_event_ids = {
1395+
event_id
1396+
for (event_type, state_key), event_id in prev_state_ids.items()
1397+
if event_type != EventTypes.Member
1398+
}
1399+
1400+
# we also need the current state of the current user (it's going to
1401+
# be an auth event for the new join, so we may as well return it)
1402+
current_membership_event_id = prev_state_ids.get(
1403+
(EventTypes.Member, join_event.state_key)
1404+
)
1405+
if current_membership_event_id is not None:
1406+
state_event_ids.add(current_membership_event_id)
1407+
1408+
# TODO: return a few more members:
1409+
# - those with invites
1410+
# - those that are kicked? / banned
1411+
1412+
return state_event_ids

synapse/federation/transport/server/federation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
412412

413413
PREFIX = FEDERATION_V2_PREFIX
414414

415+
def __init__(
416+
self,
417+
hs: "HomeServer",
418+
authenticator: Authenticator,
419+
ratelimiter: FederationRateLimiter,
420+
server_name: str,
421+
):
422+
super().__init__(hs, authenticator, ratelimiter, server_name)
423+
self._msc3706_enabled = hs.config.experimental.msc3706_enabled
424+
415425
async def on_PUT(
416426
self,
417427
origin: str,
@@ -422,7 +432,15 @@ async def on_PUT(
422432
) -> Tuple[int, JsonDict]:
423433
# TODO(paul): assert that event_id parsed from path actually
424434
# match those given in content
425-
result = await self.handler.on_send_join_request(origin, content, room_id)
435+
436+
partial_state = False
437+
if self._msc3706_enabled:
438+
partial_state = parse_boolean_from_args(
439+
query, "org.matrix.msc3706.partial_state", default=False
440+
)
441+
result = await self.handler.on_send_join_request(
442+
origin, content, room_id, caller_supports_partial_state=partial_state
443+
)
426444
return 200, result
427445

428446

synapse/storage/databases/main/event_federation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def get_auth_chain_ids(
121121
room_id: str,
122122
event_ids: Collection[str],
123123
include_given: bool = False,
124-
) -> List[str]:
124+
) -> Set[str]:
125125
"""Get auth events for given event_ids. The events *must* be state events.
126126
127127
Args:
@@ -130,7 +130,7 @@ async def get_auth_chain_ids(
130130
include_given: include the given events in result
131131
132132
Returns:
133-
list of event_ids
133+
set of event_ids
134134
"""
135135

136136
# Check if we have indexed the room so we can use the chain cover
@@ -159,7 +159,7 @@ async def get_auth_chain_ids(
159159

160160
def _get_auth_chain_ids_using_cover_index_txn(
161161
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
162-
) -> List[str]:
162+
) -> Set[str]:
163163
"""Calculates the auth chain IDs using the chain index."""
164164

165165
# First we look up the chain ID/sequence numbers for the given events.
@@ -272,11 +272,11 @@ def _get_auth_chain_ids_using_cover_index_txn(
272272
txn.execute(sql, (chain_id, max_no))
273273
results.update(r for r, in txn)
274274

275-
return list(results)
275+
return results
276276

277277
def _get_auth_chain_ids_txn(
278278
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
279-
) -> List[str]:
279+
) -> Set[str]:
280280
"""Calculates the auth chain IDs.
281281
282282
This is used when we don't have a cover index for the room.
@@ -331,7 +331,7 @@ def _get_auth_chain_ids_txn(
331331
front = new_front
332332
results.update(front)
333333

334-
return list(results)
334+
return results
335335

336336
async def get_auth_chain_difference(
337337
self, room_id: str, state_sets: List[Set[str]]

tests/federation/test_federation_server.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@
1616

1717
from parameterized import parameterized
1818

19+
from twisted.test.proto_helpers import MemoryReactor
20+
21+
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
22+
from synapse.config.server import DEFAULT_ROOM_VERSION
23+
from synapse.crypto.event_signing import add_hashes_and_signatures
1924
from synapse.events import make_event_from_dict
2025
from synapse.federation.federation_server import server_matches_acl_event
2126
from synapse.rest import admin
2227
from synapse.rest.client import login, room
28+
from synapse.server import HomeServer
29+
from synapse.types import JsonDict
30+
from synapse.util import Clock
2331

2432
from tests import unittest
33+
from tests.unittest import override_config
2534

2635

2736
class FederationServerTests(unittest.FederatingHomeserverTestCase):
@@ -152,6 +161,145 @@ def test_needs_to_be_in_room(self):
152161
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
153162

154163

164+
class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
165+
servlets = [
166+
admin.register_servlets,
167+
room.register_servlets,
168+
login.register_servlets,
169+
]
170+
171+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
172+
super().prepare(reactor, clock, hs)
173+
174+
# create the room
175+
creator_user_id = self.register_user("kermit", "test")
176+
tok = self.login("kermit", "test")
177+
self._room_id = self.helper.create_room_as(
178+
room_creator=creator_user_id, tok=tok
179+
)
180+
181+
# a second member on the orgin HS
182+
second_member_user_id = self.register_user("fozzie", "bear")
183+
tok2 = self.login("fozzie", "bear")
184+
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
185+
186+
def _make_join(self, user_id) -> JsonDict:
187+
channel = self.make_signed_federation_request(
188+
"GET",
189+
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
190+
f"?ver={DEFAULT_ROOM_VERSION}",
191+
)
192+
self.assertEquals(channel.code, 200, channel.json_body)
193+
return channel.json_body
194+
195+
def test_send_join(self):
196+
"""happy-path test of send_join"""
197+
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
198+
join_result = self._make_join(joining_user)
199+
200+
join_event_dict = join_result["event"]
201+
add_hashes_and_signatures(
202+
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
203+
join_event_dict,
204+
signature_name=self.OTHER_SERVER_NAME,
205+
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
206+
)
207+
channel = self.make_signed_federation_request(
208+
"PUT",
209+
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
210+
content=join_event_dict,
211+
)
212+
self.assertEquals(channel.code, 200, channel.json_body)
213+
214+
# we should get complete room state back
215+
returned_state = [
216+
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
217+
]
218+
self.assertCountEqual(
219+
returned_state,
220+
[
221+
("m.room.create", ""),
222+
("m.room.power_levels", ""),
223+
("m.room.join_rules", ""),
224+
("m.room.history_visibility", ""),
225+
("m.room.member", "@kermit:test"),
226+
("m.room.member", "@fozzie:test"),
227+
# nb: *not* the joining user
228+
],
229+
)
230+
231+
# also check the auth chain
232+
returned_auth_chain_events = [
233+
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
234+
]
235+
self.assertCountEqual(
236+
returned_auth_chain_events,
237+
[
238+
("m.room.create", ""),
239+
("m.room.member", "@kermit:test"),
240+
("m.room.power_levels", ""),
241+
("m.room.join_rules", ""),
242+
],
243+
)
244+
245+
# the room should show that the new user is a member
246+
r = self.get_success(
247+
self.hs.get_state_handler().get_current_state(self._room_id)
248+
)
249+
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
250+
251+
@override_config({"experimental_features": {"msc3706_enabled": True}})
252+
def test_send_join_partial_state(self):
253+
"""When MSC3706 support is enabled, /send_join should return partial state"""
254+
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
255+
join_result = self._make_join(joining_user)
256+
257+
join_event_dict = join_result["event"]
258+
add_hashes_and_signatures(
259+
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
260+
join_event_dict,
261+
signature_name=self.OTHER_SERVER_NAME,
262+
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
263+
)
264+
channel = self.make_signed_federation_request(
265+
"PUT",
266+
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
267+
content=join_event_dict,
268+
)
269+
self.assertEquals(channel.code, 200, channel.json_body)
270+
271+
# expect a reduced room state
272+
returned_state = [
273+
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
274+
]
275+
self.assertCountEqual(
276+
returned_state,
277+
[
278+
("m.room.create", ""),
279+
("m.room.power_levels", ""),
280+
("m.room.join_rules", ""),
281+
("m.room.history_visibility", ""),
282+
],
283+
)
284+
285+
# the auth chain should not include anything already in "state"
286+
returned_auth_chain_events = [
287+
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
288+
]
289+
self.assertCountEqual(
290+
returned_auth_chain_events,
291+
[
292+
("m.room.member", "@kermit:test"),
293+
],
294+
)
295+
296+
# the room should show that the new user is a member
297+
r = self.get_success(
298+
self.hs.get_state_handler().get_current_state(self._room_id)
299+
)
300+
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
301+
302+
155303
def _create_acl_event(content):
156304
return make_event_from_dict(
157305
{

0 commit comments

Comments
 (0)