Skip to content

Commit baea1fa

Browse files
kszucscpcloud
authored andcommitted
feat(common): add support for variadic positional and variadic keyword annotations
1 parent b368b04 commit baea1fa

File tree

5 files changed

+288
-63
lines changed

5 files changed

+288
-63
lines changed

ibis/common/annotations.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from typing import Any
66

7-
from ibis.common.validators import Validator, any_, option
7+
from ibis.common.validators import Validator, any_, frozendict_of, option, tuple_of
88
from ibis.util import DotDict
99

1010
EMPTY = inspect.Parameter.empty # marker for missing argument
@@ -70,8 +70,15 @@ def initialize(self, this):
7070
class Argument(Annotation):
7171
"""Base class for all fields which should be passed as arguments."""
7272

73+
__slots__ = ('_kind',)
74+
75+
def __init__(self, validator=None, default=EMPTY, kind=POSITIONAL_OR_KEYWORD):
76+
self._kind = kind
77+
self._default = default
78+
self._validator = validator
79+
7380
@classmethod
74-
def mandatory(cls, validator=None):
81+
def required(cls, validator=None):
7582
"""Annotation to mark a mandatory argument."""
7683
return cls(validator)
7784

@@ -89,6 +96,17 @@ def optional(cls, validator=None, default=None):
8996
validator = option(validator, default=default)
9097
return cls(validator, default=None)
9198

99+
@classmethod
100+
def varargs(cls, validator=None):
101+
"""Annotation to mark a variable length positional argument."""
102+
validator = None if validator is None else tuple_of(validator)
103+
return cls(validator, kind=VAR_POSITIONAL)
104+
105+
@classmethod
106+
def varkwds(cls, validator=None):
107+
validator = None if validator is None else frozendict_of(any_, validator)
108+
return cls(validator, kind=VAR_KEYWORD)
109+
92110

93111
class Parameter(inspect.Parameter):
94112
"""Augmented Parameter class to additionally hold a validator object."""
@@ -102,7 +120,7 @@ def __init__(self, name, annotation):
102120
)
103121
super().__init__(
104122
name,
105-
kind=POSITIONAL_OR_KEYWORD,
123+
kind=annotation._kind,
106124
default=annotation._default,
107125
annotation=annotation._validator,
108126
)
@@ -150,22 +168,34 @@ def merge(cls, *signatures, **annotations):
150168

151169
# mandatory fields without default values must preceed the optional
152170
# ones in the function signature, the partial ordering will be kept
171+
var_args, var_kwargs = [], []
153172
new_args, new_kwargs = [], []
154-
inherited_args, inherited_kwargs = [], []
173+
old_args, old_kwargs = [], []
155174

156175
for name, param in params.items():
157-
if name in inherited:
176+
if param.kind == VAR_POSITIONAL:
177+
var_args.append(param)
178+
elif param.kind == VAR_KEYWORD:
179+
var_kwargs.append(param)
180+
elif name in inherited:
158181
if param.default is EMPTY:
159-
inherited_args.append(param)
182+
old_args.append(param)
160183
else:
161-
inherited_kwargs.append(param)
184+
old_kwargs.append(param)
162185
else:
163186
if param.default is EMPTY:
164187
new_args.append(param)
165188
else:
166189
new_kwargs.append(param)
167190

168-
return cls(inherited_args + new_args + new_kwargs + inherited_kwargs)
191+
if len(var_args) > 1:
192+
raise TypeError('only one variadic positional *args parameter is allowed')
193+
if len(var_kwargs) > 1:
194+
raise TypeError('only one variadic keywords **kwargs parameter is allowed')
195+
196+
return cls(
197+
old_args + new_args + var_args + new_kwargs + old_kwargs + var_kwargs
198+
)
169199

170200
@classmethod
171201
def from_callable(cls, fn, validators=None, return_validator=None):
@@ -199,25 +229,24 @@ def from_callable(cls, fn, validators=None, return_validator=None):
199229

200230
parameters = []
201231
for param in sig.parameters.values():
202-
if param.kind in {
203-
VAR_POSITIONAL,
204-
VAR_KEYWORD,
205-
POSITIONAL_ONLY,
206-
KEYWORD_ONLY,
207-
}:
232+
if param.kind in {POSITIONAL_ONLY, KEYWORD_ONLY}:
208233
raise TypeError(f"unsupported parameter kind {param.kind} in {fn}")
209234

210235
if param.name in validators:
211236
validator = validators[param.name]
212-
elif param.annotation is EMPTY:
213-
validator = any_
214-
else:
237+
elif param.annotation is not EMPTY:
215238
validator = Validator.from_annotation(
216239
param.annotation, module=fn.__module__
217240
)
218-
219-
if param.default is EMPTY:
220-
annot = Argument.mandatory(validator)
241+
else:
242+
validator = None
243+
244+
if param.kind is VAR_POSITIONAL:
245+
annot = Argument.varargs(validator)
246+
elif param.kind is VAR_KEYWORD:
247+
annot = Argument.varkwds(validator)
248+
elif param.default is EMPTY:
249+
annot = Argument.required(validator)
221250
else:
222251
annot = Argument.default(param.default, validator)
223252

