Skip to content

Commit 0c7b0bd

Browse files
committed
Add type hints and include mypy in build
This fully type hints the wsproto codebase and uses mypy to ensure the type hints are added and correct. This has identified some potential bugs, see ?? The additional linting disable is because the mypy TYPE_CHECKING is not understood by pylint.
1 parent 498f2cf commit 0c7b0bd

20 files changed

+644
-513
lines changed

.prospector.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pylint:
22
disable:
3+
- cyclic-import
34
- unused-argument
45
- useless-object-inheritance
56

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ combine_as_imports=True
2525
force_grid_wrap=0
2626
include_trailing_comma=True
2727
known_first_party=wsproto, test
28-
known_third_party=h11, pytest
28+
known_third_party=h11, pytest, _pytest
2929
line_length=88
3030
multi_line_output=3
3131
no_lines_before=LOCALFOLDER

test/helpers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1+
from typing import Optional, Union
2+
13
from wsproto.extensions import Extension
24

35

46
class FakeExtension(Extension):
57
name = "fake"
68

7-
def __init__(self, offer_response=None, accept_response=None):
9+
def __init__(
10+
self,
11+
offer_response: Optional[Union[bool, str]] = None,
12+
accept_response: Optional[Union[bool, str]] = None,
13+
) -> None:
814
self.offer_response = offer_response
9-
self.accepted_offer = None
10-
self.offered = None
15+
self.accepted_offer: Optional[str] = None
16+
self.offered: Optional[str] = None
1117
self.accept_response = accept_response
1218

13-
def offer(self):
19+
def offer(self) -> Union[bool, str]:
1420
return self.offer_response
1521

16-
def finalize(self, offer):
22+
def finalize(self, offer: str) -> None:
1723
self.accepted_offer = offer
1824

19-
def accept(self, offer):
25+
def accept(self, offer: str) -> Union[bool, str]:
2026
self.offered = offer
2127
return self.accept_response

test/test_client.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
# These tests test the behaviours expected of wsproto in when the
2-
# connection is a client.
2+
# connectionis a client.
3+
from typing import List, Optional, Tuple
34

45
import h11
56
import pytest
67

