Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 9e4610c

Browse files
authored
Correct type hints for parse_string(s)_from_args. (#10137)
1 parent 7dc1473 commit 9e4610c

File tree

8 files changed

+132
-83
lines changed

8 files changed

+132
-83
lines changed

changelog.d/10137.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `parse_strings_from_args` for parsing an array from query parameters.

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ files =
3232
synapse/http/federation/matrix_federation_agent.py,
3333
synapse/http/federation/well_known_resolver.py,
3434
synapse/http/matrixfederationclient.py,
35+
synapse/http/servlet.py,
3536
synapse/http/server.py,
3637
synapse/http/site.py,
3738
synapse/logging,

synapse/http/servlet.py

Lines changed: 111 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
""" This module contains base REST classes for constructing REST servlets. """
1616

1717
import logging
18-
from typing import Iterable, List, Optional, Union, overload
18+
from typing import Dict, Iterable, List, Optional, overload
1919

2020
from typing_extensions import Literal
2121

22+
from twisted.web.server import Request
23+
2224
from synapse.api.errors import Codes, SynapseError
2325
from synapse.util import json_decoder
2426

@@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
108110
return default
109111

110112

113+
@overload
114+
def parse_bytes_from_args(
115+
args: Dict[bytes, List[bytes]],
116+
name: str,
117+
default: Literal[None] = None,
118+
required: Literal[True] = True,
119+
) -> bytes:
120+
...
121+
122+
123+
@overload
124+
def parse_bytes_from_args(
125+
args: Dict[bytes, List[bytes]],
126+
name: str,
127+
default: Optional[bytes] = None,
128+
required: bool = False,
129+
) -> Optional[bytes]:
130+
...
131+
132+
133+
def parse_bytes_from_args(
134+
args: Dict[bytes, List[bytes]],
135+
name: str,
136+
default: Optional[bytes] = None,
137+
required: bool = False,
138+
) -> Optional[bytes]:
139+
"""
140+
Parse a string parameter as bytes from the request query string.
141+
142+
Args:
143+
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
144+
name: the name of the query parameter.
145+
default: value to use if the parameter is absent,
146+
defaults to None. Must be bytes if encoding is None.
147+
required: whether to raise a 400 SynapseError if the
148+
parameter is absent, defaults to False.
149+
Returns:
150+
Bytes or the default value.
151+
152+
Raises:
153+
SynapseError if the parameter is absent and required.
154+
"""
155+
name_bytes = name.encode("ascii")
156+
157+
if name_bytes in args:
158+
return args[name_bytes][0]
159+
elif required:
160+
message = "Missing string query parameter %s" % (name,)
161+
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
162+
163+
return default
164+
165+
111166
def parse_string(
112-
request,
113-
name: Union[bytes, str],
167+
request: Request,
168+
name: str,
114169
default: Optional[str] = None,
115170
required: bool = False,
116171
allowed_values: Optional[Iterable[str]] = None,
117-
encoding: Optional[str] = "ascii",
172+
encoding: str = "ascii",
118173
):
119174
"""
120175
Parse a string parameter from the request query string.
@@ -125,66 +180,65 @@ def parse_string(
125180
Args:
126181
request: the twisted HTTP request.
127182
name: the name of the query parameter.
128-
default: value to use if the parameter is absent,
129-
defaults to None. Must be bytes if encoding is None.
183+
default: value to use if the parameter is absent, defaults to None.
130184
required: whether to raise a 400 SynapseError if the
131185
parameter is absent, defaults to False.
132186
allowed_values: List of allowed values for the
133187
string, or None if any value is allowed, defaults to None. Must be
134188
the same type as name, if given.
135-
encoding : The encoding to decode the string content with.
189+
encoding: The encoding to decode the string content with.
190+
136191
Returns:
137-
A string value or the default. Unicode if encoding
138-
was given, bytes otherwise.
192+
A string value or the default.
139193
140194
Raises:
141195
SynapseError if the parameter is absent and required, or if the
142196
parameter is present, must be one of a list of allowed values and
143197
is not one of those allowed values.
144198
"""
199+
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
145200
return parse_string_from_args(
146-
request.args, name, default, required, allowed_values, encoding
201+
args, name, default, required, allowed_values, encoding
147202
)
148203

149204

150205
def _parse_string_value(
151-
value: Union[str, bytes],
206+
value: bytes,
152207
allowed_values: Optional[Iterable[str]],
153208
name: str,
154-
encoding: Optional[str],
155-
) -> Union[str, bytes]:
156-
if encoding:
157-
try:
158-
value = value.decode(encoding)
159-
except ValueError:
160-
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
209+
encoding: str,
210+
) -> str:
211+
try:
212+
value_str = value.decode(encoding)
213+
except ValueError:
214+
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
161215

