Skip to content

Commit 6722c31

Browse files
committed
feat(datatype/schema): support datatype and schema declaration using type annotated classes
1 parent d22ae7b commit 6722c31

File tree

6 files changed

+499
-15
lines changed

6 files changed

+499
-15
lines changed

ibis/expr/datatypes/cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def can_cast_struct(source, target, **kwargs):
168168

169169
@castable.register(dt.Array, dt.Array)
170170
@castable.register(dt.Set, dt.Set)
171-
def can_cast_variadic(
171+
def can_cast_array_or_set(
172172
source: dt.Array | dt.Set, target: dt.Array | dt.Set, **kwargs
173173
) -> bool:
174174
return castable(source.value_type, target.value_type)

ibis/expr/datatypes/core.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from __future__ import annotations
22

3+
import datetime as pydatetime
4+
import decimal as pydecimal
35
import numbers
6+
import uuid as pyuuid
47
from abc import abstractmethod
5-
from collections.abc import Iterator, Mapping
8+
from collections.abc import Iterator, Mapping, Sequence
9+
from collections.abc import Set as PySet
10+
from numbers import Integral, Real
611
from typing import Any, Iterable, NamedTuple
712

813
import numpy as np
14+
import toolz
915
from multipledispatch import Dispatcher
1016
from public import public
17+
from typing_extensions import get_args, get_origin, get_type_hints
1118

1219
import ibis.expr.types as ir
1320
from ibis.common.annotations import attribute, optional
14-
from ibis.common.exceptions import IbisTypeError
1521
from ibis.common.grounds import Concrete, Singleton
1622
from ibis.common.validators import (
1723
all_of,
@@ -22,12 +28,53 @@
2228
validator,
2329
)
2430

31+
# TODO(kszucs): we don't support union types yet
32+
2533
dtype = Dispatcher('dtype')
2634

2735

2836
@dtype.register(object)
2937
def dtype_from_object(value, **kwargs) -> DataType:
30-
raise IbisTypeError(f'Value {value!r} is not a valid datatype')
38+
# TODO(kszucs): implement this in a @dtype.register(type) overload once dtype
39+
# turned into a singledispatched function because that overload doesn't work
40+
# with multipledispatch
41+
42+
# TODO(kszucs): support Tuple[int, str] and Tuple[int, ...] typehints
43+
# in order to support more kinds of typehints follow the implementation of
44+
# Validator.from_annotation
45+
origin_type = get_origin(value)
46+
if origin_type is None:
47+
if issubclass(value, DataType):
48+
return value()
49+
elif result := _python_dtypes.get(value):
50+
return result
51+
elif annots := get_type_hints(value):
52+
return Struct(toolz.valmap(dtype, annots))
53+
elif issubclass(value, bytes):
54+
return bytes
55+
elif issubclass(value, str):
56+
return string
57+
elif issubclass(value, Integral):
58+
return int64
59+
elif issubclass(value, Real):
60+
return float64
61+
elif value is type(None):
62+
return null
63+
else:
64+
raise TypeError(
65+
f"Cannot construct an ibis datatype from python type {value!r}"
66+
)
67+
elif issubclass(origin_type, Sequence):
68+
(value_type,) = map(dtype, get_args(value))
69+
return Array(value_type)
70+
elif issubclass(origin_type, Mapping):
71+
key_type, value_type = map(dtype, get_args(value))
72+
return Map(key_type, value_type)
73+
elif issubclass(origin_type, PySet):
74+
(value_type,) = map(dtype, get_args(value))
75+
return Set(value_type)
76+
else:
77+
raise TypeError(f'Value {value!r} is not a valid datatype')
3178

3279

3380
@validator
@@ -239,11 +286,20 @@ class Primitive(DataType, Singleton):
239286
"""Values with known size."""
240287

241288

289+
# TODO(kszucs): consider to remove since we don't actually use this information
242290
@public
243291
class Variadic(DataType):
244292
"""Values with unknown size."""
245293

246294

295+
@public
296+
class Parametric(DataType):
297+
"""Types that can be parameterized."""
298+
299+
def __class_getitem__(cls, params):
300+
return cls(*params) if isinstance(params, tuple) else cls(params)
301+
302+
247303
@public
248304
class Null(Primitive):
249305
"""Null values."""
@@ -339,7 +395,7 @@ class Time(Temporal, Primitive):
339395

340396

341397
@public
342-
class Timestamp(Temporal):
398+
class Timestamp(Temporal, Parametric):
343399
"""Timestamp values."""
344400

345401
timezone = optional(instance_of(str))
@@ -487,7 +543,7 @@ class Float64(Floating):
487543

488544

489545
@public
490-
class Decimal(Numeric):
546+
class Decimal(Numeric, Parametric):
491547
"""Fixed-precision decimal values."""
492548

493549
precision = optional(instance_of(int))
@@ -551,7 +607,7 @@ def _pretty_piece(self) -> str:
551607

552608

553609
@public
554-
class Interval(DataType):
610+
class Interval(Parametric):
555611
"""Interval values."""
556612

557613
__valid_units__ = {
@@ -617,7 +673,7 @@ def _pretty_piece(self) -> str:
617673

618674

619675
@public
620-
class Category(DataType):
676+
class Category(Parametric):
621677
cardinality = optional(instance_of(int))
622678

623679
scalar = ir.CategoryScalar
@@ -640,14 +696,17 @@ def to_integer_type(self):
640696

641697

642698
@public
643-
class Struct(DataType, Mapping):
699+
class Struct(Parametric, Mapping):
644700
"""Structured values."""
645701

646702
fields = frozendict_of(instance_of(str), datatype)
647703

648704
scalar = ir.StructScalar
649705
column = ir.StructColumn
650706

707+
def __class_getitem__(cls, fields):
708+
return cls({slice_.start: slice_.stop for slice_ in fields})
709+
651710
@classmethod
652711
def from_tuples(
653712
cls, pairs: Iterable[tuple[str, str | DataType]], nullable: bool = True
@@ -697,7 +756,7 @@ def _pretty_piece(self) -> str:
697756

698757

699758
@public
700-
class Array(Variadic):
759+
class Array(Variadic, Parametric):
701760
"""Array values."""
702761

703762
value_type = datatype
@@ -711,7 +770,7 @@ def _pretty_piece(self) -> str:
711770

712771

713772
@public
714-
class Set(Variadic):
773+
class Set(Variadic, Parametric):
715774
"""Set values."""
716775

717776
value_type = datatype
@@ -725,7 +784,7 @@ def _pretty_piece(self) -> str:
725784

726785

727786
@public
728-
class Map(Variadic):
787+
class Map(Variadic, Parametric):
729788
"""Associative array values."""
730789

731790
key_type = datatype
@@ -887,6 +946,21 @@ class INET(String):
887946

888947
Enum = String
889948

949+
950+
_python_dtypes = {
951+
bool: boolean,
952+
int: int64,
953+
float: float64,
954+
str: string,
955+
bytes: binary,
956+
pydatetime.date: date,
957+
pydatetime.time: time,
958+
pydatetime.datetime: timestamp,
959+
pydatetime.timedelta: interval,
960+
pydecimal.Decimal: decimal,
961+
pyuuid.UUID: uuid,
962+
}
963+
890964
_numpy_dtypes = {
891965
np.dtype("bool"): boolean,
892966
np.dtype("int8"): int8,
@@ -929,7 +1003,7 @@ class INET(String):
9291003

9301004

9311005
@dtype.register(np.dtype)
932-
def _(value):
1006+
def from_numpy_dtype(value):
9331007
try:
9341008
return _numpy_dtypes[value]
9351009
except KeyError:

ibis/expr/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def schema_from_pairs(lst):
333333
return Schema.from_tuples(lst)
334334

335335

336+
@schema.register(type)
337+
def schema_from_class(cls):
338+
return Schema(dt.dtype(cls))
339+
340+
336341
@schema.register(Iterable, Iterable)
337342
def schema_from_names_types(names, types):
338343
# validate lengths of names and types are the same

0 commit comments

Comments
 (0)