15
15
""" This module contains base REST classes for constructing REST servlets. """
16
16
17
17
import logging
18
- from typing import Iterable , List , Optional , Union , overload
18
+ from typing import Dict , Iterable , List , Optional , overload
19
19
20
20
from typing_extensions import Literal
21
21
22
+ from twisted .web .server import Request
23
+
22
24
from synapse .api .errors import Codes , SynapseError
23
25
from synapse .util import json_decoder
24
26
@@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
108
110
return default
109
111
110
112
113
+ @overload
114
+ def parse_bytes_from_args (
115
+ args : Dict [bytes , List [bytes ]],
116
+ name : str ,
117
+ default : Literal [None ] = None ,
118
+ required : Literal [True ] = True ,
119
+ ) -> bytes :
120
+ ...
121
+
122
+
123
+ @overload
124
+ def parse_bytes_from_args (
125
+ args : Dict [bytes , List [bytes ]],
126
+ name : str ,
127
+ default : Optional [bytes ] = None ,
128
+ required : bool = False ,
129
+ ) -> Optional [bytes ]:
130
+ ...
131
+
132
+
133
+ def parse_bytes_from_args (
134
+ args : Dict [bytes , List [bytes ]],
135
+ name : str ,
136
+ default : Optional [bytes ] = None ,
137
+ required : bool = False ,
138
+ ) -> Optional [bytes ]:
139
+ """
140
+ Parse a string parameter as bytes from the request query string.
141
+
142
+ Args:
143
+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
144
+ name: the name of the query parameter.
145
+ default: value to use if the parameter is absent,
146
+ defaults to None. Must be bytes if encoding is None.
147
+ required: whether to raise a 400 SynapseError if the
148
+ parameter is absent, defaults to False.
149
+ Returns:
150
+ Bytes or the default value.
151
+
152
+ Raises:
153
+ SynapseError if the parameter is absent and required.
154
+ """
155
+ name_bytes = name .encode ("ascii" )
156
+
157
+ if name_bytes in args :
158
+ return args [name_bytes ][0 ]
159
+ elif required :
160
+ message = "Missing string query parameter %s" % (name ,)
161
+ raise SynapseError (400 , message , errcode = Codes .MISSING_PARAM )
162
+
163
+ return default
164
+
165
+
111
166
def parse_string (
112
- request ,
113
- name : Union [ bytes , str ] ,
167
+ request : Request ,
168
+ name : str ,
114
169
default : Optional [str ] = None ,
115
170
required : bool = False ,
116
171
allowed_values : Optional [Iterable [str ]] = None ,
117
- encoding : Optional [ str ] = "ascii" ,
172
+ encoding : str = "ascii" ,
118
173
):
119
174
"""
120
175
Parse a string parameter from the request query string.
@@ -125,66 +180,65 @@ def parse_string(
125
180
Args:
126
181
request: the twisted HTTP request.
127
182
name: the name of the query parameter.
128
- default: value to use if the parameter is absent,
129
- defaults to None. Must be bytes if encoding is None.
183
+ default: value to use if the parameter is absent, defaults to None.
130
184
required: whether to raise a 400 SynapseError if the
131
185
parameter is absent, defaults to False.
132
186
allowed_values: List of allowed values for the
133
187
string, or None if any value is allowed, defaults to None. Must be
134
188
the same type as name, if given.
135
- encoding : The encoding to decode the string content with.
189
+ encoding: The encoding to decode the string content with.
190
+
136
191
Returns:
137
- A string value or the default. Unicode if encoding
138
- was given, bytes otherwise.
192
+ A string value or the default.
139
193
140
194
Raises:
141
195
SynapseError if the parameter is absent and required, or if the
142
196
parameter is present, must be one of a list of allowed values and
143
197
is not one of those allowed values.
144
198
"""
199
+ args = request .args # type: Dict[bytes, List[bytes]] # type: ignore
145
200
return parse_string_from_args (
146
- request . args , name , default , required , allowed_values , encoding
201
+ args , name , default , required , allowed_values , encoding
147
202
)
148
203
149
204
150
205
def _parse_string_value (
151
- value : Union [ str , bytes ] ,
206
+ value : bytes ,
152
207
allowed_values : Optional [Iterable [str ]],
153
208
name : str ,
154
- encoding : Optional [str ],
155
- ) -> Union [str , bytes ]:
156
- if encoding :
157
- try :
158
- value = value .decode (encoding )
159
- except ValueError :
160
- raise SynapseError (400 , "Query parameter %r must be %s" % (name , encoding ))
209
+ encoding : str ,
210
+ ) -> str :
211
+ try :
212
+ value_str = value .decode (encoding )
213
+ except ValueError :
214
+ raise SynapseError (400 , "Query parameter %r must be %s" % (name , encoding ))
161
215
162
- if allowed_values is not None and value not in allowed_values :
216
+ if allowed_values is not None and value_str not in allowed_values :
163
217
message = "Query parameter %r must be one of [%s]" % (
164
218
name ,
165
219
", " .join (repr (v ) for v in allowed_values ),
166
220
)
167
221
raise SynapseError (400 , message )
168
222
else :
169
- return value
223
+ return value_str
170
224
171
225
172
226
@overload
173
227
def parse_strings_from_args (
174
- args : List [str ],
175
- name : Union [ bytes , str ] ,
228
+ args : Dict [ bytes , List [bytes ] ],
229
+ name : str ,
176
230
default : Optional [List [str ]] = None ,
177
- required : bool = False ,
231
+ required : Literal [ True ] = True ,
178
232
allowed_values : Optional [Iterable [str ]] = None ,
179
- encoding : Literal [ None ] = None ,
180
- ) -> Optional [ List [bytes ] ]:
233
+ encoding : str = "ascii" ,
234
+ ) -> List [str ]:
181
235
...
182
236
183
237
184
238
@overload
185
239
def parse_strings_from_args (
186
- args : List [str ],
187
- name : Union [ bytes , str ] ,
240
+ args : Dict [ bytes , List [bytes ] ],
241
+ name : str ,
188
242
default : Optional [List [str ]] = None ,
189
243
required : bool = False ,
190
244
allowed_values : Optional [Iterable [str ]] = None ,
@@ -194,83 +248,71 @@ def parse_strings_from_args(
194
248
195
249
196
250
def parse_strings_from_args (
197
- args : List [str ],
198
- name : Union [ bytes , str ] ,
251
+ args : Dict [ bytes , List [bytes ] ],
252
+ name : str ,
199
253
default : Optional [List [str ]] = None ,
200
254
required : bool = False ,
201
255
allowed_values : Optional [Iterable [str ]] = None ,
202
- encoding : Optional [ str ] = "ascii" ,
203
- ) -> Optional [List [Union [ bytes , str ] ]]:
256
+ encoding : str = "ascii" ,
257
+ ) -> Optional [List [str ]]:
204
258
"""
205
259
Parse a string parameter from the request query string list.
206
260
207
- If encoding is not None, the content of the query param will be
208
- decoded to Unicode using the encoding, otherwise it will be encoded
261
+ The content of the query param will be decoded to Unicode using the encoding.
209
262
210
263
Args:
211
- args: the twisted HTTP request. args list.
264
+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args) .
212
265
name: the name of the query parameter.
213
- default: value to use if the parameter is absent,
214
- defaults to None. Must be bytes if encoding is None.
215
- required : whether to raise a 400 SynapseError if the
266
+ default: value to use if the parameter is absent, defaults to None.
267
+ required: whether to raise a 400 SynapseError if the
216
268
parameter is absent, defaults to False.
217
- allowed_values (list[bytes|unicode]): List of allowed values for the
218
- string, or None if any value is allowed, defaults to None. Must be
219
- the same type as name, if given.
269
+ allowed_values: List of allowed values for the
270
+ string, or None if any value is allowed, defaults to None.
220
271
encoding: The encoding to decode the string content with.
221
272
222
273
Returns:
223
- A string value or the default. Unicode if encoding
224
- was given, bytes otherwise.
274
+ A string value or the default.
225
275
226
276
Raises:
227
277
SynapseError if the parameter is absent and required, or if the
228
278
parameter is present, must be one of a list of allowed values and
229
279
is not one of those allowed values.
230
280
"""
281
+ name_bytes = name .encode ("ascii" )
231
282
232
- if not isinstance (name , bytes ):
233
- name = name .encode ("ascii" )
234
-
235
- if name in args :
236
- values = args [name ]
283
+ if name_bytes in args :
284
+ values = args [name_bytes ]
237
285
238
286
return [
239
287
_parse_string_value (value , allowed_values , name = name , encoding = encoding )
240
288
for value in values
241
289
]
242
290
else :
243
291
if required :
244
- message = "Missing string query parameter %r" % (name )
292
+ message = "Missing string query parameter %r" % (name , )
245
293
raise SynapseError (400 , message , errcode = Codes .MISSING_PARAM )
246
- else :
247
-
248
- if encoding and isinstance (default , bytes ):
249
- return default .decode (encoding )
250
294
251
- return default
295
+ return default
252
296
253
297
254
298
def parse_string_from_args (
255
- args : List [str ],
256
- name : Union [ bytes , str ] ,
299
+ args : Dict [ bytes , List [bytes ] ],
300
+ name : str ,
257
301
default : Optional [str ] = None ,
258
302
required : bool = False ,
259
303
allowed_values : Optional [Iterable [str ]] = None ,
260
- encoding : Optional [ str ] = "ascii" ,
261
- ) -> Optional [Union [ bytes , str ] ]:
304
+ encoding : str = "ascii" ,
305
+ ) -> Optional [str ]:
262
306
"""
263
307
Parse the string parameter from the request query string list
264
308
and return the first result.
265
309
266
- If encoding is not None, the content of the query param will be
267
- decoded to Unicode using the encoding, otherwise it will be encoded
310
+ The content of the query param will be decoded to Unicode using the encoding.
268
311
269
312
Args:
270
- args: the twisted HTTP request. args list.
313
+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args) .
271
314
name: the name of the query parameter.
272
- default: value to use if the parameter is absent,
273
- defaults to None. Must be bytes if encoding is None.
315
+ default: value to use if the parameter is absent, defaults to None.
274
316
required: whether to raise a 400 SynapseError if the
275
317
parameter is absent, defaults to False.
276
318
allowed_values: List of allowed values for the
@@ -279,8 +321,7 @@ def parse_string_from_args(
279
321
encoding: The encoding to decode the string content with.
280
322
281
323
Returns:
282
- A string value or the default. Unicode if encoding
283
- was given, bytes otherwise.
324
+ A string value or the default.
284
325
285
326
Raises:
286
327
SynapseError if the parameter is absent and required, or if the
@@ -291,12 +332,15 @@ def parse_string_from_args(
291
332
strings = parse_strings_from_args (
292
333
args ,
293
334
name ,
294
- default = [default ],
335
+ default = [default ] if default is not None else None ,
295
336
required = required ,
296
337
allowed_values = allowed_values ,
297
338
encoding = encoding ,
298
339
)
299
340
341
+ if strings is None :
342
+ return None
343
+
300
344
return strings [0 ]
301
345
302
346
@@ -388,9 +432,8 @@ class attribute containing a pre-compiled regular expression. The automatic
388
432
389
433
def register (self , http_server ):
390
434
""" Register this servlet with the given HTTP server. """
391
- if hasattr (self , "PATTERNS" ):
392
- patterns = self .PATTERNS
393
-
435
+ patterns = getattr (self , "PATTERNS" , None )
436
+ if patterns :
394
437
for method in ("GET" , "PUT" , "POST" , "DELETE" ):
395
438
if hasattr (self , "on_%s" % (method ,)):
396
439
servlet_classname = self .__class__ .__name__
0 commit comments