Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2897fb6

Browse files
authored
Improvements to bundling aggregations. (#11815)
This is some odds and ends found during the review of #11791 and while continuing to work in this code: * Return attrs classes instead of dictionaries from some methods to improve type safety. * Call `get_bundled_aggregations` fewer times. * Adds a missing assertion in the tests. * Do not return empty bundled aggregations for an event (preferring to not include the bundle at all, as the docstring states).
1 parent d8df8e6 commit 2897fb6

File tree

12 files changed

+212
-139
lines changed

12 files changed

+212
-139
lines changed

changelog.d/11815.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type safety of bundled aggregations code.

synapse/events/utils.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
# limitations under the License.
1515
import collections.abc
1616
import re
17-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Callable,
21+
Dict,
22+
Iterable,
23+
List,
24+
Mapping,
25+
Optional,
26+
Union,
27+
)
1828

1929
from frozendict import frozendict
2030

@@ -26,6 +36,10 @@
2636

2737
from . import EventBase
2838

39+
if TYPE_CHECKING:
40+
from synapse.storage.databases.main.relations import BundledAggregations
41+
42+
2943
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
3044
# (?<!stuff) matches if the current position in the string is not preceded
3145
# by a match for 'stuff'.
@@ -376,7 +390,7 @@ def serialize_event(
376390
event: Union[JsonDict, EventBase],
377391
time_now: int,
378392
*,
379-
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
393+
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
380394
**kwargs: Any,
381395
) -> JsonDict:
382396
"""Serializes a single event.
@@ -415,7 +429,7 @@ def _inject_bundled_aggregations(
415429
self,
416430
event: EventBase,
417431
time_now: int,
418-
aggregations: JsonDict,
432+
aggregations: "BundledAggregations",
419433
serialized_event: JsonDict,
420434
) -> None:
421435
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ def _inject_bundled_aggregations(
427441
serialized_event: The serialized event which may be modified.
428442
429443
"""
430-
# Make a copy in-case the object is cached.
431-
aggregations = aggregations.copy()
444+
serialized_aggregations = {}
445+
446+
if aggregations.annotations:
447+
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
448+
449+
if aggregations.references:
450+
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
432451

433-
if RelationTypes.REPLACE in aggregations:
452+
if aggregations.replace:
434453
# If there is an edit replace the content, preserving existing
435454
# relations.
436-
edit = aggregations[RelationTypes.REPLACE]
455+
edit = aggregations.replace
437456