162-
if allowed_values is not None and value not in allowed_values:
216+
if allowed_values is not None and value_str not in allowed_values:
163217
message = "Query parameter %r must be one of [%s]" % (
164218
name,
165219
", ".join(repr(v) for v in allowed_values),
166220
)
167221
raise SynapseError(400, message)
168222
else:
169-
return value
223+
return value_str
170224

171225

172226
@overload
173227
def parse_strings_from_args(
174-
args: List[str],
175-
name: Union[bytes, str],
228+
args: Dict[bytes, List[bytes]],
229+
name: str,
176230
default: Optional[List[str]] = None,
177-
required: bool = False,
231+
required: Literal[True] = True,
178232
allowed_values: Optional[Iterable[str]] = None,
179-
encoding: Literal[None] = None,
180-
) -> Optional[List[bytes]]:
233+
encoding: str = "ascii",
234+
) -> List[str]:
181235
...
182236

183237

184238
@overload
185239
def parse_strings_from_args(
186-
args: List[str],
187-
name: Union[bytes, str],
240+
args: Dict[bytes, List[bytes]],
241+
name: str,
188242
default: Optional[List[str]] = None,
189243
required: bool = False,
190244
allowed_values: Optional[Iterable[str]] = None,
@@ -194,83 +248,71 @@ def parse_strings_from_args(
194248

195249

196250
def parse_strings_from_args(
197-
args: List[str],
198-
name: Union[bytes, str],
251+
args: Dict[bytes, List[bytes]],
252+
name: str,
199253
default: Optional[List[str]] = None,
200254
required: bool = False,
201255
allowed_values: Optional[Iterable[str]] = None,
202-
encoding: Optional[str] = "ascii",
203-
) -> Optional[List[Union[bytes, str]]]:
256+
encoding: str = "ascii",
257+
) -> Optional[List[str]]:
204258
"""
205259
Parse a string parameter from the request query string list.
206260
207-
If encoding is not None, the content of the query param will be
208-
decoded to Unicode using the encoding, otherwise it will be encoded
261+
The content of the query param will be decoded to Unicode using the encoding.
209262
210263
Args:
211-
args: the twisted HTTP request.args list.
264+
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
212265
name: the name of the query parameter.
213-
default: value to use if the parameter is absent,
214-
defaults to None. Must be bytes if encoding is None.
215-
required : whether to raise a 400 SynapseError if the
266+
default: value to use if the parameter is absent, defaults to None.
267+
required: whether to raise a 400 SynapseError if the
216268
parameter is absent, defaults to False.
217-
allowed_values (list[bytes|unicode]): List of allowed values for the
218-
string, or None if any value is allowed, defaults to None. Must be
219-
the same type as name, if given.
269+
allowed_values: List of allowed values for the
270+
string, or None if any value is allowed, defaults to None.
220271
encoding: The encoding to decode the string content with.
221272
222273
Returns:
223-
A string value or the default. Unicode if encoding
224-
was given, bytes otherwise.
274+
A string value or the default.
225275
226276
Raises:
227277
SynapseError if the parameter is absent and required, or if the
228278
parameter is present, must be one of a list of allowed values and
229279
is not one of those allowed values.
230280
"""
281+
name_bytes = name.encode("ascii")
231282

232-
if not isinstance(name, bytes):
233-
name = name.encode("ascii")
234-
235-
if name in args:
236-
values = args[name]
283+
if name_bytes in args:
284+
values = args[name_bytes]
237285

238286
return [
239287
_parse_string_value(value, allowed_values, name=name, encoding=encoding)
240288
for value in values
241289
]
242290
else:
243291
if required:
244-
message = "Missing string query parameter %r" % (name)
292+
message = "Missing string query parameter %r" % (name,)
245293
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
246-
else:
247-
248-
if encoding and isinstance(default, bytes):
249-
return default.decode(encoding)
250294

251-
return default
295+
return default
252296

253297

254298
def parse_string_from_args(
255-
args: List[str],
256-
name: Union[bytes, str],
299+
args: Dict[bytes, List[bytes]],
300+
name: str,
257301
default: Optional[str] = None,
258302
required: bool = False,
259303
allowed_values: Optional[Iterable[str]] = None,
260-
encoding: Optional[str] = "ascii",
261-
) -> Optional[Union[bytes, str]]:
304+
encoding: str = "ascii",
305+
) -> Optional[str]:
262306
"""
263307
Parse the string parameter from the request query string list
264308
and return the first result.
265309
266-
If encoding is not None, the content of the query param will be
267-
decoded to Unicode using the encoding, otherwise it will be encoded
310+
The content of the query param will be decoded to Unicode using the encoding.
268311
269312
Args:
270-
args: the twisted HTTP request.args list.
313+
args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
271314
name: the name of the query parameter.
272-
default: value to use if the parameter is absent,
273-
defaults to None. Must be bytes if encoding is None.
315+
default: value to use if the parameter is absent, defaults to None.
274316
required: whether to raise a 400 SynapseError if the
275317
parameter is absent, defaults to False.
276318
allowed_values: List of allowed values for the
@@ -279,8 +321,7 @@ def parse_string_from_args(
279321
encoding: The encoding to decode the string content with.
280322
281323
Returns:
282-
A string value or the default. Unicode if encoding
283-
was given, bytes otherwise.
324+
A string value or the default.
284325
285326
Raises:
286327
SynapseError if the parameter is absent and required, or if the
@@ -291,12 +332,15 @@ def parse_string_from_args(
291332
strings = parse_strings_from_args(
292333
args,
293334
name,
294-
default=[default],
335+
default=[default] if default is not None else None,
295336
required=required,
296337
allowed_values=allowed_values,
297338
encoding=encoding,
298339
)
299340

341+
if strings is None:
342+
return None
343+
300344
return strings[0]
301345

302346

@@ -388,9 +432,8 @@ class attribute containing a pre-compiled regular expression. The automatic
388432

389433
def register(self, http_server):
390434
""" Register this servlet with the given HTTP server. """
391-
if hasattr(self, "PATTERNS"):
392-
patterns = self.PATTERNS
393-
435+
patterns = getattr(self, "PATTERNS", None)
436+
if patterns:
394437
for method in ("GET", "PUT", "POST", "DELETE"):
395438
if hasattr(self, "on_%s" % (method,)):
396439
servlet_classname = self.__class__.__name__

synapse/rest/admin/rooms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ async def on_GET(
649649
limit = parse_integer(request, "limit", default=10)
650650

651651
# picking the API shape for symmetry with /messages
652-
filter_str = parse_string(request, b"filter", encoding="utf-8")
652+
filter_str = parse_string(request, "filter", encoding="utf-8")
653653
if filter_str:
654654
filter_json = urlparse.unquote(filter_str)
655655
event_filter = Filter(

synapse/rest/client/v1/login.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import re
17-
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
17+
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
1818

1919
from synapse.api.errors import Codes, LoginError, SynapseError
2020
from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +25,7 @@
2525
from synapse.http.server import HttpServer, finish_request
2626
from synapse.http.servlet import (
2727
RestServlet,
28+
parse_bytes_from_args,
2829
parse_json_object_from_request,
2930
parse_string,
3031
)
@@ -437,9 +438,8 @@ async def on_GET(
437438
finish_request(request)
438439
return
439440

440-
client_redirect_url = parse_string(
441-
request, "redirectUrl", required=True, encoding=None
442-
)
441+
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
442+
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
443443
sso_url = await self._sso_handler.handle_redirect_request(
444444
request,
445445
client_redirect_url,

synapse/rest/client/v1/room.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ async def on_GET(self, request, room_id):
537537
self.store, request, default_limit=10
538538
)
539539
as_client_event = b"raw" not in request.args
540-
filter_str = parse_string(request, b"filter", encoding="utf-8")
540+
filter_str = parse_string(request, "filter", encoding="utf-8")
541541
if filter_str:
542542
filter_json = urlparse.unquote(filter_str)
543543
event_filter = Filter(
@@ -652,7 +652,7 @@ async def on_GET(self, request, room_id, event_id):
652652
limit = parse_integer(request, "limit", default=10)
653653

654654
# picking the API shape for symmetry with /messages
655-
filter_str = parse_string(request, b"filter", encoding="utf-8")
655+
filter_str = parse_string(request, "filter", encoding="utf-8")
656656
if filter_str:
657657
filter_json = urlparse.unquote(filter_str)
658658
event_filter = Filter(

0 commit comments

Comments
 (0)