Skip to content

Commit ae4a415

Browse files
kszucscpcloud
authored andcommitted
feat(common): add support for annotating with coercible types
1 parent ddc6603 commit ae4a415

File tree

3 files changed

+108
-11
lines changed

3 files changed

+108
-11
lines changed

ibis/common/tests/test_validators.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
3+
import sys
14
from typing import Dict, List, Optional, Tuple, Union
25

36
import pytest
47
from typing_extensions import Annotated
58

69
from ibis.common.validators import (
10+
Coercible,
711
Validator,
812
all_of,
913
any_of,
@@ -14,6 +18,7 @@
1418
int_,
1519
isin,
1620
list_of,
21+
mapping_of,
1722
min_,
1823
str_,
1924
tuple_of,
@@ -98,5 +103,60 @@ def endswith_d(x, this):
98103
],
99104
)
100105
def test_validator_from_annotation(annot, expected):
101-
validator = Validator.from_annotation(annot)
102-
assert validator == expected
106+
assert Validator.from_annotation(annot) == expected
107+
108+
109+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
110+
def test_validator_from_annotation_uniontype():
111+
# uniontype marks `type1 | type2` annotations and it's different from
112+
# Union[type1, type2]
113+
validator = Validator.from_annotation(str | int | float)
114+
assert validator == any_of((instance_of(str), instance_of(int), instance_of(float)))
115+
116+
117+
class Something(Coercible):
118+
def __init__(self, value):
119+
self.value = value
120+
121+
@classmethod
122+
def __coerce__(cls, obj):
123+
return cls(obj + 1)
124+
125+
def __eq__(self, other):
126+
return type(self) == type(other) and self.value == other.value
127+
128+
129+
class SomethingSimilar(Something):
130+
pass
131+
132+
133+
class SomethingDifferent(Coercible):
134+
@classmethod
135+
def __coerce__(cls, obj):
136+
return obj + 2
137+
138+
139+
def test_coercible():
140+
s = Validator.from_annotation(Something)
141+
assert s(1) == Something(2)
142+
assert s(10) == Something(11)
143+
144+
145+
def test_coercible_checks_type():
146+
s = Validator.from_annotation(SomethingSimilar)
147+
v = Validator.from_annotation(SomethingDifferent)
148+
149+
assert s(1) == SomethingSimilar(2)
150+
assert SomethingDifferent.__coerce__(1) == 3
151+
152+
with pytest.raises(TypeError, match="not an instance of .*SomethingDifferent.*"):
153+
v(1)
154+
155+
156+
def test_mapping_of():
157+
value = {"a": 1, "b": 2}
158+
assert mapping_of(str, int, value, type=dict) == value
159+
assert mapping_of(str, int, value, type=frozendict) == frozendict(value)
160+
161+
with pytest.raises(TypeError, match="Argument must be a mapping"):
162+
mapping_of(str, float, 10, type=dict)

ibis/common/typing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
if sys.version_info >= (3, 9):
1111

1212
@toolz.memoize
13-
def evaluate_typehint(hint, module_name) -> Any:
13+
def evaluate_typehint(hint, module_name=None) -> Any:
1414
if isinstance(hint, str):
1515
hint = ForwardRef(hint)
1616
if isinstance(hint, ForwardRef):
17-
globalns = sys.modules[module_name].__dict__
17+
if module_name is None:
18+
globalns = {}
19+
else:
20+
globalns = sys.modules[module_name].__dict__
1821
return hint._evaluate(globalns, locals(), frozenset())
1922
else:
2023
return hint
@@ -26,7 +29,10 @@ def evaluate_typehint(hint, module_name) -> Any:
2629
if isinstance(hint, str):
2730
hint = ForwardRef(hint)
2831
if isinstance(hint, ForwardRef):
29-
globalns = sys.modules[module_name].__dict__
32+
if module_name is None:
33+
globalns = {}
34+
else:
35+
globalns = sys.modules[module_name].__dict__
3036
return hint._evaluate(globalns, locals())
3137
else:
3238
return hint

ibis/common/validators.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import math
4+
from abc import ABC, abstractmethod
45
from contextlib import suppress
56
from typing import Any, Callable, Iterable, Mapping, Sequence, Union
67

@@ -12,6 +13,20 @@
1213
from ibis.common.typing import evaluate_typehint
1314
from ibis.util import flatten_iterable, frozendict, is_function, is_iterable
1415

16+
try:
17+
from types import UnionType
18+
except ImportError:
19+
UnionType = object()
20+
21+
22+
class Coercible(ABC):
23+
__slots__ = ()
24+
25+
@classmethod
26+
@abstractmethod
27+
def __coerce__(cls, obj):
28+
...
29+
1530

1631
class Validator(Callable):
1732
"""Abstract base class for defining argument validators."""
@@ -24,11 +39,14 @@ def from_annotation(cls, annot, module=None):
2439
annot = evaluate_typehint(annot, module)
2540
origin_type = get_origin(annot)
2641

27-
if annot is Any:
28-
return any_
29-
elif origin_type is None:
30-
return instance_of(annot)
31-
elif origin_type is Union:
42+
if origin_type is None:
43+
if annot is Any:
44+
return any_
45+
elif issubclass(annot, Coercible):
46+
return coerced_to(annot)
47+
else:
48+
return instance_of(annot)
49+
elif origin_type is UnionType or origin_type is Union:
3250
inners = map(cls.from_annotation, get_args(annot))
3351
return any_of(tuple(inners))
3452
elif origin_type is Annotated:
@@ -40,8 +58,13 @@ def from_annotation(cls, annot, module=None):
4058
elif issubclass(origin_type, Mapping):
4159
key_type, value_type = map(cls.from_annotation, get_args(annot))
4260
return mapping_of(key_type, value_type, type=origin_type)
61+
elif issubclass(origin_type, Callable):
62+
# TODO(kszucs): add a more comprehensive callable_with rule here
63+
return instance_of(Callable)
4364
else:
44-
return instance_of(annot)
65+
raise NotImplementedError(
66+
f"Cannot create validator from annotation {annot} {origin_type}"
67+
)
4568

4669

4770
# TODO(kszucs): in order to cache valiadator instances we could subclass
@@ -96,6 +119,12 @@ def instance_of(klasses, arg, **kwargs):
96119
return arg
97120

98121

122+
@validator
123+
def coerced_to(klass, arg, **kwargs):
124+
value = klass.__coerce__(arg)
125+
return instance_of(klass, value, **kwargs)
126+
127+
99128
class lazy_instance_of(Validator):
100129
"""A version of `instance_of` that accepts qualnames instead of imported classes.
101130
@@ -194,6 +223,8 @@ def sequence_of(inner, arg, *, type, min_length=0, flatten=False, **kwargs):
194223

195224
@validator
196225
def mapping_of(key_inner, value_inner, arg, *, type, **kwargs):
226+
if not isinstance(arg, Mapping):
227+
raise IbisTypeError('Argument must be a mapping')
197228
return type(
198229
(key_inner(k, **kwargs), value_inner(v, **kwargs)) for k, v in arg.items()
199230
)

0 commit comments

Comments
 (0)