Skip to content

Commit 3d6f5e3

Browse files
squahtxazmeuk
authored andcommitted
Refactor _resolve_state_at_missing_prevs to return an EventContext (matrix-org#13404)
Previously, `_resolve_state_at_missing_prevs` returned the resolved state before an event and a partial state flag. These were unwieldy to carry around would only ever be used to build an event context. Build the event context directly instead. Signed-off-by: Sean Quah <[email protected]>
1 parent 54f9a33 commit 3d6f5e3

File tree

5 files changed

+68
-86
lines changed

5 files changed

+68
-86
lines changed

changelog.d/13404.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead.

synapse/handlers/federation_event.py

Lines changed: 44 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Dict,
2424
Iterable,
2525
List,
26-
Optional,
2726
Sequence,
2827
Set,
2928
Tuple,
@@ -278,19 +277,17 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
278277
)
279278

280279
try:
281-
await self._process_received_pdu(
282-
origin, pdu, state_ids=None, partial_state=None
283-
)
280+
context = await self._state_handler.compute_event_context(pdu)
281+
await self._process_received_pdu(origin, pdu, context)
284282
except PartialStateConflictError:
285283
# The room was un-partial stated while we were processing the PDU.
286284
# Try once more, with full state this time.
287285
logger.info(
288286
"Room %s was un-partial stated while processing the PDU, trying again.",
289287
room_id,
290288
)
291-
await self._process_received_pdu(
292-
origin, pdu, state_ids=None, partial_state=None
293-
)
289+
context = await self._state_handler.compute_event_context(pdu)
290+
await self._process_received_pdu(origin, pdu, context)
294291