@@ -250,7 +279,18 @@ def unbind(self, this: Any):
250279
Tuple of positional and keyword arguments.
251280
"""
252281
# does the reverse of bind, but doesn't apply defaults
253-
return {name: getattr(this, name) for name in self.parameters}
282+
args, kwargs = [], {}
283+
for name, param in self.parameters.items():
284+
value = getattr(this, name)
285+
if param.kind is POSITIONAL_OR_KEYWORD:
286+
args.append(value)
287+
elif param.kind is VAR_POSITIONAL:
288+
args.extend(value)
289+
elif param.kind is VAR_KEYWORD:
290+
kwargs.update(value)
291+
else:
292+
raise TypeError(f"unsupported parameter kind {param.kind}")
293+
return tuple(args), kwargs
254294

255295
def validate(self, *args, **kwargs):
256296
"""Validate the arguments against the signature.
@@ -278,7 +318,16 @@ def validate(self, *args, **kwargs):
278318
param = self.parameters[name]
279319
# TODO(kszucs): provide more error context on failure
280320
this[name] = param.validate(value, this=this)
321+
return this
281322

323+
def validate_nobind(self, **kwargs):
324+
"""Validate the arguments against the signature without binding."""
325+
this = DotDict()
326+
for name, param in self.parameters.items():
327+
value = kwargs.get(name, param.default)
328+
if value is EMPTY:
329+
raise TypeError(f"missing required argument `{name!r}`")
330+
this[name] = param.validate(value, this=kwargs)
282331
return this
283332

284333
def validate_return(self, value):
@@ -303,8 +352,10 @@ def validate_return(self, value):
303352
# aliases for convenience
304353
attribute = Attribute
305354
argument = Argument
306-
mandatory = Argument.mandatory
355+
required = Argument.required
307356
optional = Argument.optional
357+
varargs = Argument.varargs
358+
varkwds = Argument.varkwds
308359
default = Argument.default
309360

310361

@@ -384,9 +435,10 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
384435

385436
@functools.wraps(func)
386437
def wrapped(*args, **kwargs):
387-
kwargs = sig.validate(*args, **kwargs)
388-
result = sig.validate_return(func(**kwargs))
389-
return result
438+
values = sig.validate(*args, **kwargs)
439+
args, kwargs = sig.unbind(values)
440+
result = func(*args, **kwargs)
441+
return sig.validate_return(result)
390442

391443
wrapped.__signature__ = sig
392444

ibis/common/grounds.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
5454
if name in dct:
5555
dct[name] = Argument.default(dct[name], validator)
5656
else:
57-
dct[name] = Argument.mandatory(validator)
57+
dct[name] = Argument.required(validator)
5858

5959
# collect the newly defined annotations
6060
slots = list(dct.pop('__slots__', []))
6161
namespace, arguments = {}, {}
6262
for name, attrib in dct.items():
6363
if isinstance(attrib, Validator):
64-
attrib = Argument.mandatory(attrib)
64+
attrib = Argument.required(attrib)
6565

6666
if isinstance(attrib, Argument):
6767
arguments[name] = attrib
@@ -96,6 +96,12 @@ def __create__(cls, *args, **kwargs) -> Annotable:
9696
kwargs = cls.__signature__.validate(*args, **kwargs)
9797
return super().__create__(**kwargs)
9898

99+
@classmethod
100+
def __recreate__(cls, kwargs) -> Annotable:
101+
# bypass signature binding by requiring keyword arguments only
102+
kwargs = cls.__signature__.validate_nobind(**kwargs)
103+
return super().__create__(**kwargs)
104+
99105
def __init__(self, **kwargs) -> None:
100106
# set the already validated arguments
101107
for name, value in kwargs.items():
@@ -221,7 +227,8 @@ def __precomputed_hash__(self):
221227
def __reduce__(self):
222228
# assuming immutability and idempotency of the __init__ method, we can
223229
# reconstruct the instance from the arguments without additional attributes
224-
return (self.__class__, self.__args__)
230+
state = dict(zip(self.__argnames__, self.__args__))
231+
return (self.__recreate__, (state,))
225232

226233
def __hash__(self):
227234
return self.__precomputed_hash__
@@ -240,4 +247,4 @@ def argnames(self):
240247
def copy(self, **overrides):
241248
kwargs = dict(zip(self.__argnames__, self.__args__))
242249
kwargs.update(overrides)
243-
return self.__class__(**kwargs)
250+
return self.__recreate__(kwargs)

ibis/common/tests/test_annotations.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_parameter():
7777
def fn(x, this):
7878
return int(x) + this['other']
7979

80-
annot = Argument.mandatory(fn)
80+
annot = Argument.required(fn)
8181
p = Parameter('test', annotation=annot)
8282

8383
assert p.annotation is fn
@@ -104,8 +104,8 @@ def to_int(x, this):
104104
def add_other(x, this):
105105
return int(x) + this['other']
106106

