Skip to content

Commit dee3ab7

Browse files
AndriiG13cosmicBboy
authored andcommitted
add polars builtin check tests
Signed-off-by: Andrii G <[email protected]>
1 parent 0a24f3e commit dee3ab7

File tree

2 files changed

+1217
-60
lines changed

2 files changed

+1217
-60
lines changed

pandera/backends/polars/builtin_checks.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Built-in checks for polars."""
22

3-
from typing import Any, TypeVar, Iterable, Union
3+
from typing import Any, TypeVar, Iterable, Union, Optional
44

55
import re
66
import polars as pl
77

88

99
from pandera.api.extensions import register_builtin_check
1010
from pandera.api.polars.types import PolarsData
11-
from pandera.backends.polars.constants import CHECK_OUTPUT_KEY
1211

1312
T = TypeVar("T")
1413

@@ -180,42 +179,24 @@ def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame:
180179
)
181180

182181

183-
@register_builtin_check(
184-
error="str_matches('{pattern}')",
185-
)
186-
def str_matches(
187-
data: PolarsData,
188-
pattern: Union[str, re.Pattern],
189-
) -> pl.LazyFrame:
190-
"""Ensure that string values match a regular expression.
191-
192-
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
193-
to access the dataframe is "dataframe" and column name using "key".
194-
:param pattern: Regular expression pattern to use for matching
195-
"""
196-
197-
return data.dataframe.select(
198-
pl.col(data.key).str.contains(pattern=pattern).alias(CHECK_OUTPUT_KEY)
199-
)
200-
201-
202182
@register_builtin_check(
203183
error="str_contains('{pattern}')",
204184
)
205185
def str_contains(
206186
data: PolarsData,
207-
pattern: str,
187+
pattern: re.Pattern,
208188
) -> pl.LazyFrame:
209189
"""Ensure that a pattern can be found within each row.
210190
211191
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
212192
to access the dataframe is "dataframe" and column name using "key".
213193
:param pattern: Regular expression pattern to use for searching
214194
"""
195+
215196
return data.dataframe.select(
216-
pl.col(data.key)
217-
.str.contains(pattern=pattern, literal=True)
218-
.alias(CHECK_OUTPUT_KEY)
197+
pl.col(data.key).str.contains(
198+
pattern=f"{pattern.pattern}", literal=False
199+
)
219200
)
220201

221202

@@ -249,26 +230,26 @@ def str_endswith(data: PolarsData, string: str) -> pl.LazyFrame:
249230
)
250231
def str_length(
251232
data: PolarsData,
252-
min_value: int = None,
253-
max_value: int = None,
233+
min_value: Optional[int] = None,
234+
max_value: Optional[int] = None,
254235
) -> pl.LazyFrame:
255236
"""Ensure that the length of strings is within a specified range.
256237
257238
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
258239
to access the dataframe is "dataframe" and column name using "key".
259-
:param min_value: Minimum length of strings (default: no minimum)
260-
:param max_value: Maximum length of strings (default: no maximum)
240+
:param min_value: Minimum length of strings (including) (default: no minimum)
241+
:param max_value: Maximum length of strings (including) (default: no maximum)
261242
"""
262243
# NOTE: consider using len_bytes (faster but returns != n_chars for non ASCII strings
263-
n_chars = pl.col("string_col").str.n_chars()
244+
n_chars = pl.col(data.key).str.n_chars()
264245
is_in_min = (
265246
n_chars.ge(min_value) if min_value is not None else pl.lit(True)
266247
)
267248
is_in_max = (
268249
n_chars.le(max_value) if max_value is not None else pl.lit(True)
269250
)
270251

271-
return data.dataframe.select(is_in_min.and_(is_in_max))
252+
return data.dataframe.select(is_in_min.and_(is_in_max).alias(data.key))
272253

273254

274255
@register_builtin_check(

0 commit comments

Comments
 (0)