Skip to content

Add Polars Builtin Check Tests #1518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandera/api/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check":

@classmethod
def str_matches(cls, pattern: Union[str, re.Pattern], **kwargs) -> "Check":
"""Ensure that string values match a regular expression.
"""Ensure that strings start with regular expression match.

:param pattern: Regular expression pattern to use for matching
:param kwargs: key-word arguments passed into the `Check` initializer.
Expand Down
53 changes: 29 additions & 24 deletions pandera/backends/polars/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Built-in checks for polars."""

from typing import Any, TypeVar, Iterable, Union
from typing import Any, TypeVar, Iterable, Union, Optional

import re
import polars as pl


from pandera.api.extensions import register_builtin_check
from pandera.api.polars.types import PolarsData
from pandera.backends.polars.constants import CHECK_OUTPUT_KEY

T = TypeVar("T")

Expand Down Expand Up @@ -187,15 +186,17 @@ def str_matches(
data: PolarsData,
pattern: Union[str, re.Pattern],
) -> pl.LazyFrame:
"""Ensure that string values match a regular expression.
"""Ensure that string starts with a match of a regular expression pattern.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param pattern: Regular expression pattern to use for matching
"""

pattern = pattern.pattern if isinstance(pattern, re.Pattern) else pattern
if not pattern.startswith("^"):
pattern = f"^{pattern}"
return data.dataframe.select(
pl.col(data.key).str.contains(pattern=pattern).alias(CHECK_OUTPUT_KEY)
pl.col(data.key).str.contains(pattern=pattern)
)


Expand All @@ -204,18 +205,18 @@ def str_matches(
)
def str_contains(
data: PolarsData,
pattern: str,
pattern: re.Pattern,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be consistent with the public Check api: https://github.com/unionai-oss/pandera/blob/main/pandera/api/checks.py#L430

Suggested change
pattern: re.Pattern,
pattern: Union[str, re.Pattern],

) -> pl.LazyFrame:
"""Ensure that a pattern can be found within each row.
"""Ensure that a pattern can be found in the string.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param pattern: Regular expression pattern to use for searching
"""

pattern = pattern.pattern if isinstance(pattern, re.Pattern) else pattern
return data.dataframe.select(
pl.col(data.key)
.str.contains(pattern=pattern, literal=True)
.alias(CHECK_OUTPUT_KEY)
pl.col(data.key).str.contains(pattern=pattern, literal=False)
)


Expand Down Expand Up @@ -249,26 +250,30 @@ def str_endswith(data: PolarsData, string: str) -> pl.LazyFrame:
)
def str_length(
data: PolarsData,
min_value: int = None,
max_value: int = None,
min_value: Optional[int] = None,
max_value: Optional[int] = None,
) -> pl.LazyFrame:
"""Ensure that the length of strings is within a specified range.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param min_value: Minimum length of strings (default: no minimum)
:param max_value: Maximum length of strings (default: no maximum)
:param min_value: Minimum length of strings (including) (default: no minimum)
:param max_value: Maximum length of strings (including) (default: no maximum)
"""
# NOTE: consider using len_bytes (faster but returns != n_chars for non ASCII strings
n_chars = pl.col("string_col").str.n_chars()
is_in_min = (
n_chars.ge(min_value) if min_value is not None else pl.lit(True)
)
is_in_max = (
n_chars.le(max_value) if max_value is not None else pl.lit(True)
)

return data.dataframe.select(is_in_min.and_(is_in_max))
if min_value is None and max_value is None:
raise ValueError(
"Must provide at least on of 'min_value' and 'max_value'"
)

n_chars = pl.col(data.key).str.n_chars()
if min_value is None:
expr = n_chars.le(max_value)
elif max_value is None:
expr = n_chars.ge(min_value)
else:
expr = n_chars.is_between(min_value, max_value)

return data.dataframe.select(expr)


@register_builtin_check(
Expand Down
Loading