107-
other = Parameter('other', annotation=Argument.mandatory(to_int))
108-
this = Parameter('this', annotation=Argument.mandatory(add_other))
107+
other = Parameter('other', annotation=Argument.required(to_int))
108+
this = Parameter('this', annotation=Argument.required(add_other))
109109

110110
sig = Signature(parameters=[other, this])
111111
assert sig.validate(1, 2) == {'other': 1, 'this': 3}
@@ -124,19 +124,20 @@ def test(a: int, b: int, c: int = 1):
124124
sig.validate(2, 3, "4")
125125

126126

127-
def test_signature_from_callable_unsupported_argument_kinds():
128-
def test(a: int, b: int, *args):
129-
pass
127+
def test_signature_from_callable_with_varargs():
128+
def test(a: int, b: int, *args: int):
129+
return a + b + sum(args)
130130

131-
with pytest.raises(TypeError, match="unsupported parameter kind VAR_POSITIONAL"):
132-
Signature.from_callable(test)
131+
sig = Signature.from_callable(test)
132+
assert sig.validate(2, 3) == {'a': 2, 'b': 3, 'args': ()}
133+
assert sig.validate(2, 3, 4) == {'a': 2, 'b': 3, 'args': (4,)}
134+
assert sig.validate(2, 3, 4, 5) == {'a': 2, 'b': 3, 'args': (4, 5)}
133135

134-
def test(a: int, b: int, **kwargs):
135-
pass
136+
with pytest.raises(TypeError):
137+
sig.validate(2, 3, 4, "5")
136138

137-
with pytest.raises(TypeError, match="unsupported parameter kind VAR_KEYWORD"):
138-
Signature.from_callable(test)
139139

140+
def test_signature_from_callable_unsupported_argument_kinds():
140141
def test(a: int, b: int, *, c: int):
141142
pass
142143

@@ -157,14 +158,15 @@ def to_int(x, this):
157158
def add_other(x, this):
158159
return int(x) + this['other']
159160

160-
other = Parameter('other', annotation=Argument.mandatory(to_int))
161-
this = Parameter('this', annotation=Argument.mandatory(add_other))
161+
other = Parameter('other', annotation=Argument.required(to_int))
162+
this = Parameter('this', annotation=Argument.required(add_other))
162163

163164
sig = Signature(parameters=[other, this])
164165
params = sig.validate(1, this=2)
165166

166-
kwargs = sig.unbind(params)
167-
assert kwargs == {"other": 1, "this": 3}
167+
args, kwargs = sig.unbind(params)
168+
assert args == (1, 3)
169+
assert kwargs == {}
168170

169171

170172
def as_float(x, this):
@@ -175,8 +177,8 @@ def as_tuple_of_floats(x, this):
175177
return tuple(float(i) for i in x)
176178

177179

178-
a = Parameter('a', annotation=Argument.mandatory(validator=as_float))
179-
b = Parameter('b', annotation=Argument.mandatory(validator=as_float))
180+
a = Parameter('a', annotation=Argument.required(validator=as_float))
181+
b = Parameter('b', annotation=Argument.required(validator=as_float))
180182
c = Parameter('c', annotation=Argument.default(default=0, validator=as_float))
181183
d = Parameter(
182184
'd', annotation=Argument.default(default=tuple(), validator=as_tuple_of_floats)
@@ -190,10 +192,11 @@ def test_signature_unbind_with_empty_variadic(d):
190192
params = sig.validate(1, 2, 3, d, e=4)
191193
assert params == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0}
192194

193-
kwargs = sig.unbind(params)
194-
assert kwargs == {'a': 1.0, 'b': 2.0, 'c': 3.0, 'd': d, 'e': 4.0}
195+
args, kwargs = sig.unbind(params)
196+
assert args == (1.0, 2.0, 3.0, tuple(map(float, d)), 4.0)
197+
assert kwargs == {}
195198

196-
params_again = sig.validate(**kwargs)
199+
params_again = sig.validate(*args, **kwargs)
197200
assert params_again == params
198201

199202

@@ -333,3 +336,27 @@ def test(a, b, c):
333336
func(1, 2)
334337

335338
assert func(1, 2, c=3) == 6
339+
340+
341+
def test_annotated_function_with_varargs():
342+
@annotated
343+
def test(a: float, b: float, *args: int):
344+
return sum((a, b) + args)
345+
346+
assert test(1.0, 2.0, 3, 4) == 10.0
347+
assert test(1.0, 2.0, 3, 4, 5) == 15.0
348+
349+
with pytest.raises(TypeError):
350+
test(1.0, 2.0, 3, 4, 5, 6.0)
351+
352+
353+
def test_annotated_function_with_varkwds():
354+
@annotated
355+
def test(a: float, b: float, **kwargs: int):
356+
return sum((a, b) + tuple(kwargs.values()))
357+
358+
assert test(1.0, 2.0, c=3, d=4) == 10.0
359+
assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0
360+
361+
with pytest.raises(TypeError):
362+
test(1.0, 2.0, c=3, d=4, e=5, f=6.0)

0 commit comments

Comments
 (0)