Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 56fc75c

Browse files
committed
Handle threads when fetching events for push.
1 parent 4ebb038 commit 56fc75c

File tree

2 files changed

+58
-20
lines changed

2 files changed

+58
-20
lines changed

changelog.d/13878.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

synapse/storage/databases/main/event_push_actions.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,32 @@
119119
]
120120

121121

122+
@attr.s(slots=True, auto_attribs=True)
123+
class _RoomReceipt:
124+
"""
125+
HttpPushAction instances include the information used to generate HTTP
126+
requests to a push gateway.
127+
"""
128+
129+
unthreaded_stream_ordering: int = 0
130+
# threaded_stream_ordering includes the main pseudo-thread.
131+
threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)
132+
133+
def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
134+
"""Returns True if the stream ordering is unread according to the receipt information."""
135+
136+
# Only include push actions with a stream ordering after both the unthreaded
137+
# and threaded receipt. Properly handles a user without any receipts present.
138+
return (
139+
self.unthreaded_stream_ordering < stream_ordering
140+
and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
141+
)
142+
143+
144+
# A _RoomReceipt with no receipts in it.
145+
MISSING_ROOM_RECEIPT = _RoomReceipt()
146+
147+
122148
@attr.s(slots=True, frozen=True, auto_attribs=True)
123149
class HttpPushAction:
124150
"""
@@ -559,7 +585,7 @@ def f(txn: LoggingTransaction) -> List[str]:
559585

560586
def _get_receipts_by_room_txn(
561587
self, txn: LoggingTransaction, user_id: str
562-
) -> Dict[str, int]:
588+
) -> Dict[str, _RoomReceipt]:
563589
"""
564590
Generate a map of room ID to the latest stream ordering that has been
565591
read by the given user.
@@ -569,7 +595,8 @@ def _get_receipts_by_room_txn(
569595
user_id: The user to fetch receipts for.
570596
571597
Returns:
572-
A map of room ID to stream ordering for all rooms the user has a receipt in.
598+
A map including all rooms the user is in with a receipt. It maps
599+
room IDs to _RoomReceipt instances
573600
"""
574601
receipt_types_clause, args = make_in_list_sql_clause(
575602
self.database_engine,
@@ -581,17 +608,26 @@ def _get_receipts_by_room_txn(
581608
)
582609

583610
sql = f"""
584-
SELECT room_id, MAX(stream_ordering)
611+
SELECT room_id, thread_id, MAX(stream_ordering)
585612
FROM receipts_linearized
586613
INNER JOIN events USING (room_id, event_id)
587614
WHERE {receipt_types_clause}
588615
AND user_id = ?
589-
GROUP BY room_id
616+
GROUP BY room_id, thread_id
590617
"""
591618

592619
args.extend((user_id,))
593620
txn.execute(sql, args)
594-
return dict(cast(List[Tuple[str, int]], txn.fetchall()))
621+
622+
result: Dict[str, _RoomReceipt] = {}
623+
for room_id, thread_id, stream_ordering in txn:
624+
room_receipt = result.setdefault(room_id, _RoomReceipt())
625+
if thread_id is None:
626+
room_receipt.unthreaded_stream_ordering = stream_ordering
627+
else:
628+
room_receipt.threaded_stream_ordering[thread_id] = stream_ordering
629+
630+
return result
595631

596632
async def get_unread_push_actions_for_user_in_range_for_http(
597633
self,
@@ -624,9 +660,10 @@ async def get_unread_push_actions_for_user_in_range_for_http(
624660

625661
def get_push_actions_txn(
626662
txn: LoggingTransaction,
627-
) -> List[Tuple[str, str, int, str, bool]]:
663+
) -> List[Tuple[str, str, str, int, str, bool]]:
628664
sql = """
629-
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
665+
SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
666+
ep.actions, ep.highlight
630667
FROM event_push_actions AS ep
631668
WHERE
632669
ep.user_id = ?
@@ -636,7 +673,7 @@ def get_push_actions_txn(
636673
ORDER BY ep.stream_ordering ASC LIMIT ?
637674
"""
638675
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
639-
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
676+
return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())
640677

641678
push_actions = await self.db_pool.runInteraction(
642679
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
@@ -649,10 +686,10 @@ def get_push_actions_txn(
649686
stream_ordering=stream_ordering,
650687
actions=_deserialize_action(actions, highlight),
651688
)
652-
for event_id, room_id, stream_ordering, actions, highlight in push_actions
653-
# Only include push actions with a stream ordering after any receipt, or without any
654-
# receipt present (invited to but never read rooms).
655-
if stream_ordering > receipts_by_room.get(room_id, 0)
689+
for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
690+
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
691+
thread_id, stream_ordering
692+
)
656693
]
657694

658695
# Now sort it so it's ordered correctly, since currently it will
@@ -696,10 +733,10 @@ async def get_unread_push_actions_for_user_in_range_for_email(
696733

697734
def get_push_actions_txn(
698735
txn: LoggingTransaction,
699-
) -> List[Tuple[str, str, int, str, bool, int]]:
736+
) -> List[Tuple[str, str, str, int, str, bool, int]]:
700737
sql = """
701-
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
702-
ep.highlight, e.received_ts
738+
SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
739+
ep.actions, ep.highlight, e.received_ts
703740
FROM event_push_actions AS ep
704741
INNER JOIN events AS e USING (room_id, event_id)
705742
WHERE
@@ -710,7 +747,7 @@ def get_push_actions_txn(
710747
ORDER BY ep.stream_ordering DESC LIMIT ?
711748
"""
712749
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
713-
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
750+
return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())
714751

715752
push_actions = await self.db_pool.runInteraction(
716753
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
@@ -725,10 +762,10 @@ def get_push_actions_txn(
725762
actions=_deserialize_action(actions, highlight),
726763
received_ts=received_ts,
727764
)
728-
for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
729-
# Only include push actions with a stream ordering after any receipt, or without any
730-
# receipt present (invited to but never read rooms).
731-
if stream_ordering > receipts_by_room.get(room_id, 0)
765+
for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
766+
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
767+
thread_id, stream_ordering
768+
)
732769
]
733770

734771
# Now sort it so it's ordered correctly, since currently it will

0 commit comments

Comments
 (0)