13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import TYPE_CHECKING , Dict , Iterable , List , Optional , Tuple , Union , cast
16
+ from typing import (
17
+ TYPE_CHECKING ,
18
+ Collection ,
19
+ Dict ,
20
+ Iterable ,
21
+ List ,
22
+ Optional ,
23
+ Tuple ,
24
+ Union ,
25
+ cast ,
26
+ )
17
27
18
28
import attr
19
29
from frozendict import frozendict
20
30
21
- from synapse .api .constants import EventTypes , RelationTypes
31
+ from synapse .api .constants import RelationTypes
22
32
from synapse .events import EventBase
23
33
from synapse .storage ._base import SQLBaseStore
24
34
from synapse .storage .database import (
28
38
make_in_list_sql_clause ,
29
39
)
30
40
from synapse .storage .databases .main .stream import generate_pagination_where_clause
41
+ from synapse .storage .engines import PostgresEngine
31
42
from synapse .storage .relations import (
32
43
AggregationPaginationToken ,
33
44
PaginationChunk ,
34
45
RelationPaginationToken ,
35
46
)
36
47
from synapse .types import JsonDict
37
- from synapse .util .caches .descriptors import cached
48
+ from synapse .util .caches .descriptors import cached , cachedList
38
49
39
50
if TYPE_CHECKING :
40
51
from synapse .server import HomeServer
@@ -340,20 +351,24 @@ def _get_aggregation_groups_for_event_txn(
340
351
)
341
352
342
353
@cached ()
343
- async def get_applicable_edit (
344
- self , event_id : str , room_id : str
345
- ) -> Optional [EventBase ]:
354
+ def get_applicable_edit (self , event_id : str ) -> Optional [EventBase ]:
355
+ raise NotImplementedError ()
356
+
357
+ @cachedList (cached_method_name = "get_applicable_edit" , list_name = "event_ids" )
358
+ async def _get_applicable_edits (
359
+ self , event_ids : Collection [str ]
360
+ ) -> Dict [str , Optional [EventBase ]]:
346
361
"""Get the most recent edit (if any) that has happened for the given
347
- event .
362
+ events .
348
363
349
364
Correctly handles checking whether edits were allowed to happen.
350
365
351
366
Args:
352
- event_id: The original event ID
353
- room_id: The original event's room ID
367
+ event_ids: The original event IDs
354
368
355
369
Returns:
356
- The most recent edit, if any.
370
+ A map of the most recent edit for each event. If there are no edits,
371
+ the event will map to None.
357
372
"""
358
373
359
374
# We only allow edits for `m.room.message` events that have the same sender
@@ -362,37 +377,67 @@ async def get_applicable_edit(
362
377
363
378
# Fetches latest edit that has the same type and sender as the
364
379
# original, and is an `m.room.message`.
365
- sql = """
366
- SELECT edit.event_id FROM events AS edit
367
- INNER JOIN event_relations USING (event_id)
368
- INNER JOIN events AS original ON
369
- original.event_id = relates_to_id
370
- AND edit.type = original.type
371
- AND edit.sender = original.sender
372
- WHERE
373
- relates_to_id = ?
374
- AND relation_type = ?
375
- AND edit.room_id = ?
376
- AND edit.type = 'm.room.message'
377
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
378
- LIMIT 1
379
- """
380
+ if isinstance (self .database_engine , PostgresEngine ):
381
+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
382
+ # so ordering by origin server ts + event ID desc will ensure we get
383
+ # the latest edit.
384
+ sql = """
385
+ SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
386
+ INNER JOIN event_relations USING (event_id)
387
+ INNER JOIN events AS original ON
388
+ original.event_id = relates_to_id
389
+ AND edit.type = original.type
390
+ AND edit.sender = original.sender
391
+ AND edit.room_id = original.room_id
392
+ WHERE
393
+ %s
394
+ AND relation_type = ?
395
+ AND edit.type = 'm.room.message'
396
+ ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
397
+ """
398
+ else :
399
+ # SQLite uses a simplified query which returns all edits for an
400
+ # original event. The results are then de-duplicated when turned into
401
+ # a dict. Due to the chosen ordering, the latest edit stomps on
402
+ # earlier edits.
403
+ sql = """
404
+ SELECT original.event_id, edit.event_id FROM events AS edit
405
+ INNER JOIN event_relations USING (event_id)
406
+ INNER JOIN events AS original ON
407
+ original.event_id = relates_to_id
408
+ AND edit.type = original.type
409
+ AND edit.sender = original.sender
410
+ AND edit.room_id = original.room_id
411
+ WHERE
412
+ %s
413
+ AND relation_type = ?
414
+ AND edit.type = 'm.room.message'
415
+ ORDER by edit.origin_server_ts, edit.event_id
416
+ """
380
417
381
- def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [str ]:
382
- txn .execute (sql , (event_id , RelationTypes .REPLACE , room_id ))
383
- row = txn .fetchone ()
384
- if row :
385
- return row [0 ]
386
- return None
418
+ def _get_applicable_edits_txn (txn : LoggingTransaction ) -> Dict [str , str ]:
419
+ clause , args = make_in_list_sql_clause (
420
+ txn .database_engine , "relates_to_id" , event_ids
421
+ )
422
+ args .append (RelationTypes .REPLACE )
387
423
388
- edit_id = await self .db_pool .runInteraction (
389
- "get_applicable_edit" , _get_applicable_edit_txn
424
+ txn .execute (sql % (clause ,), args )
425
+ return dict (cast (Iterable [Tuple [str , str ]], txn .fetchall ()))
426
+
427
+ edit_ids = await self .db_pool .runInteraction (
428
+ "get_applicable_edits" , _get_applicable_edits_txn
390
429
)
391
430
392
- if not edit_id :
393
- return None
431
+ edits = await self .get_events (edit_ids .values ()) # type: ignore[attr-defined]
394
432
395
- return await self .get_event (edit_id , allow_none = True ) # type: ignore[attr-defined]
433
+ # Map to the original event IDs to the edit events.
434
+ #
435
+ # There might not be an edit event due to there being no edits or
436
+ # due to the event not being known, either case is treated the same.
437
+ return {
438
+ original_event_id : edits .get (edit_ids .get (original_event_id ))
439
+ for original_event_id in event_ids
440
+ }
396
441
397
442
@cached ()
398
443
async def get_thread_summary (
@@ -612,9 +657,6 @@ async def _get_bundled_aggregation_for_event(
612
657
The bundled aggregations for an event, if bundled aggregations are
613
658
enabled and the event can have bundled aggregations.
614
659
"""
615
- # State events and redacted events do not get bundled aggregations.
616
- if event .is_state () or event .internal_metadata .is_redacted ():
617
- return None
618
660
619
661
# Do not bundle aggregations for an event which represents an edit or an
620
662
# annotation. It does not make sense for them to have related events.
@@ -642,13 +684,6 @@ async def _get_bundled_aggregation_for_event(
642
684
if references .chunk :
643
685
aggregations .references = references .to_dict ()
644
686
645
- edit = None
646
- if event .type == EventTypes .Message :
647
- edit = await self .get_applicable_edit (event_id , room_id )
648
-
649
- if edit :
650
- aggregations .replace = edit
651
-
652
687
# If this event is the start of a thread, include a summary of the replies.
653
688
if self ._msc3440_enabled :
654
689
thread_count , latest_thread_event = await self .get_thread_summary (
@@ -668,9 +703,7 @@ async def _get_bundled_aggregation_for_event(
668
703
return aggregations
669
704
670
705
async def get_bundled_aggregations (
671
- self ,
672
- events : Iterable [EventBase ],
673
- user_id : str ,
706
+ self , events : Iterable [EventBase ], user_id : str
674
707
) -> Dict [str , BundledAggregations ]:
675
708
"""Generate bundled aggregations for events.
676
709
@@ -683,13 +716,28 @@ async def get_bundled_aggregations(
683
716
events may have bundled aggregations in the results.
684
717
"""
685
718
686
- # TODO Parallelize.
687
- results = {}
719
+ # State events and redacted events do not get bundled aggregations.
720
+ events = [
721
+ event
722
+ for event in events
723
+ if not event .is_state () and not event .internal_metadata .is_redacted ()
724
+ ]
725
+
726
+ # event ID -> bundled aggregation in non-serialized form.
727
+ results : Dict [str , BundledAggregations ] = {}
728
+
729
+ # Fetch other relations per event.
688
730
for event in events :
689
731
event_result = await self ._get_bundled_aggregation_for_event (event , user_id )
690
732
if event_result :
691
733
results [event .event_id ] = event_result
692
734
735
+ # Fetch any edits.
736
+ event_ids = [event .event_id for event in events ]
737
+ edits = await self ._get_applicable_edits (event_ids )
738
+ for event_id , edit in edits .items ():
739
+ results .setdefault (event_id , BundledAggregations ()).replace = edit
740
+
693
741
return results
694
742
695
743
0 commit comments