295292
async def on_send_membership_event(
296293
self, origin: str, event: EventBase
@@ -320,6 +317,7 @@ async def on_send_membership_event(
320317
The event and context of the event after inserting it into the room graph.
321318
322319
Raises:
320+
RuntimeError if any prev_events are missing
323321
SynapseError if the event is not accepted into the room
324322
PartialStateConflictError if the room was un-partial stated in between
325323
computing the state at the event and persisting it. The caller should
@@ -380,7 +378,7 @@ async def on_send_membership_event(
380378
# need to.
381379
await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
382380

383-
await self._check_for_soft_fail(event, None, origin=origin)
381+
await self._check_for_soft_fail(event, context=context, origin=origin)
384382
await self._run_push_actions_and_persist_event(event, context)
385383
return event, context
386384

@@ -538,36 +536,10 @@ async def update_state_for_partial_state_event(
538536
#
539537
# This is the same operation as we do when we receive a regular event
540538
# over federation.
541-
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
539+
context = await self._compute_event_context_with_maybe_missing_prevs(
542540
destination, event
543541
)
544-
545-
# There are three possible cases for (state_ids, partial_state):
546-
# * `state_ids` and `partial_state` are both `None` if we had all the
547-
# prev_events. The prev_events may or may not have partial state and
548-
# we won't know until we compute the event context.
549-
# * `state_ids` is not `None` and `partial_state` is `False` if we were
550-
# missing some prev_events (but we have full state for any we did
551-
# have). We calculated the full state after the prev_events.
552-
# * `state_ids` is not `None` and `partial_state` is `True` if we were
553-
# missing some, but not all, prev_events. At least one of the
554-
# prev_events we did have had partial state, so we calculated a partial
555-
# state after the prev_events.
556-
557-
context = None
558-
if state_ids is not None and partial_state:
559-
# the state after the prev events is still partial. We can't de-partial
560-
# state the event, so don't bother building the event context.
561-
pass
562-
else:
563-
# build a new state group for it if need be
564-
context = await self._state_handler.compute_event_context(
565-
event,
566-
state_ids_before_event=state_ids,
567-
partial_state=partial_state,
568-
)
569-
570-
if context is None or context.partial_state:
542+
if context.partial_state:
571543
# this can happen if some or all of the event's prev_events still have
572544
# partial state. We were careful to only pick events from the db without
573545
# partial-state prev events, so that implies that a prev event has
@@ -840,26 +812,25 @@ async def _process_pulled_event(
840812

841813
try:
842814
try:
843-
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
815+
context = await self._compute_event_context_with_maybe_missing_prevs(
844816
origin, event
845817
)
846818
await self._process_received_pdu(
847819
origin,
848820
event,
849-
state_ids=state_ids,
850-
partial_state=partial_state,
821+
context,
851822
backfilled=backfilled,
852823
)
853824
except PartialStateConflictError:
854825
# The room was un-partial stated while we were processing the event.
855826
# Try once more, with full state this time.
856-
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
827+
context = await self._compute_event_context_with_maybe_missing_prevs(
857828
origin, event
858829
)
859830

860831
# We ought to have full state now, barring some unlikely race where we left and
861832
# rejoned the room in the background.
862-
if state_ids is not None and partial_state:
833+
if context.partial_state:
863834
raise AssertionError(
864835
f"Event {event.event_id} still has a partial resolved state "
865836
f"after room {event.room_id} was un-partial stated"
@@ -868,8 +839,7 @@ async def _process_pulled_event(
868839
await self._process_received_pdu(
869840
origin,
870841
event,
871-
state_ids=state_ids,
872-
partial_state=partial_state,
842+
context,
873843
backfilled=backfilled,
874844
)
875845
except FederationError as e:
@@ -878,15 +848,18 @@ async def _process_pulled_event(
878848
else:
879849
raise
880850

881-
async def _resolve_state_at_missing_prevs(
851+
async def _compute_event_context_with_maybe_missing_prevs(
882852
self, dest: str, event: EventBase
883-
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
884-
"""Calculate the state at an event with missing prev_events.
853+
) -> EventContext:
854+
"""Build an EventContext structure for a non-outlier event whose prev_events may
855+
be missing.
885856
886-
This is used when we have pulled a batch of events from a remote server, and
887-
still don't have all the prev_events.
857+
This is used when we have pulled a batch of events from a remote server, and may
858+
not have all the prev_events.
888859
889-
If we already have all the prev_events for `event`, this method does nothing.
860+
To build an EventContext, we need to calculate the state before the event. If we
861+
already have all the prev_events for `event`, we can simply use the state after
862+
the prev_events to calculate the state before `event`.
890863
891864
Otherwise, the missing prevs become new backwards extremities, and we fall back
892865
to asking the remote server for the state after each missing `prev_event`,
@@ -907,10 +880,7 @@ async def _resolve_state_at_missing_prevs(
907880
event: an event to check for missing prevs.
908881
909882
Returns:
910-
if we already had all the prev events, `None, None`. Otherwise, returns a
911-
tuple containing:
912-
* the event ids of the state at `event`.
913-
* a boolean indicating whether the state may be partial.
883+
The event context.
914884
915885
Raises:
916886
FederationError if we fail to get the state from the remote server after any
@@ -924,7 +894,7 @@ async def _resolve_state_at_missing_prevs(
924894
missing_prevs = prevs - seen
925895

926896
if not missing_prevs:
927-
return None, None
897+
return await self._state_handler.compute_event_context(event)
928898

929899
logger.info(
930900
"Event %s is missing prev_events %s: calculating state for a "
@@ -990,7 +960,9 @@ async def _resolve_state_at_missing_prevs(
990960
"We can't get valid state history.",
991961
affected=event_id,
992962
)
993-
return state_map, partial_state
963+
return await self._state_handler.compute_event_context(
964+
event, state_ids_before_event=state_map, partial_state=partial_state
965+
)
994966

995967
async def _get_state_ids_after_missing_prev_event(
996968
self,
@@ -1159,8 +1131,7 @@ async def _process_received_pdu(
11591131
self,
11601132
origin: str,
11611133
event: EventBase,
1162-
state_ids: Optional[StateMap[str]],
1163-
partial_state: Optional[bool],
1134+
context: EventContext,
11641135
backfilled: bool = False,
11651136
) -> None:
11661137
"""Called when we have a new non-outlier event.
@@ -1182,32 +1153,18 @@ async def _process_received_pdu(
11821153
11831154
event: event to be persisted
11841155
1185-
state_ids: Normally None, but if we are handling a gap in the graph
1186-
(ie, we are missing one or more prev_events), the resolved state at the
1187-
event
1188-
1189-
partial_state:
1190-
`True` if `state_ids` is partial and omits non-critical membership
1191-
events.
1192-
`False` if `state_ids` is the full state.
1193-
`None` if `state_ids` is not provided. In this case, the flag will be
1194-
calculated based on `event`'s prev events.
1156+
context: The `EventContext` to persist the event with.
11951157
11961158
backfilled: True if this is part of a historical batch of events (inhibits
11971159
notification to clients, and validation of device keys.)
11981160
11991161
PartialStateConflictError: if the room was un-partial stated in between
1200-
computing the state at the event and persisting it. The caller should retry
1201-
exactly once in this case.
1162+
computing the state at the event and persisting it. The caller should
1163+
recompute `context` and retry exactly once when this happens.
12021164
"""
12031165
logger.debug("Processing event: %s", event)
12041166
assert not event.internal_metadata.outlier
12051167

1206-
context = await self._state_handler.compute_event_context(
1207-
event,
1208-
state_ids_before_event=state_ids,
1209-
partial_state=partial_state,
1210-
)
12111168
try:
12121169
await self._check_event_auth(origin, event, context)
12131170
except AuthError as e:
@@ -1219,7 +1176,7 @@ async def _process_received_pdu(
12191176
# For new (non-backfilled and non-outlier) events we check if the event
12201177
# passes auth based on the current state. If it doesn't then we
12211178
# "soft-fail" the event.
1222-
await self._check_for_soft_fail(event, state_ids, origin=origin)
1179+
await self._check_for_soft_fail(event, context=context, origin=origin)
12231180

12241181
await self._run_push_actions_and_persist_event(event, context, backfilled)
12251182

@@ -1782,7 +1739,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
17821739
async def _check_for_soft_fail(
17831740
self,
17841741
event: EventBase,
1785-
state_ids: Optional[StateMap[str]],
1742+
context: EventContext,
17861743
origin: str,
17871744
) -> None:
17881745
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1793,7 +1750,7 @@ async def _check_for_soft_fail(
17931750
17941751
Args:
17951752
event
1796-
state_ids: The state at the event if we don't have all the event's prev events
1753+
context: The `EventContext` which we are about to persist the event with.
17971754
origin: The host the event originates from.
17981755
"""
17991756
if await self._store.is_partial_state_room(event.room_id):
@@ -1819,11 +1776,15 @@ async def _check_for_soft_fail(
18191776
auth_types = auth_types_for_event(room_version_obj, event)
18201777

18211778
# Calculate the "current state".
1822-
if state_ids is not None:
1823-
# If we're explicitly given the state then we won't have all the
1824-
# prev events, and so we have a gap in the graph. In this case
1825-
# we want to be a little careful as we might have been down for
1826-
# a while and have an incorrect view of the current state,
1779+
seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
1780+
has_missing_prevs = bool(prev_event_ids - seen_event_ids)
1781+
if has_missing_prevs:
1782+
# We don't have all the prev_events of this event, which means we have a
1783+
# gap in the graph, and the new event is going to become a new backwards
1784+
# extremity.
1785+
#
1786+
# In this case we want to be a little careful as we might have been
1787+
# down for a while and have an incorrect view of the current state,
18271788
# however we still want to do checks as gaps are easy to
18281789
# maliciously manufacture.
18291790
#
@@ -1836,6 +1797,7 @@ async def _check_for_soft_fail(
18361797
event.room_id, extrem_ids
18371798
)
18381799
state_sets: List[StateMap[str]] = list(state_sets_d.values())
1800+
state_ids = await context.get_prev_state_ids()
18391801
state_sets.append(state_ids)
18401802
current_state_ids = (
18411803
await self._state_resolution_handler.resolve_events_with_store(

synapse/state/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ async def compute_event_context(
278278
flag will be calculated based on `event`'s prev events.
279279
Returns:
280280
The event context.
281+
282+
Raises:
283+
RuntimeError if `state_ids_before_event` is not provided and one or more
284+
prev events are missing or outliers.
281285
"""
282286

283287
assert not event.internal_metadata.is_outlier()
@@ -432,6 +436,10 @@ async def resolve_state_groups_for_events(
432436
433437
Returns:
434438
The resolved state
439+
440+
Raises:
441+
RuntimeError if we don't have a state group for one or more of the events
442+
(ie. they are outliers or unknown)
435443
"""
436444
logger.debug("resolve_state_groups event_ids %s", event_ids)
437445

synapse/storage/controllers/state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ async def get_state_group_for_events(
338338
event_ids: events to get state groups for
339339
await_full_state: if true, will block if we do not yet have complete
340340
state at these events.
341+
342+
Raises:
343+
RuntimeError if we don't have a state group for one or more of the events
344+
(ie. they are outliers or unknown)
341345
"""
342346
if await_full_state:
343347
await self._partial_state_events_tracker.await_full_state(event_ids)

tests/handlers/test_federation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,23 @@ def test_backfill_with_many_backward_extremities(self) -> None:
280280

281281
# we poke this directly into _process_received_pdu, to avoid the
282282
# federation handler wanting to backfill the fake event.
283-
self.get_success(
284-
federation_event_handler._process_received_pdu(
285-
self.OTHER_SERVER_NAME,
283+
state_handler = self.hs.get_state_handler()
284+
context = self.get_success(
285+
state_handler.compute_event_context(
286286
event,
287-
state_ids={
287+
state_ids_before_event={
288288
(e.type, e.state_key): e.event_id for e in current_state
289289
},
290290
partial_state=False,
291291
)
292292
)
293+
self.get_success(
294+
federation_event_handler._process_received_pdu(
295+
self.OTHER_SERVER_NAME,
296+
event,
297+
context,
298+
)
299+
)
293300

294301
# we should now have 8 backwards extremities.
295302
backwards_extremities = self.get_success(

0 commit comments

Comments
 (0)