Skip to content

Commit 7186507

Browse files
authored
Bugfix/763 improve type annotations for DataFrameModel.validate (#1905)
* trial type annotations Signed-off-by: Matt Richards <[email protected]> * changes in individual api files Signed-off-by: Matt Richards <[email protected]> * pl.dataframe working in local test Signed-off-by: Matt Richards <[email protected]> * older python union compat Signed-off-by: Matt Richards <[email protected]> * try polars in the mypy env on ci Signed-off-by: Matt Richards <[email protected]> * translate toplevel mypy skip into module specific skips Signed-off-by: Matt Richards <[email protected]> * mypy passes Signed-off-by: Matt Richards <[email protected]> * missing line continuation Signed-off-by: Matt Richards <[email protected]> * python 3.8 Signed-off-by: Matt Richards <[email protected]> --------- Signed-off-by: Matt Richards <[email protected]>
1 parent 32b08fd commit 7186507

File tree

8 files changed

+101
-9
lines changed

8 files changed

+101
-9
lines changed

.github/workflows/ci-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ jobs:
5656
types-pytz \
5757
types-pyyaml \
5858
types-requests \
59-
types-setuptools
59+
types-setuptools \
60+
polars
6061
- name: Pip info
6162
run: python -m pip list
6263

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ repos:
5454
- types-pyyaml
5555
- types-requests
5656
- types-setuptools
57+
- polars
5758
args: ["pandera", "tests", "scripts"]
5859
exclude: (^docs/|^tests/mypy/modules/)
5960
pass_filenames: false

mypy.ini

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[mypy]
22
ignore_missing_imports = True
3-
follow_imports = skip
3+
follow_imports = normal
44
allow_redefinition = True
55
warn_return_any = False
66
warn_unused_configs = True
@@ -12,3 +12,17 @@ exclude=(?x)(
1212
| ^pandera/backends/pyspark
1313
| ^tests/pyspark
1414
)
15+
[mypy-pandera.api.pyspark.*]
16+
follow_imports = skip
17+
18+
[mypy-docs.*]
19+
follow_imports = skip
20+
21+
[mypy-pandera.engines.polars_engine]
22+
ignore_errors = True
23+
24+
[mypy-pandera.backends.polars.builtin_checks]
25+
ignore_errors = True
26+
27+
[mypy-tests.polars.*]
28+
ignore_errors = True

pandera/api/pandas/model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import copy
44
import sys
5-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
5+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
66

77
import pandas as pd
88

9+
from pandera.api.base.schema import BaseSchema
910
from pandera.api.checks import Check
1011
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
1112
from pandera.api.dataframe.model import get_dtype_kwargs
@@ -22,6 +23,7 @@
2223
AnnotationInfo,
2324
DataFrame,
2425
)
26+
from pandera.utils import docstring_substitution
2527