438457
# Ensure we take copies of the edit content, otherwise we risk modifying
439458
# the original event.
@@ -451,24 +470,28 @@ def _inject_bundled_aggregations(
451470
else:
452471
serialized_event["content"].pop("m.relates_to", None)
453472

454-
aggregations[RelationTypes.REPLACE] = {
473+
serialized_aggregations[RelationTypes.REPLACE] = {
455474
"event_id": edit.event_id,
456475
"origin_server_ts": edit.origin_server_ts,
457476
"sender": edit.sender,
458477
}
459478

460479
# If this event is the start of a thread, include a summary of the replies.
461-
if RelationTypes.THREAD in aggregations:
462-
# Serialize the latest thread event.
463-
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
464-
465-
# Don't bundle aggregations as this could recurse forever.
466-
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
467-
latest_thread_event, time_now, bundle_aggregations=None
468-
)
480+
if aggregations.thread:
481+
serialized_aggregations[RelationTypes.THREAD] = {
482+
# Don't bundle aggregations as this could recurse forever.
483+
"latest_event": self.serialize_event(
484+
aggregations.thread.latest_event, time_now, bundle_aggregations=None
485+
),
486+
"count": aggregations.thread.count,
487+
"current_user_participated": aggregations.thread.current_user_participated,
488+
}
469489

470490
# Include the bundled aggregations in the event.
471-
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
491+
if serialized_aggregations:
492+
serialized_event["unsigned"].setdefault("m.relations", {}).update(
493+
serialized_aggregations
494+
)
472495

473496
def serialize_events(
474497
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any

synapse/handlers/room.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Tuple,
3131
)
3232

33+
import attr
3334
from typing_extensions import TypedDict
3435

3536
from synapse.api.constants import (
@@ -60,6 +61,7 @@
6061
from synapse.federation.federation_client import InvalidResponseError
6162
from synapse.handlers.federation import get_domains_from_state
6263
from synapse.rest.admin._base import assert_user_is_admin
64+
from synapse.storage.databases.main.relations import BundledAggregations
6365
from synapse.storage.state import StateFilter
6466
from synapse.streams import EventSource
6567
from synapse.types import (
@@ -90,6 +92,17 @@
9092
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
9193

9294

95+
@attr.s(slots=True, frozen=True, auto_attribs=True)
96+
class EventContext:
97+
events_before: List[EventBase]
98+
event: EventBase
99+
events_after: List[EventBase]
100+
state: List[EventBase]
101+
aggregations: Dict[str, BundledAggregations]
102+
start: str
103+
end: str
104+
105+
93106
class RoomCreationHandler:
94107
def __init__(self, hs: "HomeServer"):
95108
self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ async def get_event_context(
11191132
limit: int,
11201133
event_filter: Optional[Filter],
11211134
use_admin_priviledge: bool = False,
1122-
) -> Optional[JsonDict]:
1135+
) -> Optional[EventContext]:
11231136
"""Retrieves events, pagination tokens and state around a given event
11241137
in a room.
11251138
@@ -1167,48 +1180,38 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
11671180
results = await self.store.get_events_around(
11681181
room_id, event_id, before_limit, after_limit, event_filter
11691182
)
1183+
events_before = results.events_before
1184+
events_after = results.events_after
11701185

11711186
if event_filter:
1172-
results["events_before"] = await event_filter.filter(
1173-
results["events_before"]
1174-
)
1175-
results["events_after"] = await event_filter.filter(results["events_after"])
1187+
events_before = await event_filter.filter(events_before)
1188+
events_after = await event_filter.filter(events_after)
11761189

1177-
results["events_before"] = await filter_evts(results["events_before"])
1178-
results["events_after"] = await filter_evts(results["events_after"])
1190+
events_before = await filter_evts(events_before)
1191+
events_after = await filter_evts(events_after)
11791192
# filter_evts can return a pruned event in case the user is allowed to see that
11801193
# there's something there but not see the content, so use the event that's in
11811194
# `filtered` rather than the event we retrieved from the datastore.
1182-
results["event"] = filtered[0]
1195+
event = filtered[0]
11831196

11841197
# Fetch the aggregations.
11851198
aggregations = await self.store.get_bundled_aggregations(
1186-
[results["event"]], user.to_string()
1199+
itertools.chain(events_before, (event,), events_after),
1200+
user.to_string(),
11871201
)
1188-
aggregations.update(
1189-
await self.store.get_bundled_aggregations(
1190-
results["events_before"], user.to_string()
1191-
)
1192-
)
1193-
aggregations.update(
1194-
await self.store.get_bundled_aggregations(
1195-
results["events_after"], user.to_string()
1196-
)
1197-
)
1198-
results["aggregations"] = aggregations
11991202

1200-
if results["events_after"]:
1201-
last_event_id = results["events_after"][-1].event_id
1203+
if events_after:
1204+
last_event_id = events_after[-1].event_id
12021205
else:
12031206
last_event_id = event_id
12041207

12051208
if event_filter and event_filter.lazy_load_members:
12061209
state_filter = StateFilter.from_lazy_load_member_list(
12071210
ev.sender
12081211
for ev in itertools.chain(
1209-
results["events_before"],
1210-
(results["event"],),
1211-
results["events_after"],
1212+
events_before,
1213+
(event,),
1214+
events_after,
12121215
)
12131216
)
12141217
else:
@@ -1226,21 +1229,23 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
12261229
if event_filter:
12271230
state_events = await event_filter.filter(state_events)
12281231

1229-
results["state"] = await filter_evts(state_events)
1230-
12311232
# We use a dummy token here as we only care about the room portion of
12321233
# the token, which we replace.
12331234
token = StreamToken.START
12341235

1235-
results["start"] = await token.copy_and_replace(
1236-
"room_key", results["start"]
1237-
).to_string(self.store)
1238-
1239-
results["end"] = await token.copy_and_replace(
1240-
"room_key", results["end"]
1241-
).to_string(self.store)
1242-
1243-
return results
1236+
return EventContext(
1237+
events_before=events_before,
1238+
event=event,
1239+
events_after=events_after,
1240+
state=await filter_evts(state_events),
1241+
aggregations=aggregations,
1242+
start=await token.copy_and_replace("room_key", results.start).to_string(
1243+
self.store
1244+
),
1245+
end=await token.copy_and_replace("room_key", results.end).to_string(
1246+
self.store
1247+
),
1248+
)
12441249

12451250

12461251
class TimestampLookupHandler:

synapse/handlers/search.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -361,36 +361,37 @@ async def search(
361361

362362
logger.info(
363363
"Context for search returned %d and %d events",
364-
len(res["events_before"]),
365-
len(res["events_after"]),
364+
len(res.events_before),
365+
len(res.events_after),
366366
)
367367

368-
res["events_before"] = await filter_events_for_client(
369-
self.storage, user.to_string(), res["events_before"]
368+
events_before = await filter_events_for_client(
369+
self.storage, user.to_string(), res.events_before
370370
)
371371

372-
res["events_after"] = await filter_events_for_client(
373-
self.storage, user.to_string(), res["events_after"]
372+
events_after = await filter_events_for_client(
373+
self.storage, user.to_string(), res.events_after
374374
)
375375

376-
res["start"] = await now_token.copy_and_replace(
377-
"room_key", res["start"]
378-
).to_string(self.store)
379-
380-
res["end"] = await now_token.copy_and_replace(
381-
"room_key", res["end"]
382-
).to_string(self.store)
376+
context = {
377+
"events_before": events_before,
378+
"events_after": events_after,
379+
"start": await now_token.copy_and_replace(
380+
"room_key", res.start
381+
).to_string(self.store),
382+
"end": await now_token.copy_and_replace(
383+
"room_key", res.end
384+
).to_string(self.store),
385+
}
383386

384387
if include_profile:
385388
senders = {
386389
ev.sender
387-
for ev in itertools.chain(
388-
res["events_before"], [event], res["events_after"]
389-
)
390+
for ev in itertools.chain(events_before, [event], events_after)
390391
}
391392

392-
if res["events_after"]:
393-
last_event_id = res["events_after"][-1].event_id
393+
if events_after:
394+
last_event_id = events_after[-1].event_id
394395
else:
395396
last_event_id = event.event_id
396397

@@ -402,7 +403,7 @@ async def search(
402403
last_event_id, state_filter
403404
)
404405

405-
res["profile_info"] = {
406+
context["profile_info"] = {
406407
s.state_key: {
407408
"displayname": s.content.get("displayname", None),
408409
"avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ async def search(
411412
if s.type == EventTypes.Member and s.state_key in senders
412413
}
413414

414-
contexts[event.event_id] = res
415+
contexts[event.event_id] = context
415416
else:
416417
contexts = {}
417418

@@ -421,10 +422,10 @@ async def search(
421422

422423
for context in contexts.values():
423424
context["events_before"] = self._event_serializer.serialize_events(
424-
context["events_before"], time_now
425+
context["events_before"], time_now # type: ignore[arg-type]
425426
)
426427
context["events_after"] = self._event_serializer.serialize_events(
427-
context["events_after"], time_now
428+
context["events_after"], time_now # type: ignore[arg-type]
428429
)
429430

430431
state_results = {}

synapse/handlers/sync.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
3838
from synapse.push.clientformat import format_push_rules_for_user
3939
from synapse.storage.databases.main.event_push_actions import NotifCounts
40+
from synapse.storage.databases.main.relations import BundledAggregations
4041
from synapse.storage.roommember import MemberSummary
4142
from synapse.storage.state import StateFilter
4243
from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
100101
limited: bool
101102
# A mapping of event ID to the bundled aggregations for the above events.
102103
# This is only calculated if limited is true.
103-
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
104+
bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
104105

105106
def __bool__(self) -> bool:
106107
"""Make the result appear empty if there are no updates. This is used

synapse/push/mailer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ async def _get_notif_vars(
455455
}
456456

457457
the_events = await filter_events_for_client(
458-
self.storage, user_id, results["events_before"]
458+
self.storage, user_id, results.events_before
459459
)
460460
the_events.append(notif_event)
461461

0 commit comments

Comments
 (0)