15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
import logging
18
- from typing import Any , Dict , List , Optional , Tuple
18
+ from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Set , Tuple
19
19
20
20
from synapse .api import errors
21
21
from synapse .api .constants import EventTypes
29
29
from synapse .logging .opentracing import log_kv , set_tag , trace
30
30
from synapse .metrics .background_process_metrics import run_as_background_process
31
31
from synapse .types import (
32
+ Collection ,
32
33
JsonDict ,
33
34
StreamToken ,
35
+ UserID ,
34
36
get_domain_from_id ,
35
37
get_verify_key_from_cross_signing_key ,
36
38
)
42
44
43
45
from ._base import BaseHandler
44
46
47
+ if TYPE_CHECKING :
48
+ from synapse .app .homeserver import HomeServer
49
+
45
50
logger = logging .getLogger (__name__ )
46
51
47
52
MAX_DEVICE_DISPLAY_NAME_LEN = 100
48
53
49
54
50
55
class DeviceWorkerHandler (BaseHandler ):
51
- def __init__ (self , hs ):
56
+ def __init__ (self , hs : "HomeServer" ):
52
57
super ().__init__ (hs )
53
58
54
59
self .hs = hs
@@ -106,7 +111,9 @@ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
106
111
107
112
@trace
108
113
@measure_func ("device.get_user_ids_changed" )
109
- async def get_user_ids_changed (self , user_id : str , from_token : StreamToken ):
114
+ async def get_user_ids_changed (
115
+ self , user_id : str , from_token : StreamToken
116
+ ) -> JsonDict :
110
117
"""Get list of users that have had the devices updated, or have newly
111
118
joined a room, that `user_id` may be interested in.
112
119
"""
@@ -222,16 +229,16 @@ async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
222
229
possibly_joined = possibly_changed & users_who_share_room
223
230
possibly_left = (possibly_changed | possibly_left ) - users_who_share_room
224
231
else :
225
- possibly_joined = []
226
- possibly_left = []
232
+ possibly_joined = set ()
233
+ possibly_left = set ()
227
234
228
235
result = {"changed" : list (possibly_joined ), "left" : list (possibly_left )}
229
236
230
237
log_kv (result )
231
238
232
239
return result
233
240
234
- async def on_federation_query_user_devices (self , user_id ) :
241
+ async def on_federation_query_user_devices (self , user_id : str ) -> JsonDict :
235
242
stream_id , devices = await self .store .get_e2e_device_keys_for_federation_query (
236
243
user_id
237
244
)
@@ -250,7 +257,7 @@ async def on_federation_query_user_devices(self, user_id):
250
257
251
258
252
259
class DeviceHandler (DeviceWorkerHandler ):
253
- def __init__ (self , hs ):
260
+ def __init__ (self , hs : "HomeServer" ):
254
261
super ().__init__ (hs )
255
262
256
263
self .federation_sender = hs .get_federation_sender ()
@@ -265,7 +272,7 @@ def __init__(self, hs):
265
272
266
273
hs .get_distributor ().observe ("user_left_room" , self .user_left_room )
267
274
268
- def _check_device_name_length (self , name : str ):
275
+ def _check_device_name_length (self , name : Optional [ str ] ):
269
276
"""
270
277
Checks whether a device name is longer than the maximum allowed length.
271
278
@@ -284,21 +291,23 @@ def _check_device_name_length(self, name: str):
284
291
)
285
292
286
293
async def check_device_registered (
287
- self , user_id , device_id , initial_device_display_name = None
288
- ):
294
+ self ,
295
+ user_id : str ,
296
+ device_id : Optional [str ],
297
+ initial_device_display_name : Optional [str ] = None ,
298
+ ) -> str :
289
299
"""
290
300
If the given device has not been registered, register it with the
291
301
supplied display name.
292
302
293
303
If no device_id is supplied, we make one up.
294
304
295
305
Args:
296
- user_id (str): @user:id
297
- device_id (str | None): device id supplied by client
298
- initial_device_display_name (str | None): device display name from
299
- client
306
+ user_id: @user:id
307
+ device_id: device id supplied by client
308
+ initial_device_display_name: device display name from client
300
309
Returns:
301
- str: device id (generated if none was supplied)
310
+ device id (generated if none was supplied)
302
311
"""
303
312
304
313
self ._check_device_name_length (initial_device_display_name )
@@ -317,15 +326,15 @@ async def check_device_registered(
317
326
# times in case of a clash.
318
327
attempts = 0
319
328
while attempts < 5 :
320
- device_id = stringutils .random_string (10 ).upper ()
329
+ new_device_id = stringutils .random_string (10 ).upper ()
321
330
new_device = await self .store .store_device (
322
331
user_id = user_id ,
323
- device_id = device_id ,
332
+ device_id = new_device_id ,
324
333
initial_device_display_name = initial_device_display_name ,
325
334
)
326
335
if new_device :
327
- await self .notify_device_update (user_id , [device_id ])
328
- return device_id
336
+ await self .notify_device_update (user_id , [new_device_id ])
337
+ return new_device_id
329
338
attempts += 1
330
339
331
340
raise errors .StoreError (500 , "Couldn't generate a device ID." )
@@ -434,7 +443,9 @@ async def update_device(self, user_id: str, device_id: str, content: dict) -> No
434
443
435
444
@trace
436
445
@measure_func ("notify_device_update" )
437
- async def notify_device_update (self , user_id , device_ids ):
446
+ async def notify_device_update (
447
+ self , user_id : str , device_ids : Collection [str ]
448
+ ) -> None :
438
449
"""Notify that a user's device(s) has changed. Pokes the notifier, and
439
450
remote servers if the user is local.
440
451
"""
@@ -446,7 +457,7 @@ async def notify_device_update(self, user_id, device_ids):
446
457
user_id
447
458
)
448
459
449
- hosts = set ()
460
+ hosts = set () # type: Set[str]
450
461
if self .hs .is_mine_id (user_id ):
451
462
hosts .update (get_domain_from_id (u ) for u in users_who_share_room )
452
463
hosts .discard (self .server_name )
@@ -498,7 +509,7 @@ async def notify_user_signature_update(
498
509
499
510
self .notifier .on_new_event ("device_list_key" , position , users = [from_user_id ])
500
511
501
- async def user_left_room (self , user , room_id ) :
512
+ async def user_left_room (self , user : UserID , room_id : str ) -> None :
502
513
user_id = user .to_string ()
503
514
room_ids = await self .store .get_rooms_for_user (user_id )
504
515
if not room_ids :
@@ -586,15 +597,17 @@ async def rehydrate_device(
586
597
return {"success" : True }
587
598
588
599
589
- def _update_device_from_client_ips (device , client_ips ):
600
+ def _update_device_from_client_ips (
601
+ device : Dict [str , Any ], client_ips : Dict [Tuple [str , str ], Dict [str , Any ]]
602
+ ) -> None :
590
603
ip = client_ips .get ((device ["user_id" ], device ["device_id" ]), {})
591
604
device .update ({"last_seen_ts" : ip .get ("last_seen" ), "last_seen_ip" : ip .get ("ip" )})
592
605
593
606
594
607
class DeviceListUpdater :
595
608
"Handles incoming device list updates from federation and updates the DB"
596
609
597
- def __init__ (self , hs , device_handler ):
610
+ def __init__ (self , hs : "HomeServer" , device_handler : DeviceHandler ):
598
611
self .store = hs .get_datastore ()
599
612
self .federation = hs .get_federation_client ()
600
613
self .clock = hs .get_clock ()
@@ -603,7 +616,9 @@ def __init__(self, hs, device_handler):
603
616
self ._remote_edu_linearizer = Linearizer (name = "remote_device_list" )
604
617
605
618
# user_id -> list of updates waiting to be handled.
606
- self ._pending_updates = {}
619
+ self ._pending_updates = (
620
+ {}
621
+ ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
607
622
608
623
# Recently seen stream ids. We don't bother keeping these in the DB,
609
624
# but they're useful to have them about to reduce the number of spurious
@@ -626,7 +641,9 @@ def __init__(self, hs, device_handler):
626
641
)
627
642
628
643
@trace
629
- async def incoming_device_list_update (self , origin , edu_content ):
644
+ async def incoming_device_list_update (
645
+ self , origin : str , edu_content : JsonDict
646
+ ) -> None :
630
647
"""Called on incoming device list update from federation. Responsible
631
648
for parsing the EDU and adding to pending updates list.
632
649
"""
@@ -687,7 +704,7 @@ async def incoming_device_list_update(self, origin, edu_content):
687
704
await self ._handle_device_updates (user_id )
688
705
689
706
@measure_func ("_incoming_device_list_update" )
690
- async def _handle_device_updates (self , user_id ) :
707
+ async def _handle_device_updates (self , user_id : str ) -> None :
691
708
"Actually handle pending updates."
692
709
693
710
with (await self ._remote_edu_linearizer .queue (user_id )):
@@ -735,7 +752,9 @@ async def _handle_device_updates(self, user_id):
735
752
stream_id for _ , stream_id , _ , _ in pending_updates
736
753
)
737
754
738
- async def _need_to_do_resync (self , user_id , updates ):
755
+ async def _need_to_do_resync (
756
+ self , user_id : str , updates : Iterable [Tuple [str , str , Iterable [str ], JsonDict ]]
757
+ ) -> bool :
739
758
"""Given a list of updates for a user figure out if we need to do a full
740
759
resync, or whether we have enough data that we can just apply the delta.
741
760
"""
@@ -766,7 +785,7 @@ async def _need_to_do_resync(self, user_id, updates):
766
785
return False
767
786
768
787
@trace
769
- async def _maybe_retry_device_resync (self ):
788
+ async def _maybe_retry_device_resync (self ) -> None :
770
789
"""Retry to resync device lists that are out of sync, except if another retry is
771
790
in progress.
772
791
"""
@@ -809,7 +828,7 @@ async def _maybe_retry_device_resync(self):
809
828
810
829
async def user_device_resync (
811
830
self , user_id : str , mark_failed_as_stale : bool = True
812
- ) -> Optional [dict ]:
831
+ ) -> Optional [JsonDict ]:
813
832
"""Fetches all devices for a user and updates the device cache with them.
814
833
815
834
Args:
@@ -833,7 +852,7 @@ async def user_device_resync(
833
852
# it later.
834
853
await self .store .mark_remote_user_device_cache_as_stale (user_id )
835
854
836
- return
855
+ return None
837
856
except (RequestSendFailed , HttpResponseException ) as e :
838
857
logger .warning (
839
858
"Failed to handle device list update for %s: %s" , user_id , e ,
@@ -850,12 +869,12 @@ async def user_device_resync(
850
869
# next time we get a device list update for this user_id.
851
870
# This makes it more likely that the device lists will
852
871
# eventually become consistent.
853
- return
872
+ return None
854
873
except FederationDeniedError as e :
855
874
set_tag ("error" , True )
856
875
log_kv ({"reason" : "FederationDeniedError" })
857
876
logger .info (e )
858
- return
877
+ return None
859
878
except Exception as e :
860
879
set_tag ("error" , True )
861
880
log_kv (
@@ -868,7 +887,7 @@ async def user_device_resync(
868
887
# it later.
869
888
await self .store .mark_remote_user_device_cache_as_stale (user_id )
870
889
871
- return
890
+ return None
872
891
log_kv ({"result" : result })
873
892
stream_id = result ["stream_id" ]
874
893
devices = result ["devices" ]
@@ -929,7 +948,7 @@ async def process_cross_signing_key_update(
929
948
user_id : str ,
930
949
master_key : Optional [Dict [str , Any ]],
931
950
self_signing_key : Optional [Dict [str , Any ]],
932
- ) -> list :
951
+ ) -> List [ str ] :
933
952
"""Process the given new master and self-signing key for the given remote user.
934
953
935
954
Args:
0 commit comments