Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 922b771

Browse files
authored
Add missing type hints for tests.unittest. (#13397)
1 parent 502f075 commit 922b771

File tree

6 files changed

+66
-52
lines changed

6 files changed

+66
-52
lines changed

changelog.d/13397.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Adding missing type hints to tests.

tests/handlers/test_directory.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -481,17 +481,13 @@ def default_config(self) -> Dict[str, Any]:
481481

482482
return config
483483

484-
def prepare(
485-
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
486-
) -> HomeServer:
484+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
487485
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
488486
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
489487

490488
self.denied_user_id = self.register_user("denied", "pass")
491489
self.denied_access_token = self.login("denied", "pass")
492490

493-
return hs
494-
495491
def test_denied_without_publication_permission(self) -> None:
496492
"""
497493
Try to create a room, register an alias for it, and publish it,
@@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
575571

576572
servlets = [directory.register_servlets, room.register_servlets]
577573

578-
def prepare(
579-
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
580-
) -> HomeServer:
574+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
581575
room_id = self.helper.create_room_as(self.user_id)
582576

583577
channel = self.make_request(
@@ -588,8 +582,6 @@ def prepare(
588582
self.room_list_handler = hs.get_room_list_handler()
589583
self.directory_handler = hs.get_directory_handler()
590584

591-
return hs
592-
593585
def test_disabling_room_list(self) -> None:
594586
self.room_list_handler.enable_room_list_search = True
595587
self.directory_handler.enable_room_list_search = True

tests/rest/client/test_relations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
10601060
participated, bundled_aggregations.get("current_user_participated")
10611061
)
10621062
# The latest thread event has some fields that don't matter.
1063+
self.assertIn("latest_event", bundled_aggregations)
10631064
self.assert_dict(
10641065
{
10651066
"content": {
@@ -1072,7 +1073,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
10721073
"sender": self.user2_id,
10731074
"type": "m.room.test",
10741075
},
1075-
bundled_aggregations.get("latest_event"),
1076+
bundled_aggregations["latest_event"],
10761077
)
10771078

10781079
return assert_thread
@@ -1112,6 +1113,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
11121113
self.assertEqual(2, bundled_aggregations.get("count"))
11131114
self.assertTrue(bundled_aggregations.get("current_user_participated"))
11141115
# The latest thread event has some fields that don't matter.
1116+
self.assertIn("latest_event", bundled_aggregations)
11151117
self.assert_dict(
11161118
{
11171119
"content": {
@@ -1124,7 +1126,7 @@ def assert_thread(bundled_aggregations: JsonDict) -> None:
11241126
"sender": self.user_id,
11251127
"type": "m.room.test",
11261128
},
1127-
bundled_aggregations.get("latest_event"),
1129+
bundled_aggregations["latest_event"],
11281130
)
11291131
# Check the unsigned field on the latest event.
11301132
self.assert_dict(

tests/rest/client/test_rooms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def test_get_state_cancellation(self) -> None:
496496

497497
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
498498
self.assertCountEqual(
499-
[state_event["type"] for state_event in channel.json_body],
499+
[state_event["type"] for state_event in channel.json_list],
500500
{
501501
"m.room.create",
502502
"m.room.power_levels",

tests/server.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Callable,
2626
Dict,
2727
Iterable,
28+
List,
2829
MutableMapping,
2930
Optional,
3031
Tuple,
@@ -121,7 +122,15 @@ def request(self, request: Request) -> None:
121122

122123
@property
123124
def json_body(self) -> JsonDict:
124-
return json.loads(self.text_body)
125+
body = json.loads(self.text_body)
126+
assert isinstance(body, dict)
127+
return body
128+
129+
@property
130+
def json_list(self) -> List[JsonDict]:
131+
body = json.loads(self.text_body)
132+
assert isinstance(body, list)
133+
return body
125134

126135
@property
127136
def text_body(self) -> str:

tests/unittest.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Generic,
2929
Iterable,
3030
List,
31+
NoReturn,
3132
Optional,
3233
Tuple,
3334
Type,
@@ -39,7 +40,7 @@
3940
import canonicaljson
4041
import signedjson.key
4142
import unpaddedbase64
42-
from typing_extensions import Protocol
43+
from typing_extensions import Concatenate, ParamSpec, Protocol
4344

4445
from twisted.internet.defer import Deferred, ensureDeferred
4546
from twisted.python.failure import Failure
@@ -67,7 +68,7 @@
6768
from synapse.rest import RegisterServletsFunc
6869
from synapse.server import HomeServer
6970
from synapse.storage.keys import FetchKeyResult
70-
from synapse.types import JsonDict, UserID, create_requester
71+
from synapse.types import JsonDict, Requester, UserID, create_requester
7172
from synapse.util import Clock
7273
from synapse.util.httpresourcetree import create_resource_tree
7374

@@ -88,6 +89,10 @@
8889
TV = TypeVar("TV")
8990
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
9091

92+
P = ParamSpec("P")
93+
R = TypeVar("R")
94+
S = TypeVar("S")
95+
9196

9297
class _TypedFailure(Generic[_ExcType], Protocol):
9398
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@@ -97,7 +102,7 @@ def value(self) -> _ExcType:
97102
...
98103

99104

100-
def around(target):
105+
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
101106
"""A CLOS-style 'around' modifier, which wraps the original method of the
102107
given instance with another piece of code.
103108
@@ -106,11 +111,11 @@ def method_name(orig, *args, **kwargs):
106111
return orig(*args, **kwargs)
107112
"""
108113

109-
def _around(code):
114+
def _around(code: Callable[Concatenate[S, P], R]) -> None:
110115
name = code.__name__
111116
orig = getattr(target, name)
112117

113-
def new(*args, **kwargs):
118+
def new(*args: P.args, **kwargs: P.kwargs) -> R:
114119
return code(orig, *args, **kwargs)
115120

116121
setattr(target, name, new)
@@ -131,7 +136,7 @@ def __init__(self, methodName: str):
131136
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
132137

133138
@around(self)
134-
def setUp(orig):
139+
def setUp(orig: Callable[[], R]) -> R:
135140
# if we're not starting in the sentinel logcontext, then to be honest
136141
# all future bets are off.
137142
if current_context():
@@ -144,7 +149,7 @@ def setUp(orig):
144149
if level is not None and old_level != level:
145150

146151
@around(self)
147-
def tearDown(orig):
152+
def tearDown(orig: Callable[[], R]) -> R:
148153
ret = orig()
149154
logging.getLogger().setLevel(old_level)
150155
return ret
@@ -158,7 +163,7 @@ def tearDown(orig):
158163
return orig()
159164

160165
@around(self)
161-
def tearDown(orig):
166+
def tearDown(orig: Callable[[], R]) -> R:
162167
ret = orig()
163168
# force a GC to workaround problems with deferreds leaking logcontexts when
164169
# they are GCed (see the logcontext docs)
@@ -167,7 +172,7 @@ def tearDown(orig):
167172

168173
return ret
169174

170-
def assertObjectHasAttributes(self, attrs, obj):
175+
def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
171176
"""Asserts that the given object has each of the attributes given, and
172177
that the value of each matches according to assertEqual."""
173178
for key in attrs.keys():
@@ -178,44 +183,44 @@ def assertObjectHasAttributes(self, attrs, obj):
178183
except AssertionError as e:
179184
raise (type(e))(f"Assert error for '.{key}':") from e
180185

181-
def assert_dict(self, required, actual):
186+
def assert_dict(self, required: dict, actual: dict) -> None:
182187
"""Does a partial assert of a dict.
183188
184189
Args:
185-
required (dict): The keys and value which MUST be in 'actual'.
186-
actual (dict): The test result. Extra keys will not be checked.
190+
required: The keys and value which MUST be in 'actual'.
191+
actual: The test result. Extra keys will not be checked.
187192
"""
188193
for key in required:
189194
self.assertEqual(
190195
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
191196
)
192197

193198

194-
def DEBUG(target):
199+
def DEBUG(target: TV) -> TV:
195200
"""A decorator to set the .loglevel attribute to logging.DEBUG.
196201
Can apply to either a TestCase or an individual test method."""
197-
target.loglevel = logging.DEBUG
202+
target.loglevel = logging.DEBUG # type: ignore[attr-defined]
198203
return target
199204

200205

201-
def INFO(target):
206+
def INFO(target: TV) -> TV:
202207
"""A decorator to set the .loglevel attribute to logging.INFO.
203208
Can apply to either a TestCase or an individual test method."""
204-
target.loglevel = logging.INFO
209+
target.loglevel = logging.INFO # type: ignore[attr-defined]
205210
return target
206211

207212

208-
def logcontext_clean(target):
213+
def logcontext_clean(target: TV) -> TV:
209214
"""A decorator which marks the TestCase or method as 'logcontext_clean'
210215
211216
... ie, any logcontext errors should cause a test failure
212217
"""
213218

214-
def logcontext_error(msg):
219+
def logcontext_error(msg: str) -> NoReturn:
215220
raise AssertionError("logcontext error: %s" % (msg))
216221

217222
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
218-
return patcher(target)
223+
return patcher(target) # type: ignore[call-overload]
219224

220225

221226
class HomeserverTestCase(TestCase):
@@ -255,7 +260,7 @@ def __init__(self, methodName: str):
255260
method = getattr(self, methodName)
256261
self._extra_config = getattr(method, "_extra_config", None)
257262

258-
def setUp(self):
263+
def setUp(self) -> None:
259264
"""
260265
Set up the TestCase by calling the homeserver constructor, optionally
261266
hijacking the authentication system to return a fixed user, and then
@@ -306,15 +311,21 @@ def setUp(self):
306311
)
307312
)
308313

