1
1
"""Class-based api for polars models."""
2
2
3
3
from typing import (
4
- Any ,
5
4
Dict ,
6
5
List ,
7
6
Tuple ,
19
18
from pandera .api .polars .container import DataFrameSchema
20
19
from pandera .api .polars .components import Column
21
20
from pandera .api .polars .model_config import BaseConfig
21
+ from pandera .engines import polars_engine as pe
22
22
from pandera .errors import SchemaInitError
23
23
from pandera .typing import AnnotationInfo
24
+ from pandera .typing .polars import Series
24
25
25
26
26
27
class DataFrameModel (_DataFrameModel [pl .LazyFrame , DataFrameSchema ]):
@@ -52,24 +53,30 @@ def _build_columns( # pylint:disable=too-many-locals
52
53
field_name = field .name
53
54
check_name = getattr (field , "check_name" , None )
54
55
55
- if annotation .metadata :
56
- if field .dtype_kwargs :
57
- raise TypeError (
58
- "Cannot specify redundant 'dtype_kwargs' "
59
- + f"for { annotation .raw_annotation } ."
60
- + "\n Usage Tip: Drop 'typing.Annotated'."
61
- )
62
- dtype_kwargs = get_dtype_kwargs (annotation )
63
- dtype = annotation .arg (** dtype_kwargs ) # type: ignore
64
- elif annotation .default_dtype :
65
- dtype = annotation .default_dtype
66
- else :
67
- dtype = annotation .arg
68
-
69
- dtype = None if dtype is Any else dtype
70
-
71
- if annotation .origin is None or isinstance (
72
- annotation .origin , pl .datatypes .DataTypeClass
56
+ engine_dtype = None
57
+ try :
58
+ engine_dtype = pe .Engine .dtype (annotation .raw_annotation )
59
+ dtype = engine_dtype .type
60
+ except TypeError as exc :
61
+ if annotation .metadata :
62
+ if field .dtype_kwargs :
63
+ raise TypeError (
64
+ "Cannot specify redundant 'dtype_kwargs' "
65
+ + f"for { annotation .raw_annotation } ."
66
+ + "\n Usage Tip: Drop 'typing.Annotated'."
67
+ ) from exc
68
+ dtype_kwargs = get_dtype_kwargs (annotation )
69
+ dtype = annotation .arg (** dtype_kwargs ) # type: ignore
70
+ elif annotation .default_dtype :
71
+ dtype = annotation .default_dtype
72
+ else :
73
+ dtype = annotation .arg
74
+
75
+ if (
76
+ annotation .origin is None
77
+ or isinstance (annotation .origin , pl .datatypes .DataTypeClass )
78
+ or annotation .origin is Series
79
+ or engine_dtype
73
80
):
74
81
if check_name is False :
75
82
raise SchemaInitError (
@@ -89,19 +96,9 @@ def _build_columns( # pylint:disable=too-many-locals
89
96
columns [field_name ] = Column (** column_kwargs )
90
97
91
98
else :
92
- origin_name = (
93
- f"{ annotation .origin .__module__ } ."
94
- f"{ annotation .origin .__name__ } "
95
- )
96
- msg = (
97
- " Series[TYPE] annotations are not supported for polars. "
98
- "Use the bare TYPE directly"
99
- if origin_name == "pandera.typing.pandas.Series"
100
- else ""
101
- )
102
99
raise SchemaInitError (
103
100
f"Invalid annotation '{ field_name } : "
104
- f"{ annotation .raw_annotation } '.{ msg } "
101
+ f"{ annotation .raw_annotation } '."
105
102
)
106
103
107
104
return columns
0 commit comments