78
from wsproto import WSConnection
89
from wsproto.connection import CLIENT
9-
from wsproto.events import AcceptConnection, RejectConnection, RejectData, Request
10+
from wsproto.events import (
11+
AcceptConnection,
12+
Event,
13+
RejectConnection,
14+
RejectData,
15+
Request,
16+
)
17+
from wsproto.extensions import Extension
1018
from wsproto.frame_protocol import CloseReason
19+
from wsproto.typing import Headers
1120
from wsproto.utilities import (
1221
generate_accept_token,
1322
normed_header_dict,
@@ -16,15 +25,14 @@
1625
from .helpers import FakeExtension
1726

1827

19-
def _make_connection_request(request):
20-
# type: (Request) -> h11.Request
28+
def _make_connection_request(request: Request) -> h11.Request:
2129
client = WSConnection(CLIENT)
2230
server = h11.Connection(h11.SERVER)
2331
server.receive_data(client.send(request))
2432
return server.next_event()
2533

2634

27-
def test_connection_request():
35+
def test_connection_request() -> None:
2836
request = _make_connection_request(Request(host="localhost", target="/"))
2937

3038
assert request.http_version == b"1.1"
@@ -38,7 +46,7 @@ def test_connection_request():
3846
assert b"sec-websocket-key" in headers
3947

4048

41-
def test_connection_request_additional_headers():
49+
def test_connection_request_additional_headers() -> None:
4250
request = _make_connection_request(
4351
Request(
4452
host="localhost",
@@ -52,40 +60,40 @@ def test_connection_request_additional_headers():
5260
assert headers[b"x-bar"] == b"Foo"
5361

5462

55-
def test_connection_request_simple_extension():
63+
def test_connection_request_simple_extension() -> None:
5664
extension = FakeExtension(offer_response=True)
5765
request = _make_connection_request(
58-
Request(host="localhost", target="/", extensions=[extension])
66+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
5967
)
6068

6169
headers = normed_header_dict(request.headers)
6270
assert headers[b"sec-websocket-extensions"] == extension.name.encode("ascii")
6371

6472

65-
def test_connection_request_simple_extension_no_offer():
73+
def test_connection_request_simple_extension_no_offer() -> None:
6674
extension = FakeExtension(offer_response=False)
6775
request = _make_connection_request(
68-
Request(host="localhost", target="/", extensions=[extension])
76+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
6977
)
7078

7179
headers = normed_header_dict(request.headers)
7280
assert b"sec-websocket-extensions" not in headers
7381

7482

75-
def test_connection_request_parametrised_extension():
83+
def test_connection_request_parametrised_extension() -> None:
7684
extension = FakeExtension(offer_response="parameter1=value1; parameter2=value2")
7785
request = _make_connection_request(
78-
Request(host="localhost", target="/", extensions=[extension])
86+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
7987
)
8088

8189
headers = normed_header_dict(request.headers)
8290
assert headers[b"sec-websocket-extensions"] == b"%s; %s" % (
8391
extension.name.encode("ascii"),
84-
extension.offer_response.encode("ascii"),
92+
extension.offer_response.encode("ascii"), # type: ignore
8593
)
8694

8795

88-
def test_connection_request_subprotocols():
96+
def test_connection_request_subprotocols() -> None:
8997
request = _make_connection_request(
9098
Request(host="localhost", target="/", subprotocols=["one", "two"])
9199
)
@@ -95,12 +103,12 @@ def test_connection_request_subprotocols():
95103

96104

97105
def _make_handshake(
98-
response_status,
99-
response_headers,
100-
subprotocols=None,
101-
extensions=None,
102-
auto_accept_key=True,
103-
):
106+
response_status: int,
107+
response_headers: Headers,
108+
subprotocols: Optional[List[str]] = None,
109+
extensions: Optional[List[Extension]] = None,
110+
auto_accept_key: bool = True,
111+
) -> List[Event]:
104112
client = WSConnection(CLIENT)
105113
server = h11.Connection(h11.SERVER)
106114
server.receive_data(
@@ -130,22 +138,22 @@ def _make_handshake(
130138
return list(client.events())
131139

132140

133-
def test_handshake():
141+
def test_handshake() -> None:
134142
events = _make_handshake(
135143
101, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")]
136144
)
137145
assert events == [AcceptConnection()]
138146

139147

140-
def test_broken_handshake():
148+
def test_broken_handshake() -> None:
141149
events = _make_handshake(
142150
102, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")]
143151
)
144152
assert isinstance(events[0], RejectConnection)
145153
assert events[0].status_code == 102
146154

147155

148-
def test_handshake_extra_accept_headers():
156+
def test_handshake_extra_accept_headers() -> None:
149157
events = _make_handshake(
150158
101,
151159
[(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), (b"X-Foo", b"bar")],
@@ -154,20 +162,20 @@ def test_handshake_extra_accept_headers():
154162

155163

156164
@pytest.mark.parametrize("extra_headers", [[], [(b"connection", b"Keep-Alive")]])
157-
def test_handshake_response_broken_connection_header(extra_headers):
165+
def test_handshake_response_broken_connection_header(extra_headers: Headers) -> None:
158166
with pytest.raises(RemoteProtocolError) as excinfo:
159167
events = _make_handshake(101, [(b"upgrade", b"WebSocket")] + extra_headers)
160168
assert str(excinfo.value) == "Missing header, 'Connection: Upgrade'"
161169

162170

163171
@pytest.mark.parametrize("extra_headers", [[], [(b"upgrade", b"h2")]])
164-
def test_handshake_response_broken_upgrade_header(extra_headers):
172+
def test_handshake_response_broken_upgrade_header(extra_headers: Headers) -> None:
165173
with pytest.raises(RemoteProtocolError) as excinfo:
166174
events = _make_handshake(101, [(b"connection", b"Upgrade")] + extra_headers)
167175
assert str(excinfo.value) == "Missing header, 'Upgrade: WebSocket'"
168176

169177

170-
def test_handshake_response_missing_websocket_key_header():
178+
def test_handshake_response_missing_websocket_key_header() -> None:
171179
with pytest.raises(RemoteProtocolError) as excinfo:
172180
events = _make_handshake(
173181
101,
@@ -177,7 +185,7 @@ def test_handshake_response_missing_websocket_key_header():
177185
assert str(excinfo.value) == "Bad accept token"
178186

179187

180-
def test_handshake_with_subprotocol():
188+
def test_handshake_with_subprotocol() -> None:
181189
events = _make_handshake(
182190
101,
183191
[
@@ -190,7 +198,7 @@ def test_handshake_with_subprotocol():
190198
assert events == [AcceptConnection(subprotocol="one")]
191199

192200

193-
def test_handshake_bad_subprotocol():
201+
def test_handshake_bad_subprotocol() -> None:
194202
with pytest.raises(RemoteProtocolError) as excinfo:
195203
events = _make_handshake(
196204
101,
@@ -203,7 +211,7 @@ def test_handshake_bad_subprotocol():
203211
assert str(excinfo.value) == "unrecognized subprotocol new"
204212

205213

206-
def test_handshake_with_extension():
214+
def test_handshake_with_extension() -> None:
207215
extension = FakeExtension(offer_response=True)
208216
events = _make_handshake(
209217
101,
@@ -217,7 +225,7 @@ def test_handshake_with_extension():
217225
assert events == [AcceptConnection(extensions=[extension])]
218226

219227

220-
def test_handshake_bad_extension():
228+
def test_handshake_bad_extension() -> None:
221229
with pytest.raises(RemoteProtocolError) as excinfo:
222230
events = _make_handshake(
223231
101,
@@ -230,15 +238,17 @@ def test_handshake_bad_extension():
230238
assert str(excinfo.value) == "unrecognized extension bad"
231239

232240

233-
def test_protocol_error():
241+
def test_protocol_error() -> None:
234242
client = WSConnection(CLIENT)
235243
client.send(Request(host="localhost", target="/"))
236244
with pytest.raises(RemoteProtocolError) as excinfo:
237245
client.receive_data(b"broken nonsense\r\n\r\n")
238246
assert str(excinfo.value) == "Bad HTTP message"
239247

240248

241-
def _make_handshake_rejection(status_code, body=None):
249+
def _make_handshake_rejection(
250+
status_code: int, body: Optional[bytes] = None
251+
) -> List[Event]:
242252
client = WSConnection(CLIENT)
243253
server = h11.Connection(h11.SERVER)
244254
server.receive_data(client.send(Request(host="localhost", target="/")))
@@ -255,7 +265,7 @@ def _make_handshake_rejection(status_code, body=None):
255265
return list(client.events())
256266

257267

258-
def test_handshake_rejection():
268+
def test_handshake_rejection() -> None:
259269
events = _make_handshake_rejection(400)
260270
assert events == [
261271
RejectConnection(
@@ -265,7 +275,7 @@ def test_handshake_rejection():
265275
]
266276

267277

268-
def test_handshake_rejection_with_body():
278+
def test_handshake_rejection_with_body() -> None:
269279
events = _make_handshake_rejection(400, b"Hello")
270280
assert events == [
271281
RejectConnection(

0 commit comments

Comments
 (0)