309-
async def get_user_by_access_token(token=None, allow_guest=False):
314+
async def get_user_by_access_token(
315+
token: Optional[str] = None, allow_guest: bool = False
316+
) -> JsonDict:
310317
assert self.helper.auth_user_id is not None
311318
return {
312319
"user": UserID.from_string(self.helper.auth_user_id),
313320
"token_id": token_id,
314321
"is_guest": False,
315322
}
316323

317-
async def get_user_by_req(request, allow_guest=False):
324+
async def get_user_by_req(
325+
request: SynapseRequest,
326+
allow_guest: bool = False,
327+
allow_expired: bool = False,
328+
) -> Requester:
318329
assert self.helper.auth_user_id is not None
319330
return create_requester(
320331
UserID.from_string(self.helper.auth_user_id),
@@ -339,11 +350,11 @@ async def get_user_by_req(request, allow_guest=False):
339350
if hasattr(self, "prepare"):
340351
self.prepare(self.reactor, self.clock, self.hs)
341352

342-
def tearDown(self):
353+
def tearDown(self) -> None:
343354
# Reset to not use frozen dicts.
344355
events.USE_FROZEN_DICTS = False
345356

346-
def wait_on_thread(self, deferred, timeout=10):
357+
def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
347358
"""
348359
Wait until a Deferred is done, where it's waiting on a real thread.
349360
"""
@@ -374,7 +385,7 @@ def make_homeserver(self, reactor, clock):
374385
clock (synapse.util.Clock): The Clock, associated with the reactor.
375386
376387
Returns:
377-
A homeserver (synapse.server.HomeServer) suitable for testing.
388+
A homeserver suitable for testing.
378389
379390
Function to be overridden in subclasses.
380391
"""
@@ -408,7 +419,7 @@ def create_resource_dict(self) -> Dict[str, Resource]:
408419
"/_synapse/admin": servlet_resource,
409420
}
410421

411-
def default_config(self):
422+
def default_config(self) -> JsonDict:
412423
"""
413424
Get a default HomeServer config dict.
414425
"""
@@ -421,7 +432,9 @@ def default_config(self):
421432

422433
return config
423434

424-
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
435+
def prepare(
436+
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
437+
) -> None:
425438
"""
426439
Prepare for the test. This involves things like mocking out parts of
427440
the homeserver, or building test data common across the whole test
@@ -519,7 +532,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
519532
config_obj.parse_config_dict(config, "", "")
520533
kwargs["config"] = config_obj
521534

522-
async def run_bg_updates():
535+
async def run_bg_updates() -> None:
523536
with LoggingContext("run_bg_updates"):
524537
self.get_success(stor.db_pool.updates.run_background_updates(False))
525538

@@ -538,11 +551,7 @@ def pump(self, by: float = 0.0) -> None:
538551
"""
539552
self.reactor.pump([by] * 100)
540553

541-
def get_success(
542-
self,
543-
d: Awaitable[TV],
544-
by: float = 0.0,
545-
) -> TV:
554+
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
546555
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
547556
self.pump(by=by)
548557
return self.successResultOf(deferred)
@@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
755764
OTHER_SERVER_NAME = "other.example.com"
756765
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
757766

758-
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
767+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
759768
super().prepare(reactor, clock, hs)
760769

761770
# poke the other server's signing key into the key store, so that we don't
@@ -879,7 +888,7 @@ def _auth_header_for_request(
879888
)
880889

881890

882-
def override_config(extra_config):
891+
def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
883892
"""A decorator which can be applied to test functions to give additional HS config
884893
885894
For use
@@ -892,12 +901,13 @@ def test_foo(self):
892901
...
893902
894903
Args:
895-
extra_config(dict): Additional config settings to be merged into the default
904+
extra_config: Additional config settings to be merged into the default
896905
config dict before instantiating the test homeserver.
897906
"""
898907

899-
def decorator(func):
900-
func._extra_config = extra_config
908+
def decorator(func: TV) -> TV:
909+
# This attribute is being defined.
910+
func._extra_config = extra_config # type: ignore[attr-defined]
901911
return func
902912

903913
return decorator

0 commit comments

Comments
 (0)