119
119
]
120
120
121
121
122
+ @attr .s (slots = True , auto_attribs = True )
123
+ class _RoomReceipt :
124
+ """
125
+ HttpPushAction instances include the information used to generate HTTP
126
+ requests to a push gateway.
127
+ """
128
+
129
+ unthreaded_stream_ordering : int = 0
130
+ # threaded_stream_ordering includes the main pseudo-thread.
131
+ threaded_stream_ordering : Dict [str , int ] = attr .Factory (dict )
132
+
133
+ def is_unread (self , thread_id : str , stream_ordering : int ) -> bool :
134
+ """Returns True if the stream ordering is unread according to the receipt information."""
135
+
136
+ # Only include push actions with a stream ordering after both the unthreaded
137
+ # and threaded receipt. Properly handles a user without any receipts present.
138
+ return (
139
+ self .unthreaded_stream_ordering < stream_ordering
140
+ and self .threaded_stream_ordering .get (thread_id , 0 ) < stream_ordering
141
+ )
142
+
143
+
144
+ # A _RoomReceipt with no receipts in it.
145
+ MISSING_ROOM_RECEIPT = _RoomReceipt ()
146
+
147
+
122
148
@attr .s (slots = True , frozen = True , auto_attribs = True )
123
149
class HttpPushAction :
124
150
"""
@@ -559,7 +585,7 @@ def f(txn: LoggingTransaction) -> List[str]:
559
585
560
586
def _get_receipts_by_room_txn (
561
587
self , txn : LoggingTransaction , user_id : str
562
- ) -> Dict [str , int ]:
588
+ ) -> Dict [str , _RoomReceipt ]:
563
589
"""
564
590
Generate a map of room ID to the latest stream ordering that has been
565
591
read by the given user.
@@ -569,7 +595,8 @@ def _get_receipts_by_room_txn(
569
595
user_id: The user to fetch receipts for.
570
596
571
597
Returns:
572
- A map of room ID to stream ordering for all rooms the user has a receipt in.
598
+ A map including all rooms the user is in with a receipt. It maps
599
+ room IDs to _RoomReceipt instances
573
600
"""
574
601
receipt_types_clause , args = make_in_list_sql_clause (
575
602
self .database_engine ,
@@ -581,17 +608,26 @@ def _get_receipts_by_room_txn(
581
608
)
582
609
583
610
sql = f"""
584
- SELECT room_id, MAX(stream_ordering)
611
+ SELECT room_id, thread_id, MAX(stream_ordering)
585
612
FROM receipts_linearized
586
613
INNER JOIN events USING (room_id, event_id)
587
614
WHERE { receipt_types_clause }
588
615
AND user_id = ?
589
- GROUP BY room_id
616
+ GROUP BY room_id, thread_id
590
617
"""
591
618
592
619
args .extend ((user_id ,))
593
620
txn .execute (sql , args )
594
- return dict (cast (List [Tuple [str , int ]], txn .fetchall ()))
621
+
622
+ result : Dict [str , _RoomReceipt ] = {}
623
+ for room_id , thread_id , stream_ordering in txn :
624
+ room_receipt = result .setdefault (room_id , _RoomReceipt ())
625
+ if thread_id is None :
626
+ room_receipt .unthreaded_stream_ordering = stream_ordering
627
+ else :
628
+ room_receipt .threaded_stream_ordering [thread_id ] = stream_ordering
629
+
630
+ return result
595
631
596
632
async def get_unread_push_actions_for_user_in_range_for_http (
597
633
self ,
@@ -624,9 +660,10 @@ async def get_unread_push_actions_for_user_in_range_for_http(
624
660
625
661
def get_push_actions_txn (
626
662
txn : LoggingTransaction ,
627
- ) -> List [Tuple [str , str , int , str , bool ]]:
663
+ ) -> List [Tuple [str , str , str , int , str , bool ]]:
628
664
sql = """
629
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
665
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
666
+ ep.actions, ep.highlight
630
667
FROM event_push_actions AS ep
631
668
WHERE
632
669
ep.user_id = ?
@@ -636,7 +673,7 @@ def get_push_actions_txn(
636
673
ORDER BY ep.stream_ordering ASC LIMIT ?
637
674
"""
638
675
txn .execute (sql , (user_id , min_stream_ordering , max_stream_ordering , limit ))
639
- return cast (List [Tuple [str , str , int , str , bool ]], txn .fetchall ())
676
+ return cast (List [Tuple [str , str , str , int , str , bool ]], txn .fetchall ())
640
677
641
678
push_actions = await self .db_pool .runInteraction (
642
679
"get_unread_push_actions_for_user_in_range_http" , get_push_actions_txn
@@ -649,10 +686,10 @@ def get_push_actions_txn(
649
686
stream_ordering = stream_ordering ,
650
687
actions = _deserialize_action (actions , highlight ),
651
688
)
652
- for event_id , room_id , stream_ordering , actions , highlight in push_actions
653
- # Only include push actions with a stream ordering after any receipt, or without any
654
- # receipt present (invited to but never read rooms).
655
- if stream_ordering > receipts_by_room . get ( room_id , 0 )
689
+ for event_id , room_id , thread_id , stream_ordering , actions , highlight in push_actions
690
+ if receipts_by_room . get ( room_id , MISSING_ROOM_RECEIPT ). is_unread (
691
+ thread_id , stream_ordering
692
+ )
656
693
]
657
694
658
695
# Now sort it so it's ordered correctly, since currently it will
@@ -696,10 +733,10 @@ async def get_unread_push_actions_for_user_in_range_for_email(
696
733
697
734
def get_push_actions_txn (
698
735
txn : LoggingTransaction ,
699
- ) -> List [Tuple [str , str , int , str , bool , int ]]:
736
+ ) -> List [Tuple [str , str , str , int , str , bool , int ]]:
700
737
sql = """
701
- SELECT ep.event_id, ep.room_id, ep.stream_ordering , ep.actions ,
702
- ep.highlight, e.received_ts
738
+ SELECT ep.event_id, ep.room_id, ep.thread_id , ep.stream_ordering ,
739
+ ep.actions, ep. highlight, e.received_ts
703
740
FROM event_push_actions AS ep
704
741
INNER JOIN events AS e USING (room_id, event_id)
705
742
WHERE
@@ -710,7 +747,7 @@ def get_push_actions_txn(
710
747
ORDER BY ep.stream_ordering DESC LIMIT ?
711
748
"""
712
749
txn .execute (sql , (user_id , min_stream_ordering , max_stream_ordering , limit ))
713
- return cast (List [Tuple [str , str , int , str , bool , int ]], txn .fetchall ())
750
+ return cast (List [Tuple [str , str , str , int , str , bool , int ]], txn .fetchall ())
714
751
715
752
push_actions = await self .db_pool .runInteraction (
716
753
"get_unread_push_actions_for_user_in_range_email" , get_push_actions_txn
@@ -725,10 +762,10 @@ def get_push_actions_txn(
725
762
actions = _deserialize_action (actions , highlight ),
726
763
received_ts = received_ts ,
727
764
)
728
- for event_id , room_id , stream_ordering , actions , highlight , received_ts in push_actions
729
- # Only include push actions with a stream ordering after any receipt, or without any
730
- # receipt present (invited to but never read rooms).
731
- if stream_ordering > receipts_by_room . get ( room_id , 0 )
765
+ for event_id , room_id , thread_id , stream_ordering , actions , highlight , received_ts in push_actions
766
+ if receipts_by_room . get ( room_id , MISSING_ROOM_RECEIPT ). is_unread (
767
+ thread_id , stream_ordering
768
+ )
732
769
]
733
770
734
771
# Now sort it so it's ordered correctly, since currently it will
0 commit comments