Skip to content

Commit ae57c36

Browse files
kszucscpcloud
authored andcommitted
feat(common): support Callable arguments and return types in Validator.from_annotable()
1 parent 560474e commit ae57c36

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

ibis/common/tests/test_validators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import sys
4-
from typing import Dict, List, Literal, Optional, Tuple, Union
4+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
55

66
import pytest
77
from typing_extensions import Annotated
@@ -110,6 +110,11 @@ def endswith_d(x, this):
110110
(Dict[str, float], dict_of(instance_of(str), instance_of(float))),
111111
(frozendict[str, int], frozendict_of(instance_of(str), instance_of(int))),
112112
(Literal["alpha", "beta", "gamma"], isin(("alpha", "beta", "gamma"))),
113+
(
114+
Callable[[str, int], str],
115+
callable_with((instance_of(str), instance_of(int)), instance_of(str)),
116+
),
117+
(Callable, instance_of(Callable)),
113118
],
114119
)
115120
def test_validator_from_annotation(annot, expected):

ibis/common/validators.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,15 @@ def from_annotation(cls, annot, module=None):
5959
(inner,) = map(cls.from_annotation, get_args(annot))
6060
return sequence_of(inner, type=origin_type)
6161
elif issubclass(origin_type, Mapping):
62-
key_type, value_type = map(cls.from_annotation, get_args(annot))
63-
return mapping_of(key_type, value_type, type=origin_type)
62+
key_inner, value_inner = map(cls.from_annotation, get_args(annot))
63+
return mapping_of(key_inner, value_inner, type=origin_type)
6464
elif issubclass(origin_type, Callable):
65-
# TODO(kszucs): add a more comprehensive callable_with rule here
66-
return instance_of(Callable)
65+
if args := get_args(annot):
66+
arg_inners = map(cls.from_annotation, args[0])
67+
return_inner = cls.from_annotation(args[1])
68+
return callable_with(tuple(arg_inners), return_inner)
69+
else:
70+
return instance_of(Callable)
6771
else:
6872
raise NotImplementedError(
6973
f"Cannot create validator from annotation {annot} {origin_type}"
@@ -264,13 +268,13 @@ def mapping_of(key_inner, value_inner, arg, *, type, **kwargs):
264268

265269

266270
@validator
267-
def callable_with(args_inner, return_inner, value, **kwargs):
271+
def callable_with(arg_inners, return_inner, value, **kwargs):
268272
from ibis.common.annotations import annotated
269273

270274
if not callable(value):
271275
raise IbisTypeError("Argument must be a callable")
272276

273-
fn = annotated(args_inner, return_inner, value)
277+
fn = annotated(arg_inners, return_inner, value)
274278

275279
has_varargs = False
276280
positional, keyword_only = [], []
@@ -286,9 +290,9 @@ def callable_with(args_inner, return_inner, value, **kwargs):
286290
raise IbisTypeError(
287291
"Callable has mandatory keyword-only arguments which cannot be specified"
288292
)
289-
elif len(positional) > len(args_inner):
293+
elif len(positional) > len(arg_inners):
290294
raise IbisTypeError("Callable has more positional arguments than expected")
291-
elif len(positional) < len(args_inner) and not has_varargs:
295+
elif len(positional) < len(arg_inners) and not has_varargs:
292296
raise IbisTypeError("Callable has less positional arguments than expected")
293297
else:
294298
return fn

0 commit comments

Comments
 (0)