4
4
import inspect
5
5
from typing import Any
6
6
7
- from ibis .common .validators import Validator , any_ , option
7
+ from ibis .common .validators import Validator , any_ , frozendict_of , option , tuple_of
8
8
from ibis .util import DotDict
9
9
10
10
EMPTY = inspect .Parameter .empty # marker for missing argument
@@ -70,8 +70,15 @@ def initialize(self, this):
70
70
class Argument (Annotation ):
71
71
"""Base class for all fields which should be passed as arguments."""
72
72
73
+ __slots__ = ('_kind' ,)
74
+
75
+ def __init__ (self , validator = None , default = EMPTY , kind = POSITIONAL_OR_KEYWORD ):
76
+ self ._kind = kind
77
+ self ._default = default
78
+ self ._validator = validator
79
+
73
80
@classmethod
74
- def mandatory (cls , validator = None ):
81
+ def required (cls , validator = None ):
75
82
"""Annotation to mark a mandatory argument."""
76
83
return cls (validator )
77
84
@@ -89,6 +96,17 @@ def optional(cls, validator=None, default=None):
89
96
validator = option (validator , default = default )
90
97
return cls (validator , default = None )
91
98
99
+ @classmethod
100
+ def varargs (cls , validator = None ):
101
+ """Annotation to mark a variable length positional argument."""
102
+ validator = None if validator is None else tuple_of (validator )
103
+ return cls (validator , kind = VAR_POSITIONAL )
104
+
105
+ @classmethod
106
+ def varkwds (cls , validator = None ):
107
+ validator = None if validator is None else frozendict_of (any_ , validator )
108
+ return cls (validator , kind = VAR_KEYWORD )
109
+
92
110
93
111
class Parameter (inspect .Parameter ):
94
112
"""Augmented Parameter class to additionally hold a validator object."""
@@ -102,7 +120,7 @@ def __init__(self, name, annotation):
102
120
)
103
121
super ().__init__ (
104
122
name ,
105
- kind = POSITIONAL_OR_KEYWORD ,
123
+ kind = annotation . _kind ,
106
124
default = annotation ._default ,
107
125
annotation = annotation ._validator ,
108
126
)
@@ -150,22 +168,34 @@ def merge(cls, *signatures, **annotations):
150
168
151
169
# mandatory fields without default values must preceed the optional
152
170
# ones in the function signature, the partial ordering will be kept
171
+ var_args , var_kwargs = [], []
153
172
new_args , new_kwargs = [], []
154
- inherited_args , inherited_kwargs = [], []
173
+ old_args , old_kwargs = [], []
155
174
156
175
for name , param in params .items ():
157
- if name in inherited :
176
+ if param .kind == VAR_POSITIONAL :
177
+ var_args .append (param )
178
+ elif param .kind == VAR_KEYWORD :
179
+ var_kwargs .append (param )
180
+ elif name in inherited :
158
181
if param .default is EMPTY :
159
- inherited_args .append (param )
182
+ old_args .append (param )
160
183
else :
161
- inherited_kwargs .append (param )
184
+ old_kwargs .append (param )
162
185
else :
163
186
if param .default is EMPTY :
164
187
new_args .append (param )
165
188
else :
166
189
new_kwargs .append (param )
167
190
168
- return cls (inherited_args + new_args + new_kwargs + inherited_kwargs )
191
+ if len (var_args ) > 1 :
192
+ raise TypeError ('only one variadic positional *args parameter is allowed' )
193
+ if len (var_kwargs ) > 1 :
194
+ raise TypeError ('only one variadic keywords **kwargs parameter is allowed' )
195
+
196
+ return cls (
197
+ old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs
198
+ )
169
199
170
200
@classmethod
171
201
def from_callable (cls , fn , validators = None , return_validator = None ):
@@ -199,25 +229,24 @@ def from_callable(cls, fn, validators=None, return_validator=None):
199
229
200
230
parameters = []
201
231
for param in sig .parameters .values ():
202
- if param .kind in {
203
- VAR_POSITIONAL ,
204
- VAR_KEYWORD ,
205
- POSITIONAL_ONLY ,
206
- KEYWORD_ONLY ,
207
- }:
232
+ if param .kind in {POSITIONAL_ONLY , KEYWORD_ONLY }:
208
233
raise TypeError (f"unsupported parameter kind { param .kind } in { fn } " )
209
234
210
235
if param .name in validators :
211
236
validator = validators [param .name ]
212
- elif param .annotation is EMPTY :
213
- validator = any_
214
- else :
237
+ elif param .annotation is not EMPTY :
215
238
validator = Validator .from_annotation (
216
239
param .annotation , module = fn .__module__
217
240
)
218
-
219
- if param .default is EMPTY :
220
- annot = Argument .mandatory (validator )
241
+ else :
242
+ validator = None
243
+
244
+ if param .kind is VAR_POSITIONAL :
245
+ annot = Argument .varargs (validator )
246
+ elif param .kind is VAR_KEYWORD :
247
+ annot = Argument .varkwds (validator )
248
+ elif param .default is EMPTY :
249
+ annot = Argument .required (validator )
221
250
else :
222
251
annot = Argument .default (param .default , validator )
223
252
@@ -250,7 +279,18 @@ def unbind(self, this: Any):
250
279
Tuple of positional and keyword arguments.
251
280
"""
252
281
# does the reverse of bind, but doesn't apply defaults
253
- return {name : getattr (this , name ) for name in self .parameters }
282
+ args , kwargs = [], {}
283
+ for name , param in self .parameters .items ():
284
+ value = getattr (this , name )
285
+ if param .kind is POSITIONAL_OR_KEYWORD :
286
+ args .append (value )
287
+ elif param .kind is VAR_POSITIONAL :
288
+ args .extend (value )
289
+ elif param .kind is VAR_KEYWORD :
290
+ kwargs .update (value )
291
+ else :
292
+ raise TypeError (f"unsupported parameter kind { param .kind } " )
293
+ return tuple (args ), kwargs
254
294
255
295
def validate (self , * args , ** kwargs ):
256
296
"""Validate the arguments against the signature.
@@ -278,7 +318,16 @@ def validate(self, *args, **kwargs):
278
318
param = self .parameters [name ]
279
319
# TODO(kszucs): provide more error context on failure
280
320
this [name ] = param .validate (value , this = this )
321
+ return this
281
322
323
+ def validate_nobind (self , ** kwargs ):
324
+ """Validate the arguments against the signature without binding."""
325
+ this = DotDict ()
326
+ for name , param in self .parameters .items ():
327
+ value = kwargs .get (name , param .default )
328
+ if value is EMPTY :
329
+ raise TypeError (f"missing required argument `{ name !r} `" )
330
+ this [name ] = param .validate (value , this = kwargs )
282
331
return this
283
332
284
333
def validate_return (self , value ):
@@ -303,8 +352,10 @@ def validate_return(self, value):
303
352
# aliases for convenience
304
353
attribute = Attribute
305
354
argument = Argument
306
- mandatory = Argument .mandatory
355
+ required = Argument .required
307
356
optional = Argument .optional
357
+ varargs = Argument .varargs
358
+ varkwds = Argument .varkwds
308
359
default = Argument .default
309
360
310
361
@@ -384,9 +435,10 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
384
435
385
436
@functools .wraps (func )
386
437
def wrapped (* args , ** kwargs ):
387
- kwargs = sig .validate (* args , ** kwargs )
388
- result = sig .validate_return (func (** kwargs ))
389
- return result
438
+ values = sig .validate (* args , ** kwargs )
439
+ args , kwargs = sig .unbind (values )
440
+ result = func (* args , ** kwargs )
441
+ return sig .validate_return (result )
390
442
391
443
wrapped .__signature__ = sig
392
444
0 commit comments