1
1
from __future__ import annotations
2
2
3
+ import datetime as pydatetime
4
+ import decimal as pydecimal
3
5
import numbers
6
+ import uuid as pyuuid
4
7
from abc import abstractmethod
5
- from collections .abc import Iterator , Mapping
8
+ from collections .abc import Iterator , Mapping , Sequence
9
+ from collections .abc import Set as PySet
10
+ from numbers import Integral , Real
6
11
from typing import Any , Iterable , NamedTuple
7
12
8
13
import numpy as np
14
+ import toolz
9
15
from multipledispatch import Dispatcher
10
16
from public import public
17
+ from typing_extensions import get_args , get_origin , get_type_hints
11
18
12
19
import ibis .expr .types as ir
13
20
from ibis .common .annotations import attribute , optional
14
- from ibis .common .exceptions import IbisTypeError
15
21
from ibis .common .grounds import Concrete , Singleton
16
22
from ibis .common .validators import (
17
23
all_of ,
22
28
validator ,
23
29
)
24
30
31
+ # TODO(kszucs): we don't support union types yet
32
+
25
33
dtype = Dispatcher ('dtype' )
26
34
27
35
28
36
@dtype .register (object )
29
37
def dtype_from_object (value , ** kwargs ) -> DataType :
30
- raise IbisTypeError (f'Value { value !r} is not a valid datatype' )
38
+ # TODO(kszucs): implement this in a @dtype.register(type) overload once dtype
39
+ # turned into a singledispatched function because that overload doesn't work
40
+ # with multipledispatch
41
+
42
+ # TODO(kszucs): support Tuple[int, str] and Tuple[int, ...] typehints
43
+ # in order to support more kinds of typehints follow the implementation of
44
+ # Validator.from_annotation
45
+ origin_type = get_origin (value )
46
+ if origin_type is None :
47
+ if issubclass (value , DataType ):
48
+ return value ()
49
+ elif result := _python_dtypes .get (value ):
50
+ return result
51
+ elif annots := get_type_hints (value ):
52
+ return Struct (toolz .valmap (dtype , annots ))
53
+ elif issubclass (value , bytes ):
54
+ return bytes
55
+ elif issubclass (value , str ):
56
+ return string
57
+ elif issubclass (value , Integral ):
58
+ return int64
59
+ elif issubclass (value , Real ):
60
+ return float64
61
+ elif value is type (None ):
62
+ return null
63
+ else :
64
+ raise TypeError (
65
+ f"Cannot construct an ibis datatype from python type { value !r} "
66
+ )
67
+ elif issubclass (origin_type , Sequence ):
68
+ (value_type ,) = map (dtype , get_args (value ))
69
+ return Array (value_type )
70
+ elif issubclass (origin_type , Mapping ):
71
+ key_type , value_type = map (dtype , get_args (value ))
72
+ return Map (key_type , value_type )
73
+ elif issubclass (origin_type , PySet ):
74
+ (value_type ,) = map (dtype , get_args (value ))
75
+ return Set (value_type )
76
+ else :
77
+ raise TypeError (f'Value { value !r} is not a valid datatype' )
31
78
32
79
33
80
@validator
@@ -239,11 +286,20 @@ class Primitive(DataType, Singleton):
239
286
"""Values with known size."""
240
287
241
288
289
+ # TODO(kszucs): consider to remove since we don't actually use this information
242
290
@public
243
291
class Variadic (DataType ):
244
292
"""Values with unknown size."""
245
293
246
294
295
+ @public
296
+ class Parametric (DataType ):
297
+ """Types that can be parameterized."""
298
+
299
+ def __class_getitem__ (cls , params ):
300
+ return cls (* params ) if isinstance (params , tuple ) else cls (params )
301
+
302
+
247
303
@public
248
304
class Null (Primitive ):
249
305
"""Null values."""
@@ -339,7 +395,7 @@ class Time(Temporal, Primitive):
339
395
340
396
341
397
@public
342
- class Timestamp (Temporal ):
398
+ class Timestamp (Temporal , Parametric ):
343
399
"""Timestamp values."""
344
400
345
401
timezone = optional (instance_of (str ))
@@ -487,7 +543,7 @@ class Float64(Floating):
487
543
488
544
489
545
@public
490
- class Decimal (Numeric ):
546
+ class Decimal (Numeric , Parametric ):
491
547
"""Fixed-precision decimal values."""
492
548
493
549
precision = optional (instance_of (int ))
@@ -551,7 +607,7 @@ def _pretty_piece(self) -> str:
551
607
552
608
553
609
@public
554
- class Interval (DataType ):
610
+ class Interval (Parametric ):
555
611
"""Interval values."""
556
612
557
613
__valid_units__ = {
@@ -617,7 +673,7 @@ def _pretty_piece(self) -> str:
617
673
618
674
619
675
@public
620
- class Category (DataType ):
676
+ class Category (Parametric ):
621
677
cardinality = optional (instance_of (int ))
622
678
623
679
scalar = ir .CategoryScalar
@@ -640,14 +696,17 @@ def to_integer_type(self):
640
696
641
697
642
698
@public
643
- class Struct (DataType , Mapping ):
699
+ class Struct (Parametric , Mapping ):
644
700
"""Structured values."""
645
701
646
702
fields = frozendict_of (instance_of (str ), datatype )
647
703
648
704
scalar = ir .StructScalar
649
705
column = ir .StructColumn
650
706
707
+ def __class_getitem__ (cls , fields ):
708
+ return cls ({slice_ .start : slice_ .stop for slice_ in fields })
709
+
651
710
@classmethod
652
711
def from_tuples (
653
712
cls , pairs : Iterable [tuple [str , str | DataType ]], nullable : bool = True
@@ -697,7 +756,7 @@ def _pretty_piece(self) -> str:
697
756
698
757
699
758
@public
700
- class Array (Variadic ):
759
+ class Array (Variadic , Parametric ):
701
760
"""Array values."""
702
761
703
762
value_type = datatype
@@ -711,7 +770,7 @@ def _pretty_piece(self) -> str:
711
770
712
771
713
772
@public
714
- class Set (Variadic ):
773
+ class Set (Variadic , Parametric ):
715
774
"""Set values."""
716
775
717
776
value_type = datatype
@@ -725,7 +784,7 @@ def _pretty_piece(self) -> str:
725
784
726
785
727
786
@public
728
- class Map (Variadic ):
787
+ class Map (Variadic , Parametric ):
729
788
"""Associative array values."""
730
789
731
790
key_type = datatype
@@ -887,6 +946,21 @@ class INET(String):
887
946
888
947
Enum = String
889
948
949
+
950
+ _python_dtypes = {
951
+ bool : boolean ,
952
+ int : int64 ,
953
+ float : float64 ,
954
+ str : string ,
955
+ bytes : binary ,
956
+ pydatetime .date : date ,
957
+ pydatetime .time : time ,
958
+ pydatetime .datetime : timestamp ,
959
+ pydatetime .timedelta : interval ,
960
+ pydecimal .Decimal : decimal ,
961
+ pyuuid .UUID : uuid ,
962
+ }
963
+
890
964
_numpy_dtypes = {
891
965
np .dtype ("bool" ): boolean ,
892
966
np .dtype ("int8" ): int8 ,
@@ -929,7 +1003,7 @@ class INET(String):
929
1003
930
1004
931
1005
@dtype .register (np .dtype )
932
- def _ (value ):
1006
+ def from_numpy_dtype (value ):
933
1007
try :
934
1008
return _numpy_dtypes [value ]
935
1009
except KeyError :
0 commit comments