|
1 | 1 | """Class-based api for polars models."""
|
2 | 2 |
|
3 | 3 | import inspect
|
4 |
| -from typing import Dict, List, Tuple, Type |
| 4 | +from typing import Dict, List, Tuple, Type, cast, Optional, overload, Union |
| 5 | +from typing_extensions import Self |
5 | 6 |
|
6 | 7 | import pandas as pd
|
7 | 8 | import polars as pl
|
8 | 9 |
|
| 10 | +from pandera.api.base.schema import BaseSchema |
9 | 11 | from pandera.api.checks import Check
|
10 | 12 | from pandera.api.dataframe.model import DataFrameModel as _DataFrameModel
|
11 | 13 | from pandera.api.dataframe.model import get_dtype_kwargs
|
|
16 | 18 | from pandera.engines import polars_engine as pe
|
17 | 19 | from pandera.errors import SchemaInitError
|
18 | 20 | from pandera.typing import AnnotationInfo
|
19 |
| -from pandera.typing.polars import Series |
| 21 | +from pandera.typing.polars import Series, LazyFrame, DataFrame |
| 22 | +from pandera.utils import docstring_substitution |
20 | 23 |
|
21 | 24 |
|
22 | 25 | class DataFrameModel(_DataFrameModel[pl.LazyFrame, DataFrameSchema]):
|
@@ -109,6 +112,53 @@ def _build_columns( # pylint:disable=too-many-locals
|
109 | 112 |
|
110 | 113 | return columns
|
111 | 114 |
|
| 115 | + @classmethod |
| 116 | + @overload |
| 117 | + def validate( |
| 118 | + cls: Type[Self], |
| 119 | + check_obj: pl.DataFrame, |
| 120 | + head: Optional[int] = None, |
| 121 | + tail: Optional[int] = None, |
| 122 | + sample: Optional[int] = None, |
| 123 | + random_state: Optional[int] = None, |
| 124 | + lazy: bool = False, |
| 125 | + inplace: bool = False, |
| 126 | + ) -> DataFrame[Self]: ... |
| 127 | + |
| 128 | + @classmethod |
| 129 | + @overload |
| 130 | + def validate( |
| 131 | + cls: Type[Self], |
| 132 | + check_obj: pl.LazyFrame, |
| 133 | + head: Optional[int] = None, |
| 134 | + tail: Optional[int] = None, |
| 135 | + sample: Optional[int] = None, |
| 136 | + random_state: Optional[int] = None, |
| 137 | + lazy: bool = False, |
| 138 | + inplace: bool = False, |
| 139 | + ) -> LazyFrame[Self]: ... |
| 140 | + |
| 141 | + @classmethod |
| 142 | + @docstring_substitution(validate_doc=BaseSchema.validate.__doc__) |
| 143 | + def validate( |
| 144 | + cls: Type[Self], |
| 145 | + check_obj: Union[pl.LazyFrame, pl.DataFrame], |
| 146 | + head: Optional[int] = None, |
| 147 | + tail: Optional[int] = None, |
| 148 | + sample: Optional[int] = None, |
| 149 | + random_state: Optional[int] = None, |
| 150 | + lazy: bool = False, |
| 151 | + inplace: bool = False, |
| 152 | + ) -> Union[LazyFrame[Self], DataFrame[Self]]: |
| 153 | + """%(validate_doc)s""" |
| 154 | + result = cls.to_schema().validate( |
| 155 | + check_obj, head, tail, sample, random_state, lazy, inplace |
| 156 | + ) |
| 157 | + if isinstance(check_obj, pl.LazyFrame): |
| 158 | + return cast(LazyFrame[Self], result) |
| 159 | + else: |
| 160 | + return cast(DataFrame[Self], result) |
| 161 | + |
112 | 162 | @classmethod
|
113 | 163 | def to_json_schema(cls):
|
114 | 164 | """Serialize schema metadata into json-schema format.
|
|
0 commit comments