Skip to content

Commit f2eb0fa

Browse files
committed
Bugfix/763 improve type annotations for DataFrameModel.validate (unionai-oss#1905)
Signed-off-by: Matt Richards <[email protected]> * trial type annotations Signed-off-by: Matt Richards <[email protected]> * changes in individual api files Signed-off-by: Matt Richards <[email protected]> * pl.dataframe working in local test Signed-off-by: Matt Richards <[email protected]> * older python union compat Signed-off-by: Matt Richards <[email protected]> * try polars in the mypy env on ci Signed-off-by: Matt Richards <[email protected]> * translate toplevel mypy skip into module specific skips Signed-off-by: Matt Richards <[email protected]> * mypy passes Signed-off-by: Matt Richards <[email protected]> * missing line continuation Signed-off-by: Matt Richards <[email protected]> * python 3.8 Signed-off-by: Matt Richards <[email protected]> --------- Signed-off-by: Matt Richards <[email protected]>
1 parent 3ab295f commit f2eb0fa

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

pandera/api/pandas/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import copy
44
import sys
5-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
5+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
66

77
import pandas as pd
88

9+
from pandera.api.base.schema import BaseSchema
910
from pandera.api.checks import Check
1011
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
1112
from pandera.api.dataframe.model import get_dtype_kwargs
@@ -22,6 +23,7 @@
2223
AnnotationInfo,
2324
DataFrame,
2425
)
26+
from pandera.utils import docstring_substitution
2527

2628
# if python version is < 3.11, import Self from typing_extensions
2729
if sys.version_info < (3, 11):
@@ -171,6 +173,26 @@ def _build_columns_index( # pylint:disable=too-many-locals,too-many-branches
171173

172174
return columns, _build_schema_index(indices, **multiindex_kwargs)
173175

176+
@classmethod
177+
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
178+
def validate(
179+
cls: Type[Self],
180+
check_obj: pd.DataFrame,
181+
head: Optional[int] = None,
182+
tail: Optional[int] = None,
183+
sample: Optional[int] = None,
184+
random_state: Optional[int] = None,
185+
lazy: bool = False,
186+
inplace: bool = False,
187+
) -> DataFrame[Self]:
188+
"""%(validate_doc)s"""
189+
return cast(
190+
DataFrame[Self],
191+
cls.to_schema().validate(
192+
check_obj, head, tail, sample, random_state, lazy, inplace
193+
),
194+
)
195+
174196
@classmethod
175197
def to_json_schema(cls):
176198
"""Serialize schema metadata into json-schema format.

pandera/api/polars/model.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Class-based api for polars models."""
22

33
import inspect
4-
from typing import Dict, List, Tuple, Type
4+
from typing import Dict, List, Tuple, Type, cast, Optional, overload, Union
5+
from typing_extensions import Self
56

67
import pandas as pd
78
import polars as pl
89

10+
from pandera.api.base.schema import BaseSchema
911
from pandera.api.checks import Check
1012
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
1113
from pandera.api.dataframe.model import get_dtype_kwargs
@@ -16,7 +18,8 @@
1618
from pandera.engines import polars_engine as pe
1719
from pandera.errors import SchemaInitError
1820
from pandera.typing import AnnotationInfo
19-
from pandera.typing.polars import Series
21+
from pandera.typing.polars import Series, LazyFrame, DataFrame
22+
from pandera.utils import docstring_substitution
2023

2124

2225
class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
@@ -109,6 +112,53 @@ def _build_columns( # pylint:disable=too-many-locals
109112

110113
return columns
111114

115+
@classmethod
116+
@overload
117+
def validate(
118+
cls: Type[Self],
119+
check_obj: pl.DataFrame,
120+
head: Optional[int] = None,
121+
tail: Optional[int] = None,
122+
sample: Optional[int] = None,
123+
random_state: Optional[int] = None,
124+
lazy: bool = False,
125+
inplace: bool = False,
126+
) -> DataFrame[Self]: ...
127+
128+
@classmethod
129+
@overload
130+
def validate(
131+
cls: Type[Self],
132+
check_obj: pl.LazyFrame,
133+
head: Optional[int] = None,
134+
tail: Optional[int] = None,
135+
sample: Optional[int] = None,
136+
random_state: Optional[int] = None,
137+
lazy: bool = False,
138+
inplace: bool = False,
139+
) -> LazyFrame[Self]: ...
140+
141+
@classmethod
142+
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
143+
def validate(
144+
cls: Type[Self],
145+
check_obj: Union[pl.LazyFrame, pl.DataFrame],
146+
head: Optional[int] = None,
147+
tail: Optional[int] = None,
148+
sample: Optional[int] = None,
149+
random_state: Optional[int] = None,
150+
lazy: bool = False,
151+
inplace: bool = False,
152+
) -> Union[LazyFrame[Self], DataFrame[Self]]:
153+
"""%(validate_doc)s"""
154+
result = cls.to_schema().validate(
155+
check_obj, head, tail, sample, random_state, lazy, inplace
156+
)
157+
if isinstance(check_obj, pl.LazyFrame):
158+
return cast(LazyFrame[Self], result)
159+
else:
160+
return cast(DataFrame[Self], result)
161+
112162
@classmethod
113163
def to_json_schema(cls):
114164
"""Serialize schema metadata into json-schema format.

pandera/api/pyspark/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pandera.errors import SchemaInitError
4242
from pandera.typing import AnnotationInfo
4343
from pandera.typing.common import DataFrameBase
44+
from pandera.typing.pyspark import DataFrame
4445

4546
try:
4647
from typing_extensions import get_type_hints
@@ -300,10 +301,10 @@ def validate(
300301
random_state: Optional[int] = None,
301302
lazy: bool = True,
302303
inplace: bool = False,
303-
) -> Optional[DataFrameBase[TDataFrameModel]]:
304+
) -> DataFrame[TDataFrameModel]:
304305
"""%(validate_doc)s"""
305306
return cast(
306-
DataFrameBase[TDataFrameModel],
307+
DataFrame[TDataFrameModel],
307308
cls.to_schema().validate(
308309
check_obj, head, tail, sample, random_state, lazy, inplace
309310
),

0 commit comments

Comments
 (0)