Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 03bccd5

Browse files
H-Shayerikjohnston
andauthored
Add a class UnpersistedEventContext to allow for the batching up of storing state groups (#14675)
* add class UnpersistedEventContext * modify create new client event to create unpersistedeventcontexts * persist event contexts after creation * fix tests to persist unpersisted event contexts * cleanup * misc lints + cleanup * changelog + fix comments * lints * fix batch insertion? * reduce redundant calculation * add unpersisted event classes * rework compute_event_context, split into function that returns unpersisted event context and then persists it * use calculate_context_info to create unpersisted event contexts * update typing * $%#^&* * black * fix comments and consolidate classes, use attr.s for class * requested changes * lint * requested changes * requested changes * refactor to be stupidly explicit * clearer renaming and flow * make partial state non-optional * update docstrings --------- Co-authored-by: Erik Johnston <[email protected]>
1 parent c1d2ce2 commit 03bccd5

File tree

14 files changed

+359
-162
lines changed

14 files changed

+359
-162
lines changed

changelog.d/14675.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a class UnpersistedEventContext to allow for the batching up of storing state groups.

synapse/events/snapshot.py

Lines changed: 170 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from abc import ABC, abstractmethod
1415
from typing import TYPE_CHECKING, List, Optional, Tuple
1516

1617
import attr
@@ -26,8 +27,51 @@
2627
from synapse.types.state import StateFilter
2728

2829

30+
class UnpersistedEventContextBase(ABC):
31+
"""
32+
This is a base class for EventContext and UnpersistedEventContext, objects which
33+
hold information relevant to storing an associated event. Note that an
34+
UnpersistedEventContexts must be converted into an EventContext before it is
35+
suitable to send to the db with its associated event.
36+
37+
Attributes:
38+
_storage: storage controllers for interfacing with the database
39+
app_service: If the associated event is being sent by a (local) application service, that
40+
app service.
41+
"""
42+
43+
def __init__(self, storage_controller: "StorageControllers"):
44+
self._storage: "StorageControllers" = storage_controller
45+
self.app_service: Optional[ApplicationService] = None
46+
47+
@abstractmethod
48+
async def persist(
49+
self,
50+
event: EventBase,
51+
) -> "EventContext":
52+
"""
53+
A method to convert an UnpersistedEventContext to an EventContext, suitable for
54+
sending to the database with the associated event.
55+
"""
56+
pass
57+
58+
@abstractmethod
59+
async def get_prev_state_ids(
60+
self, state_filter: Optional["StateFilter"] = None
61+
) -> StateMap[str]:
62+
"""
63+
Gets the room state at the event (ie not including the event if the event is a
64+
state event).
65+
66+
Args:
67+
state_filter: specifies the type of state event to fetch from DB, example:
68+
EventTypes.JoinRules
69+
"""
70+
pass
71+
72+
2973
@attr.s(slots=True, auto_attribs=True)
30-
class EventContext:
74+
class EventContext(UnpersistedEventContextBase):
3175
"""
3276
Holds information relevant to persisting an event
3377
@@ -77,9 +121,6 @@ class EventContext:
77121
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
78122
and ``state_group``.
79123
80-
app_service: If this event is being sent by a (local) application service, that
81-
app service.
82-
83124
partial_state: if True, we may be storing this event with a temporary,
84125
incomplete state.
85126
"""
@@ -122,6 +163,9 @@ def for_outlier(
122163
"""Return an EventContext instance suitable for persisting an outlier event"""
123164
return EventContext(storage=storage)
124165

166+
async def persist(self, event: EventBase) -> "EventContext":
167+
return self
168+
125169
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
126170
"""Converts self to a type that can be serialized as JSON, and then
127171
deserialized by `deserialize`
@@ -254,6 +298,128 @@ async def get_prev_state_ids(
254298
)
255299

256300

301+
@attr.s(slots=True, auto_attribs=True)
302+
class UnpersistedEventContext(UnpersistedEventContextBase):
303+
"""
304+
The event context holds information about the state groups for an event. It is important
305+
to remember that an event technically has two state groups: the state group before the
306+
event, and the state group after the event. If the event is not a state event, the state
307+
group will not change (ie the state group before the event will be the same as the state
308+
group after the event), but if it is a state event the state group before the event
309+
will differ from the state group after the event.
310+
This is a version of an EventContext before the new state group (if any) has been
311+
computed and stored. It contains information about the state before the event (which
312+
also may be the information after the event, if the event is not a state event). The
313+
UnpersistedEventContext must be converted into an EventContext by calling the method
314+
'persist' on it before it is suitable to be sent to the DB for processing.
315+
316+
state_group_after_event:
317+
The state group after the event. This will always be None until it is persisted.
318+
If the event is not a state event, this will be the same as
319+
state_group_before_event.
320+
321+
state_group_before_event:
322+
The ID of the state group representing the state of the room before this event.
323+
324+
state_delta_due_to_event:
325+
If the event is a state event, then this is the delta of the state between
326+
`state_group` and `state_group_before_event`
327+
328+
prev_group_for_state_group_before_event:
329+
If it is known, ``state_group_before_event``'s previous state group.
330+
331+
delta_ids_to_state_group_before_event:
332+
If ``prev_group_for_state_group_before_event`` is not None, the state delta
333+
between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
334+
335+
partial_state:
336+
Whether the event has partial state.
337+
338+
state_map_before_event:
339+
A map of the state before the event, i.e. the state at `state_group_before_event`
340+
"""
341+
342+
_storage: "StorageControllers"
343+
state_group_before_event: Optional[int]
344+
state_group_after_event: Optional[int]
345+
state_delta_due_to_event: Optional[dict]
346+
prev_group_for_state_group_before_event: Optional[int]
347+
delta_ids_to_state_group_before_event: Optional[StateMap[str]]
348+
partial_state: bool
349+
state_map_before_event: Optional[StateMap[str]] = None
350+
351+
async def get_prev_state_ids(
352+
self, state_filter: Optional["StateFilter"] = None
353+
) -> StateMap[str]:
354+
"""
355+
Gets the room state map, excluding this event.
356+
357+
Args:
358+
state_filter: specifies the type of state event to fetch from DB
359+
360+
Returns:
361+
Maps a (type, state_key) to the event ID of the state event matching
362+
this tuple.
363+
"""
364+
if self.state_map_before_event:
365+
return self.state_map_before_event
366+
367+
assert self.state_group_before_event is not None
368+
return await self._storage.state.get_state_ids_for_group(
369+
self.state_group_before_event, state_filter
370+
)
371+
372+
async def persist(self, event: EventBase) -> EventContext:
373+
"""
374+
Creates a full `EventContext` for the event, persisting any referenced state that
375+
has not yet been persisted.
376+
377+
Args:
378+
event: event that the EventContext is associated with.
379+
380+
Returns: An EventContext suitable for sending to the database with the event
381+
for persisting
382+
"""
383+
assert self.partial_state is not None
384+
385+
# If we have a full set of state for before the event but don't have a state
386+
# group for that state, we need to get one
387+
if self.state_group_before_event is None:
388+
assert self.state_map_before_event
389+
state_group_before_event = await self._storage.state.store_state_group(
390+
event.event_id,
391+
event.room_id,
392+
prev_group=self.prev_group_for_state_group_before_event,
393+
delta_ids=self.delta_ids_to_state_group_before_event,
394+
current_state_ids=self.state_map_before_event,
395+
)
396+
self.state_group_before_event = state_group_before_event
397+
398+
# if the event isn't a state event the state group doesn't change
399+
if not self.state_delta_due_to_event:
400+
state_group_after_event = self.state_group_before_event
401+
402+
# otherwise if it is a state event we need to get a state group for it
403+
else:
404+
state_group_after_event = await self._storage.state.store_state_group(
405+
event.event_id,
406+
event.room_id,
407+
prev_group=self.state_group_before_event,
408+
delta_ids=self.state_delta_due_to_event,
409+
current_state_ids=None,
410+
)
411+
412+
return EventContext.with_state(
413+
storage=self._storage,
414+
state_group=state_group_after_event,
415+
state_group_before_event=self.state_group_before_event,
416+
state_delta_due_to_event=self.state_delta_due_to_event,
417+
partial_state=self.partial_state,
418+
prev_group=self.state_group_before_event,
419+
delta_ids=self.state_delta_due_to_event,
420+
)
421+
422+
257423
def _encode_state_dict(
258424
state_dict: Optional[StateMap[str]],
259425
) -> Optional[List[Tuple[str, str, str]]]:

synapse/events/third_party_rules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from synapse.api.errors import ModuleFailedException, SynapseError
2020
from synapse.events import EventBase
21-
from synapse.events.snapshot import EventContext
21+
from synapse.events.snapshot import UnpersistedEventContextBase
2222
from synapse.storage.roommember import ProfileInfo
2323
from synapse.types import Requester, StateMap
2424
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ def register_third_party_rules_callbacks(
231231
self._on_threepid_bind_callbacks.append(on_threepid_bind)
232232

233233
async def check_event_allowed(
234-
self, event: EventBase, context: EventContext
234+
self,
235+
event: EventBase,
236+
context: UnpersistedEventContextBase,
235237
) -> Tuple[bool, Optional[dict]]:
236238
"""Check if a provided event should be allowed in the given context.
237239

synapse/handlers/federation.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from synapse.crypto.event_signing import compute_event_signature
5757
from synapse.event_auth import validate_event_for_room_version
5858
from synapse.events import EventBase
59-
from synapse.events.snapshot import EventContext
59+
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
6060
from synapse.events.validator import EventValidator
6161
from synapse.federation.federation_client import InvalidResponseError
6262
from synapse.http.servlet import assert_params_in_dict
@@ -990,15 +990,20 @@ async def on_make_join_request(
990990
)
991991

992992
try:
993-
event, context = await self.event_creation_handler.create_new_client_event(
993+
(
994+
event,
995+
unpersisted_context,
996+
) = await self.event_creation_handler.create_new_client_event(
994997
builder=builder
995998
)
996999
except SynapseError as e:
9971000
logger.warning("Failed to create join to %s because %s", room_id, e)
9981001
raise
9991002

10001003
# Ensure the user can even join the room.
1001-
await self._federation_event_handler.check_join_restrictions(context, event)
1004+
await self._federation_event_handler.check_join_restrictions(
1005+
unpersisted_context, event
1006+
)
10021007

10031008
# The remote hasn't signed it yet, obviously. We'll do the full checks
10041009
# when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ async def on_make_leave_request(
11781183
},
11791184
)
11801185

1181-
event, context = await self.event_creation_handler.create_new_client_event(
1186+
event, _ = await self.event_creation_handler.create_new_client_event(
11821187
builder=builder
11831188
)
11841189

@@ -1228,12 +1233,13 @@ async def on_make_knock_request(
12281233
},
12291234
)
12301235

1231-
event, context = await self.event_creation_handler.create_new_client_event(
1232-
builder=builder
1233-
)
1236+
(
1237+
event,
1238+
unpersisted_context,
1239+
) = await self.event_creation_handler.create_new_client_event(builder=builder)
12341240

12351241
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
1236-
event, context
1242+
event, unpersisted_context
12371243
)
12381244
if not event_allowed:
12391245
logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ async def exchange_third_party_invite(
14061412
try:
14071413
(
14081414
event,
1409-
context,
1415+
unpersisted_context,
14101416
) = await self.event_creation_handler.create_new_client_event(
14111417
builder=builder
14121418
)
14131419

1414-
event, context = await self.add_display_name_to_third_party_invite(
1415-
room_version_obj, event_dict, event, context
1420+
(
1421+
event,
1422+
unpersisted_context,
1423+
) = await self.add_display_name_to_third_party_invite(
1424+
room_version_obj, event_dict, event, unpersisted_context
14161425
)
14171426

1427+
context = await unpersisted_context.persist(event)
1428+
14181429
EventValidator().validate_new(event, self.config)
14191430

14201431
# We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ async def on_exchange_third_party_invite_request(
14831494
try:
14841495
(
14851496
event,
1486-
context,
1497+
unpersisted_context,
14871498
) = await self.event_creation_handler.create_new_client_event(
14881499
builder=builder
14891500
)
1490-
event, context = await self.add_display_name_to_third_party_invite(
1491-
room_version_obj, event_dict, event, context
1501+
(
1502+
event,
1503+
unpersisted_context,
1504+
) = await self.add_display_name_to_third_party_invite(
1505+
room_version_obj, event_dict, event, unpersisted_context
14921506
)
14931507

1508+
context = await unpersisted_context.persist(event)
1509+
14941510
try:
14951511
validate_event_for_room_version(event)
14961512
await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ async def add_display_name_to_third_party_invite(
15221538
room_version_obj: RoomVersion,
15231539
event_dict: JsonDict,
15241540
event: EventBase,
1525-
context: EventContext,
1526-
) -> Tuple[EventBase, EventContext]:
1541+
context: UnpersistedEventContextBase,
1542+
) -> Tuple[EventBase, UnpersistedEventContextBase]:
15271543
key = (
15281544
EventTypes.ThirdPartyInvite,
15291545
event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ async def add_display_name_to_third_party_invite(
15571573
room_version_obj, event_dict
15581574
)
15591575
EventValidator().validate_builder(builder)
1560-
event, context = await self.event_creation_handler.create_new_client_event(
1561-
builder=builder
1562-
)
1576+
1577+
(
1578+
event,
1579+
unpersisted_context,
1580+
) = await self.event_creation_handler.create_new_client_event(builder=builder)
1581+
15631582
EventValidator().validate_new(event, self.config)
1564-
return event, context
1583+
return event, unpersisted_context
15651584

15661585
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
15671586
"""

synapse/handlers/federation_event.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
validate_event_for_room_version,
5959
)
6060
from synapse.events import EventBase
61-
from synapse.events.snapshot import EventContext
61+
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
6262
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
6363
from synapse.logging.context import nested_logging_context
6464
from synapse.logging.opentracing import (
@@ -426,7 +426,9 @@ async def on_send_membership_event(
426426
return event, context
427427

428428
async def check_join_restrictions(
429-
self, context: EventContext, event: EventBase
429+
self,
430+
context: UnpersistedEventContextBase,
431+
event: EventBase,
430432
) -> None:
431433
"""Check that restrictions in restricted join rules are matched
432434

0 commit comments

Comments
 (0)