2628
# if python version is < 3.11, import Self from typing_extensions
2729
if sys.version_info < (3, 11):
@@ -171,6 +173,26 @@ def _build_columns_index( # pylint:disable=too-many-locals,too-many-branches
171173

172174
return columns, _build_schema_index(indices, **multiindex_kwargs)
173175

176+
@classmethod
177+
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
178+
def validate(
179+
cls: Type[Self],
180+
check_obj: pd.DataFrame,
181+
head: Optional[int] = None,
182+
tail: Optional[int] = None,
183+
sample: Optional[int] = None,
184+
random_state: Optional[int] = None,
185+
lazy: bool = False,
186+
inplace: bool = False,
187+
) -> DataFrame[Self]:
188+
"""%(validate_doc)s"""
189+
return cast(
190+
DataFrame[Self],
191+
cls.to_schema().validate(
192+
check_obj, head, tail, sample, random_state, lazy, inplace
193+
),
194+
)
195+
174196
@classmethod
175197
def to_json_schema(cls):
176198
"""Serialize schema metadata into json-schema format.

pandera/api/polars/components.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Column(ComponentSchema[PolarsCheckObjects]):
2323

2424
def __init__(
2525
self,
26-
dtype: PolarsDtypeInputTypes = None,
26+
dtype: Optional[PolarsDtypeInputTypes] = None,
2727
checks: Optional[CheckList] = None,
2828
nullable: bool = False,
2929
unique: bool = False,

pandera/api/polars/model.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Class-based api for polars models."""
22

33
import inspect
4-
from typing import Dict, List, Tuple, Type
4+
from typing import Dict, List, Tuple, Type, cast, Optional, overload, Union
5+
from typing_extensions import Self
56

67
import pandas as pd
78
import polars as pl
89

10+
from pandera.api.base.schema import BaseSchema
911
from pandera.api.checks import Check
1012
from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
1113
from pandera.api.dataframe.model import get_dtype_kwargs
@@ -16,7 +18,8 @@
1618
from pandera.engines import polars_engine as pe
1719
from pandera.errors import SchemaInitError
1820
from pandera.typing import AnnotationInfo
19-
from pandera.typing.polars import Series
21+
from pandera.typing.polars import Series, LazyFrame, DataFrame
22+
from pandera.utils import docstring_substitution
2023

2124

2225
class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
@@ -109,6 +112,53 @@ def _build_columns( # pylint:disable=too-many-locals
109112

110113
return columns
111114

115+
@classmethod
116+
@overload
117+
def validate(
118+
cls: Type[Self],
119+
check_obj: pl.DataFrame,
120+
head: Optional[int] = None,
121+
tail: Optional[int] = None,
122+
sample: Optional[int] = None,
123+
random_state: Optional[int] = None,
124+
lazy: bool = False,
125+
inplace: bool = False,
126+
) -> DataFrame[Self]: ...
127+
128+
@classmethod
129+
@overload
130+
def validate(
131+
cls: Type[Self],
132+
check_obj: pl.LazyFrame,
133+
head: Optional[int] = None,
134+
tail: Optional[int] = None,
135+
sample: Optional[int] = None,
136+
random_state: Optional[int] = None,
137+
lazy: bool = False,
138+
inplace: bool = False,
139+
) -> LazyFrame[Self]: ...
140+
141+
@classmethod
142+
@docstring_substitution(validate_doc=BaseSchema.validate.__doc__)
143+
def validate(
144+
cls: Type[Self],
145+
check_obj: Union[pl.LazyFrame, pl.DataFrame],
146+
head: Optional[int] = None,
147+
tail: Optional[int] = None,
148+
sample: Optional[int] = None,
149+
random_state: Optional[int] = None,
150+
lazy: bool = False,
151+
inplace: bool = False,
152+
) -> Union[LazyFrame[Self], DataFrame[Self]]:
153+
"""%(validate_doc)s"""
154+
result = cls.to_schema().validate(
155+
check_obj, head, tail, sample, random_state, lazy, inplace
156+
)
157+
if isinstance(check_obj, pl.LazyFrame):
158+
return cast(LazyFrame[Self], result)
159+
else:
160+
return cast(DataFrame[Self], result)
161+
112162
@classmethod
113163
def to_json_schema(cls):
114164
"""Serialize schema metadata into json-schema format.

pandera/api/pyspark/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pandera.errors import SchemaInitError
4242
from pandera.typing import AnnotationInfo
4343
from pandera.typing.common import DataFrameBase
44+
from pandera.typing.pyspark import DataFrame
4445

4546
try:
4647
from typing_extensions import get_type_hints
@@ -300,10 +301,10 @@ def validate(
300301
random_state: Optional[int] = None,
301302
lazy: bool = True,
302303
inplace: bool = False,
303-
) -> Optional[DataFrameBase[TDataFrameModel]]:
304+
) -> DataFrame[TDataFrameModel]:
304305
"""%(validate_doc)s"""
305306
return cast(
306-
DataFrameBase[TDataFrameModel],
307+
DataFrame[TDataFrameModel],
307308
cls.to_schema().validate(
308309
check_obj, head, tail, sample, random_state, lazy, inplace
309310
),

pandera/backends/polars/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def subsample(
4747
obj_subsample.append(check_obj.tail(tail))
4848
if sample is not None:
4949
obj_subsample.append(
50-
check_obj.sample(sample, random_state=random_state)
50+
# mypy is detecting a bug https://github.com/unionai-oss/pandera/issues/1912
51+
check_obj.sample( # type:ignore [attr-defined]
52+
sample, random_state=random_state
53+
)
5154
)
5255
return (
5356
check_obj

0 commit comments

Comments
 (0)