119
119
]
120
120
121
121
122
+ @attr .s (slots = True , auto_attribs = True )
123
+ class _RoomReceipt :
124
+ """
125
+ HttpPushAction instances include the information used to generate HTTP
126
+ requests to a push gateway.
127
+ """
128
+
129
+ unthreaded_stream_ordering : int = 0
130
+ # threaded_stream_ordering includes the main pseudo-thread.
131
+ threaded_stream_ordering : Dict [str , int ] = attr .Factory (dict )
132
+
133
+ def is_unread (self , thread_id : str , stream_ordering : int ) -> bool :
134
+ """Returns True if the stream ordering is unread according to the receipt information."""
135
+
136
+ # Only include push actions with a stream ordering after both the unthreaded
137
+ # and threaded receipt. Properly handles a user without any receipts present.
138
+ return (
139
+ self .unthreaded_stream_ordering < stream_ordering
140
+ and self .threaded_stream_ordering .get (thread_id , 0 ) < stream_ordering
141
+ )
142
+
143
+
144
+ # A _RoomReceipt with no receipts in it.
145
+ MISSING_ROOM_RECEIPT = _RoomReceipt ()
146
+
147
+
122
148
@attr .s (slots = True , frozen = True , auto_attribs = True )
123
149
class HttpPushAction :
124
150
"""
@@ -589,7 +615,7 @@ def f(txn: LoggingTransaction) -> List[str]:
589
615
590
616
def _get_receipts_by_room_txn (
591
617
self , txn : LoggingTransaction , user_id : str
592
- ) -> Dict [str , int ]:
618
+ ) -> Dict [str , _RoomReceipt ]:
593
619
"""
594
620
Generate a map of room ID to the latest stream ordering that has been
595
621
read by the given user.
@@ -599,7 +625,8 @@ def _get_receipts_by_room_txn(
599
625
user_id: The user to fetch receipts for.
600
626
601
627
Returns:
602
- A map of room ID to stream ordering for all rooms the user has a receipt in.
628
+ A map including all rooms the user is in with a receipt. It maps
629
+ room IDs to _RoomReceipt instances
603
630
"""
604
631
receipt_types_clause , args = make_in_list_sql_clause (
605
632
self .database_engine ,
@@ -611,17 +638,26 @@ def _get_receipts_by_room_txn(
611
638
)
612
639
613
640
sql = f"""
614
- SELECT room_id, MAX(stream_ordering)
641
+ SELECT room_id, thread_id, MAX(stream_ordering)
615
642
FROM receipts_linearized
616
643
INNER JOIN events USING (room_id, event_id)
617
644
WHERE { receipt_types_clause }
618
645
AND user_id = ?
619
- GROUP BY room_id
646
+ GROUP BY room_id, thread_id
620
647
"""
621
648
622
649
args .extend ((user_id ,))
623
650
txn .execute (sql , args )
624
- return dict (cast (List [Tuple [str , int ]], txn .fetchall ()))
651
+
652
+ result = {}
653
+ for room_id , thread_id , stream_ordering in txn :
654
+ room_receipt = result .setdefault (room_id , _RoomReceipt ())
655
+ if thread_id is None :
656
+ room_receipt .unthreaded_stream_ordering = stream_ordering
657
+ else :
658
+ room_receipt .threaded_stream_ordering [thread_id ] = stream_ordering
659
+
660
+ return result
625
661
626
662
async def get_unread_push_actions_for_user_in_range_for_http (
627
663
self ,
@@ -656,7 +692,8 @@ def get_push_actions_txn(
656
692
txn : LoggingTransaction ,
657
693
) -> List [Tuple [str , str , int , str , bool ]]:
658
694
sql = """
659
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
695
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
696
+ ep.actions, ep.highlight
660
697
FROM event_push_actions AS ep
661
698
WHERE
662
699
ep.user_id = ?
@@ -679,10 +716,10 @@ def get_push_actions_txn(
679
716
stream_ordering = stream_ordering ,
680
717
actions = _deserialize_action (actions , highlight ),
681
718
)
682
- for event_id , room_id , stream_ordering , actions , highlight in push_actions
683
- # Only include push actions with a stream ordering after any receipt, or without any
684
- # receipt present (invited to but never read rooms).
685
- if stream_ordering > receipts_by_room . get ( room_id , 0 )
719
+ for event_id , room_id , thread_id , stream_ordering , actions , highlight in push_actions
720
+ if receipts_by_room . get ( room_id , MISSING_ROOM_RECEIPT ). is_unread (
721
+ thread_id , stream_ordering
722
+ )
686
723
]
687
724
688
725
# Now sort it so it's ordered correctly, since currently it will
@@ -728,8 +765,8 @@ def get_push_actions_txn(
728
765
txn : LoggingTransaction ,
729
766
) -> List [Tuple [str , str , int , str , bool , int ]]:
730
767
sql = """
731
- SELECT ep.event_id, ep.room_id, ep.stream_ordering , ep.actions ,
732
- ep.highlight, e.received_ts
768
+ SELECT ep.event_id, ep.room_id, ep.thread_id , ep.stream_ordering ,
769
+ ep.actions, ep. highlight, e.received_ts
733
770
FROM event_push_actions AS ep
734
771
INNER JOIN events AS e USING (room_id, event_id)
735
772
WHERE
@@ -755,10 +792,10 @@ def get_push_actions_txn(
755
792
actions = _deserialize_action (actions , highlight ),
756
793
received_ts = received_ts ,
757
794
)
758
- for event_id , room_id , stream_ordering , actions , highlight , received_ts in push_actions
759
- # Only include push actions with a stream ordering after any receipt, or without any
760
- # receipt present (invited to but never read rooms).
761
- if stream_ordering > receipts_by_room . get ( room_id , 0 )
795
+ for event_id , room_id , thread_id , stream_ordering , actions , highlight , received_ts in push_actions
796
+ if receipts_by_room . get ( room_id , MISSING_ROOM_RECEIPT ). is_unread (
797
+ thread_id , stream_ordering
798
+ )
762
799
]
763
800
764
801
# Now sort it so it's ordered correctly, since currently it will
0 commit comments