28
28
Generic ,
29
29
Iterable ,
30
30
List ,
31
+ NoReturn ,
31
32
Optional ,
32
33
Tuple ,
33
34
Type ,
39
40
import canonicaljson
40
41
import signedjson .key
41
42
import unpaddedbase64
42
- from typing_extensions import Protocol
43
+ from typing_extensions import Concatenate , ParamSpec , Protocol
43
44
44
45
from twisted .internet .defer import Deferred , ensureDeferred
45
46
from twisted .python .failure import Failure
67
68
from synapse .rest import RegisterServletsFunc
68
69
from synapse .server import HomeServer
69
70
from synapse .storage .keys import FetchKeyResult
70
- from synapse .types import JsonDict , UserID , create_requester
71
+ from synapse .types import JsonDict , Requester , UserID , create_requester
71
72
from synapse .util import Clock
72
73
from synapse .util .httpresourcetree import create_resource_tree
73
74
88
89
TV = TypeVar ("TV" )
89
90
_ExcType = TypeVar ("_ExcType" , bound = BaseException , covariant = True )
90
91
92
+ P = ParamSpec ("P" )
93
+ R = TypeVar ("R" )
94
+ S = TypeVar ("S" )
95
+
91
96
92
97
class _TypedFailure (Generic [_ExcType ], Protocol ):
93
98
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@@ -97,7 +102,7 @@ def value(self) -> _ExcType:
97
102
...
98
103
99
104
100
- def around (target ) :
105
+ def around (target : TV ) -> Callable [[ Callable [ Concatenate [ S , P ], R ]], None ] :
101
106
"""A CLOS-style 'around' modifier, which wraps the original method of the
102
107
given instance with another piece of code.
103
108
@@ -106,11 +111,11 @@ def method_name(orig, *args, **kwargs):
106
111
return orig(*args, **kwargs)
107
112
"""
108
113
109
- def _around (code ) :
114
+ def _around (code : Callable [ Concatenate [ S , P ], R ]) -> None :
110
115
name = code .__name__
111
116
orig = getattr (target , name )
112
117
113
- def new (* args , ** kwargs ) :
118
+ def new (* args : P . args , ** kwargs : P . kwargs ) -> R :
114
119
return code (orig , * args , ** kwargs )
115
120
116
121
setattr (target , name , new )
@@ -131,7 +136,7 @@ def __init__(self, methodName: str):
131
136
level = getattr (method , "loglevel" , getattr (self , "loglevel" , None ))
132
137
133
138
@around (self )
134
- def setUp (orig ) :
139
+ def setUp (orig : Callable [[], R ]) -> R :
135
140
# if we're not starting in the sentinel logcontext, then to be honest
136
141
# all future bets are off.
137
142
if current_context ():
@@ -144,7 +149,7 @@ def setUp(orig):
144
149
if level is not None and old_level != level :
145
150
146
151
@around (self )
147
- def tearDown (orig ) :
152
+ def tearDown (orig : Callable [[], R ]) -> R :
148
153
ret = orig ()
149
154
logging .getLogger ().setLevel (old_level )
150
155
return ret
@@ -158,7 +163,7 @@ def tearDown(orig):
158
163
return orig ()
159
164
160
165
@around (self )
161
- def tearDown (orig ) :
166
+ def tearDown (orig : Callable [[], R ]) -> R :
162
167
ret = orig ()
163
168
# force a GC to workaround problems with deferreds leaking logcontexts when
164
169
# they are GCed (see the logcontext docs)
@@ -167,7 +172,7 @@ def tearDown(orig):
167
172
168
173
return ret
169
174
170
- def assertObjectHasAttributes (self , attrs , obj ) :
175
+ def assertObjectHasAttributes (self , attrs : Dict [ str , object ], obj : object ) -> None :
171
176
"""Asserts that the given object has each of the attributes given, and
172
177
that the value of each matches according to assertEqual."""
173
178
for key in attrs .keys ():
@@ -178,44 +183,44 @@ def assertObjectHasAttributes(self, attrs, obj):
178
183
except AssertionError as e :
179
184
raise (type (e ))(f"Assert error for '.{ key } ':" ) from e
180
185
181
- def assert_dict (self , required , actual ) :
186
+ def assert_dict (self , required : dict , actual : dict ) -> None :
182
187
"""Does a partial assert of a dict.
183
188
184
189
Args:
185
- required (dict) : The keys and value which MUST be in 'actual'.
186
- actual (dict) : The test result. Extra keys will not be checked.
190
+ required: The keys and value which MUST be in 'actual'.
191
+ actual: The test result. Extra keys will not be checked.
187
192
"""
188
193
for key in required :
189
194
self .assertEqual (
190
195
required [key ], actual [key ], msg = "%s mismatch. %s" % (key , actual )
191
196
)
192
197
193
198
194
- def DEBUG (target ) :
199
+ def DEBUG (target : TV ) -> TV :
195
200
"""A decorator to set the .loglevel attribute to logging.DEBUG.
196
201
Can apply to either a TestCase or an individual test method."""
197
- target .loglevel = logging .DEBUG
202
+ target .loglevel = logging .DEBUG # type: ignore[attr-defined]
198
203
return target
199
204
200
205
201
- def INFO (target ) :
206
+ def INFO (target : TV ) -> TV :
202
207
"""A decorator to set the .loglevel attribute to logging.INFO.
203
208
Can apply to either a TestCase or an individual test method."""
204
- target .loglevel = logging .INFO
209
+ target .loglevel = logging .INFO # type: ignore[attr-defined]
205
210
return target
206
211
207
212
208
- def logcontext_clean (target ) :
213
+ def logcontext_clean (target : TV ) -> TV :
209
214
"""A decorator which marks the TestCase or method as 'logcontext_clean'
210
215
211
216
... ie, any logcontext errors should cause a test failure
212
217
"""
213
218
214
- def logcontext_error (msg ) :
219
+ def logcontext_error (msg : str ) -> NoReturn :
215
220
raise AssertionError ("logcontext error: %s" % (msg ))
216
221
217
222
patcher = patch ("synapse.logging.context.logcontext_error" , new = logcontext_error )
218
- return patcher (target )
223
+ return patcher (target ) # type: ignore[call-overload]
219
224
220
225
221
226
class HomeserverTestCase (TestCase ):
@@ -255,7 +260,7 @@ def __init__(self, methodName: str):
255
260
method = getattr (self , methodName )
256
261
self ._extra_config = getattr (method , "_extra_config" , None )
257
262
258
- def setUp (self ):
263
+ def setUp (self ) -> None :
259
264
"""
260
265
Set up the TestCase by calling the homeserver constructor, optionally
261
266
hijacking the authentication system to return a fixed user, and then
@@ -306,15 +311,21 @@ def setUp(self):
306
311
)
307
312
)
308
313
309
- async def get_user_by_access_token (token = None , allow_guest = False ):
314
+ async def get_user_by_access_token (
315
+ token : Optional [str ] = None , allow_guest : bool = False
316
+ ) -> JsonDict :
310
317
assert self .helper .auth_user_id is not None
311
318
return {
312
319
"user" : UserID .from_string (self .helper .auth_user_id ),
313
320
"token_id" : token_id ,
314
321
"is_guest" : False ,
315
322
}
316
323
317
- async def get_user_by_req (request , allow_guest = False ):
324
+ async def get_user_by_req (
325
+ request : SynapseRequest ,
326
+ allow_guest : bool = False ,
327
+ allow_expired : bool = False ,
328
+ ) -> Requester :
318
329
assert self .helper .auth_user_id is not None
319
330
return create_requester (
320
331
UserID .from_string (self .helper .auth_user_id ),
@@ -339,11 +350,11 @@ async def get_user_by_req(request, allow_guest=False):
339
350
if hasattr (self , "prepare" ):
340
351
self .prepare (self .reactor , self .clock , self .hs )
341
352
342
- def tearDown (self ):
353
+ def tearDown (self ) -> None :
343
354
# Reset to not use frozen dicts.
344
355
events .USE_FROZEN_DICTS = False
345
356
346
- def wait_on_thread (self , deferred , timeout = 10 ):
357
+ def wait_on_thread (self , deferred : Deferred , timeout : int = 10 ) -> None :
347
358
"""
348
359
Wait until a Deferred is done, where it's waiting on a real thread.
349
360
"""
@@ -374,7 +385,7 @@ def make_homeserver(self, reactor, clock):
374
385
clock (synapse.util.Clock): The Clock, associated with the reactor.
375
386
376
387
Returns:
377
- A homeserver (synapse.server.HomeServer) suitable for testing.
388
+ A homeserver suitable for testing.
378
389
379
390
Function to be overridden in subclasses.
380
391
"""
@@ -408,7 +419,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:
408
419
"/_synapse/admin" : servlet_resource ,
409
420
}
410
421
411
- def default_config (self ):
422
+ def default_config (self ) -> JsonDict :
412
423
"""
413
424
Get a default HomeServer config dict.
414
425
"""
@@ -421,7 +432,9 @@ def default_config(self):
421
432
422
433
return config
423
434
424
- def prepare (self , reactor : MemoryReactor , clock : Clock , homeserver : HomeServer ):
435
+ def prepare (
436
+ self , reactor : MemoryReactor , clock : Clock , homeserver : HomeServer
437
+ ) -> None :
425
438
"""
426
439
Prepare for the test. This involves things like mocking out parts of
427
440
the homeserver, or building test data common across the whole test
@@ -519,7 +532,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
519
532
config_obj .parse_config_dict (config , "" , "" )
520
533
kwargs ["config" ] = config_obj
521
534
522
- async def run_bg_updates ():
535
+ async def run_bg_updates () -> None :
523
536
with LoggingContext ("run_bg_updates" ):
524
537
self .get_success (stor .db_pool .updates .run_background_updates (False ))
525
538
@@ -538,11 +551,7 @@ def pump(self, by: float = 0.0) -> None:
538
551
"""
539
552
self .reactor .pump ([by ] * 100 )
540
553
541
- def get_success (
542
- self ,
543
- d : Awaitable [TV ],
544
- by : float = 0.0 ,
545
- ) -> TV :
554
+ def get_success (self , d : Awaitable [TV ], by : float = 0.0 ) -> TV :
546
555
deferred : Deferred [TV ] = ensureDeferred (d ) # type: ignore[arg-type]
547
556
self .pump (by = by )
548
557
return self .successResultOf (deferred )
@@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
755
764
OTHER_SERVER_NAME = "other.example.com"
756
765
OTHER_SERVER_SIGNATURE_KEY = signedjson .key .generate_signing_key ("test" )
757
766
758
- def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ):
767
+ def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) -> None :
759
768
super ().prepare (reactor , clock , hs )
760
769
761
770
# poke the other server's signing key into the key store, so that we don't
@@ -879,7 +888,7 @@ def _auth_header_for_request(
879
888
)
880
889
881
890
882
- def override_config (extra_config ) :
891
+ def override_config (extra_config : JsonDict ) -> Callable [[ TV ], TV ] :
883
892
"""A decorator which can be applied to test functions to give additional HS config
884
893
885
894
For use
@@ -892,12 +901,13 @@ def test_foo(self):
892
901
...
893
902
894
903
Args:
895
- extra_config(dict) : Additional config settings to be merged into the default
904
+ extra_config: Additional config settings to be merged into the default
896
905
config dict before instantiating the test homeserver.
897
906
"""
898
907
899
- def decorator (func ):
900
- func ._extra_config = extra_config
908
+ def decorator (func : TV ) -> TV :
909
+ # This attribute is being defined.
910
+ func ._extra_config = extra_config # type: ignore[attr-defined]
901
911
return func
902
912
903
913
return decorator
0 commit comments