Skip to content

Commit 40fb04d

Browse files
AndriiG13max-raphael
authored andcommitted
add builtin checks (unionai-oss#1408)
Signed-off-by: AndriiG13 <[email protected]>
1 parent 6bcc13a commit 40fb04d

File tree

6 files changed

+652
-16
lines changed

6 files changed

+652
-16
lines changed
Lines changed: 304 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,324 @@
11
"""Built-in checks for polars."""
22

3-
from typing import Any, Tuple
3+
from typing import Any, TypeVar, Iterable
44

5+
import re
56
import polars as pl
67

8+
79
from pandera.api.extensions import register_builtin_check
810
from pandera.api.polars.types import PolarsData
911
from pandera.backends.polars.constants import CHECK_OUTPUT_KEY
1012

13+
T = TypeVar("T")
14+
15+
16+
@register_builtin_check(
17+
aliases=["eq"],
18+
error="equal_to({value})",
19+
)
20+
def equal_to(data: PolarsData, value: Any) -> pl.LazyFrame:
21+
"""Ensure all elements of a data container equal a certain value.
22+
23+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
24+
to access the dataframe is "dataframe" and column name using "key".
25+
:param value: values in this polars data structure must be
26+
equal to this value.
27+
"""
28+
return data.dataframe.with_columns(
29+
[pl.col(data.key).eq(value).alias(CHECK_OUTPUT_KEY)]
30+
)
31+
32+
33+
@register_builtin_check(
34+
aliases=["ne"],
35+
error="not_equal_to({value})",
36+
)
37+
def not_equal_to(data: PolarsData, value: Any) -> pl.LazyFrame:
38+
"""Ensure no elements of a data container equals a certain value.
39+
40+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
41+
to access the dataframe is "dataframe" and column name using "key".
42+
:param value: This value must not occur in the checked
43+
"""
44+
return data.dataframe.with_columns(
45+
[pl.col(data.key).ne(value).alias(CHECK_OUTPUT_KEY)]
46+
)
47+
48+
49+
@register_builtin_check(
50+
aliases=["gt"],
51+
error="greater_than({min_value})",
52+
)
53+
def greater_than(data: PolarsData, min_value: Any) -> pl.LazyFrame:
54+
"""
55+
Ensure values of a data container are strictly greater than a minimum
56+
value.
57+
58+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
59+
to access the dataframe is "dataframe" and column name using "key".
60+
:param min_value: Lower bound to be exceeded. Must be
61+
a type comparable to the dtype of the series datatype of Polars
62+
"""
63+
return data.dataframe.with_columns(
64+
[pl.col(data.key).gt(min_value).alias(CHECK_OUTPUT_KEY)]
65+
)
66+
1167

1268
@register_builtin_check(
1369
aliases=["ge"],
1470
error="greater_than_or_equal_to({min_value})",
1571
)
1672
def greater_than_or_equal_to(data: PolarsData, min_value: Any) -> pl.LazyFrame:
17-
"""Ensure all elements of a data container equal a certain value.
73+
"""Ensure all values are greater or equal a certain value.
1874
19-
:param value: values in this pandas data structure must be
20-
equal to this value.
75+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
76+
to access the dataframe is "dataframe" and column name using "key".
77+
:param min_value: Allowed minimum value for values of a series. Must be
78+
a type comparable to the dtype of the series datatype of Polars
2179
"""
2280
return data.dataframe.with_columns(
2381
[pl.col(data.key).ge(min_value).alias(CHECK_OUTPUT_KEY)]
2482
)
83+
84+
85+
@register_builtin_check(
86+
aliases=["lt"],
87+
error="less_than({max_value})",
88+
)
89+
def less_than(data: PolarsData, max_value: Any) -> pl.LazyFrame:
90+
"""Ensure values of a series are strictly below a maximum value.
91+
92+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
93+
to access the dataframe is "dataframe" and column name using "key".
94+
:param max_value: All elements of a series must be strictly smaller
95+
than this. Must be a type comparable to the dtype of the series datatype of Polars
96+
"""
97+
return data.dataframe.with_columns(
98+
[pl.col(data.key).lt(max_value).alias(CHECK_OUTPUT_KEY)]
99+
)
100+
101+
102+
@register_builtin_check(
103+
aliases=["le"],
104+
error="less_than_or_equal_to({max_value})",
105+
)
106+
def less_than_or_equal_to(data: PolarsData, max_value: Any) -> pl.LazyFrame:
107+
"""Ensure values of a series are strictly below a maximum value.
108+
109+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
110+
to access the dataframe is "dataframe" and column name using "key".
111+
:param max_value: Upper bound not to be exceeded. Must be a type comparable to the dtype of the
112+
series datatype of Polars
113+
"""
114+
return data.dataframe.with_columns(
115+
[pl.col(data.key).le(max_value).alias(CHECK_OUTPUT_KEY)]
116+
)
117+
118+
119+
@register_builtin_check(
120+
aliases=["between"],
121+
error="in_range({min_value}, {max_value})",
122+
)
123+
def in_range(
124+
data: PolarsData,
125+
min_value: T,
126+
max_value: T,
127+
include_min: bool = True,
128+
include_max: bool = True,
129+
) -> pl.LazyFrame:
130+
"""Ensure all values of a series are within an interval.
131+
132+
Both endpoints must be a type comparable to the dtype of the
133+
series datatype of Polars
134+
135+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
136+
to access the dataframe is "dataframe" and column name using "key".
137+
:param min_value: Left / lower endpoint of the interval.
138+
:param max_value: Right / upper endpoint of the interval. Must not be
139+
smaller than min_value.
140+
:param include_min: Defines whether min_value is also an allowed value
141+
(the default) or whether all values must be strictly greater than
142+
min_value.
143+
:param include_max: Defines whether min_value is also an allowed value
144+
(the default) or whether all values must be strictly smaller than
145+
max_value.
146+
"""
147+
col = pl.col(data.key)
148+
is_in_min = col.ge(min_value) if include_min else col.gt(min_value)
149+
is_in_max = col.le(max_value) if include_max else col.lt(max_value)
150+
151+
return data.dataframe.with_columns(
152+
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)]
153+
)
154+
155+
156+
@register_builtin_check(
157+
error="isin({allowed_values})",
158+
)
159+
def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame:
160+
"""Ensure only allowed values occur within a series.
161+
162+
This checks whether all elements of a :class:`polars.Series`
163+
are part of the set of elements of allowed values. If allowed
164+
values is a string, the set of elements consists of all distinct
165+
characters of the string. Thus only single characters which occur
166+
in allowed_values at least once can meet this condition. If you
167+
want to check for substrings use :meth:`Check.str_contains`.
168+
169+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
170+
to access the dataframe is "dataframe" and column name using "key".
171+
:param allowed_values: The set of allowed values. May be any iterable.
172+
"""
173+
return data.dataframe.with_columns(
174+
[pl.col(data.key).is_in(allowed_values).alias(CHECK_OUTPUT_KEY)]
175+
)
176+
177+
178+
@register_builtin_check(
179+
error="notin({forbidden_values})",
180+
)
181+
def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame:
182+
"""Ensure some defined values don't occur within a series.
183+
184+
Like :meth:`Check.isin` this check operates on single characters if
185+
it is applied on strings. If forbidden_values is a string, it is understood
186+
as set of prohibited characters. Any string of length > 1 can't be in it by
187+
design.
188+
189+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
190+
to access the dataframe is "dataframe" and column name using "key".
191+
:param forbidden_values: The set of values which should not occur. May
192+
be any iterable.
193+
"""
194+
return data.dataframe.with_columns(
195+
[
196+
pl.col(data.key)
197+
.is_in(forbidden_values)
198+
.is_not()
199+
.alias(CHECK_OUTPUT_KEY)
200+
]
201+
)
202+
203+
204+
@register_builtin_check(
205+
error="str_matches('{pattern}')",
206+
)
207+
def str_matches(
208+
data: PolarsData,
209+
pattern: str | re.Pattern,
210+
) -> pl.LazyFrame:
211+
"""Ensure that string values match a regular expression.
212+
213+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
214+
to access the dataframe is "dataframe" and column name using "key".
215+
:param pattern: Regular expression pattern to use for matching
216+
"""
217+
218+
return data.dataframe.with_columns(
219+
[
220+
pl.col(data.key)
221+
.str.contains(pattern=pattern)
222+
.alias(CHECK_OUTPUT_KEY)
223+
]
224+
)
225+
226+
227+
@register_builtin_check(
228+
error="str_contains('{pattern}')",
229+
)
230+
def str_contains(
231+
data: PolarsData,
232+
pattern: str,
233+
) -> pl.LazyFrame:
234+
"""Ensure that a pattern can be found within each row.
235+
236+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
237+
to access the dataframe is "dataframe" and column name using "key".
238+
:param pattern: Regular expression pattern to use for searching
239+
"""
240+
return data.dataframe.with_columns(
241+
[
242+
pl.col(data.key)
243+
.str.contains(pattern=pattern, literal=True)
244+
.alias(CHECK_OUTPUT_KEY)
245+
]
246+
)
247+
248+
249+
@register_builtin_check(
250+
error="str_startswith('{string}')",
251+
)
252+
def str_startswith(data: PolarsData, string: str) -> pl.LazyFrame:
253+
"""Ensure that all values start with a certain string.
254+
255+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
256+
to access the dataframe is "dataframe" and column name using "key".
257+
:param string: String all values should start with
258+
"""
259+
260+
return data.dataframe.with_columns(
261+
[pl.col(data.key).str.starts_with(string).alias(CHECK_OUTPUT_KEY)]
262+
)
263+
264+
265+
@register_builtin_check(error="str_endswith('{string}')")
266+
def str_endswith(data: PolarsData, string: str) -> pl.LazyFrame:
267+
"""Ensure that all values end with a certain string.
268+
269+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
270+
to access the dataframe is "dataframe" and column name using "key".
271+
:param string: String all values should end with
272+
"""
273+
return data.dataframe.with_columns(
274+
[pl.col(data.key).str.ends_with(string).alias(CHECK_OUTPUT_KEY)]
275+
)
276+
277+
278+
@register_builtin_check(
279+
error="str_length({min_value}, {max_value})",
280+
)
281+
def str_length(
282+
data: PolarsData,
283+
min_value: int = None,
284+
max_value: int = None,
285+
) -> pl.LazyFrame:
286+
"""Ensure that the length of strings is within a specified range.
287+
288+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
289+
to access the dataframe is "dataframe" and column name using "key".
290+
:param min_value: Minimum length of strings (default: no minimum)
291+
:param max_value: Maximum length of strings (default: no maximum)
292+
"""
293+
# TODO: consider using len_bytes (faster but returns != n_chars for non ASCII strings
294+
n_chars = pl.col("string_col").str.n_chars()
295+
is_in_min = (
296+
n_chars.ge(min_value) if min_value is not None else pl.lit(True)
297+
)
298+
is_in_max = (
299+
n_chars.le(max_value) if max_value is not None else pl.lit(True)
300+
)
301+
302+
return data.dataframe.with_columns(
303+
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)]
304+
)
305+
306+
307+
@register_builtin_check(
308+
error="unique_values_eq({values})",
309+
)
310+
def unique_values_eq(data: PolarsData, values: Iterable) -> bool:
311+
"""Ensure that unique values in the data object contain all values.
312+
313+
.. note::
314+
In contrast with :func:`isin`, this check makes sure that all the items
315+
in the ``values`` iterable are contained within the series.
316+
317+
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
318+
to access the dataframe is "dataframe" and column name using "key".
319+
:param values: The set of values that must be present. Maybe any iterable.
320+
"""
321+
322+
return (
323+
set(data.dataframe.collect().get_column(data.key).unique()) == values
324+
)

pandera/backends/polars/checks.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Check backend for pandas."""
22

33
from functools import partial
4-
from typing import Optional, Tuple
4+
from typing import Optional
55

66
import polars as pl
77
from polars.lazyframe.group_by import LazyGroupBy
8-
8+
from multimethod import overload
99
from pandera.api.base.checks import CheckResult
1010
from pandera.api.checks import Check
1111
from pandera.api.polars.types import PolarsData
@@ -48,6 +48,14 @@ def apply(self, check_obj: PolarsData):
4848
"""Apply the check function to a check object."""
4949
return self.check_fn(check_obj)
5050

51+
@overload
52+
def postprocess(self, check_obj, check_output):
53+
"""Postprocesses the result of applying the check function."""
54+
raise TypeError( # pragma: no cover
55+
f"output type of check_fn not recognized: {type(check_output)}"
56+
)
57+
58+
@overload # type: ignore [no-redef]
5159
def postprocess(
5260
self,
5361
check_obj: PolarsData,
@@ -68,6 +76,21 @@ def postprocess(
6876
failure_cases=failure_cases,
6977
)
7078

79+
@overload # type: ignore [no-redef]
80+
def postprocess(
81+
self,
82+
check_obj: PolarsData,
83+
check_output: bool,
84+
) -> CheckResult:
85+
"""Postprocesses the result of applying the check function."""
86+
ldf_output = pl.LazyFrame({CHECK_OUTPUT_KEY: [check_output]})
87+
return CheckResult(
88+
check_output=ldf_output,
89+
check_passed=ldf_output,
90+
checked_object=check_obj,
91+
failure_cases=None,
92+
)
93+
7194
def __call__(
7295
self,
7396
check_obj: pl.LazyFrame,

pandera/backends/polars/components.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def check_dtype(
8787

8888
if schema.dtype is not None:
8989
obj_dtype = check_obj.schema[schema.name]
90-
passed = obj_dtype is schema.dtype
90+
passed = obj_dtype.is_(schema.dtype)
9191

9292
if not passed:
9393
failure_cases = str(obj_dtype)

pandera/polars.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A flexible and expressive polars validation library for Python."""
2-
2+
# pylint: disable=unused-import
33
from pandera.api.polars.components import Column
44
from pandera.api.polars.container import DataFrameSchema
55

66
import pandera.backends.polars
7+
from pandera.api.checks import Check

0 commit comments

Comments
 (0)