Skip to content

Add Builtin Checks for Polars Support #1408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 304 additions & 4 deletions pandera/backends/polars/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,324 @@
"""Built-in checks for polars."""

from typing import Any, Tuple
from typing import Any, TypeVar, Iterable

import re
import polars as pl


from pandera.api.extensions import register_builtin_check
from pandera.api.polars.types import PolarsData
from pandera.backends.polars.constants import CHECK_OUTPUT_KEY

T = TypeVar("T")


@register_builtin_check(
aliases=["eq"],
error="equal_to({value})",
)
def equal_to(data: PolarsData, value: Any) -> pl.LazyFrame:
"""Ensure all elements of a data container equal a certain value.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param value: values in this polars data structure must be
equal to this value.
"""
return data.dataframe.with_columns(
[pl.col(data.key).eq(value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["ne"],
error="not_equal_to({value})",
)
def not_equal_to(data: PolarsData, value: Any) -> pl.LazyFrame:
"""Ensure no elements of a data container equals a certain value.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param value: This value must not occur in the checked
"""
return data.dataframe.with_columns(
[pl.col(data.key).ne(value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["gt"],
error="greater_than({min_value})",
)
def greater_than(data: PolarsData, min_value: Any) -> pl.LazyFrame:
"""
Ensure values of a data container are strictly greater than a minimum
value.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param min_value: Lower bound to be exceeded. Must be
a type comparable to the dtype of the series datatype of Polars
"""
return data.dataframe.with_columns(
[pl.col(data.key).gt(min_value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["ge"],
error="greater_than_or_equal_to({min_value})",
)
def greater_than_or_equal_to(data: PolarsData, min_value: Any) -> pl.LazyFrame:
"""Ensure all elements of a data container equal a certain value.
"""Ensure all values are greater or equal a certain value.

:param value: values in this pandas data structure must be
equal to this value.
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param min_value: Allowed minimum value for values of a series. Must be
a type comparable to the dtype of the series datatype of Polars
"""
return data.dataframe.with_columns(
[pl.col(data.key).ge(min_value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["lt"],
error="less_than({max_value})",
)
def less_than(data: PolarsData, max_value: Any) -> pl.LazyFrame:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param max_value: All elements of a series must be strictly smaller
than this. Must be a type comparable to the dtype of the series datatype of Polars
"""
return data.dataframe.with_columns(
[pl.col(data.key).lt(max_value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["le"],
error="less_than_or_equal_to({max_value})",
)
def less_than_or_equal_to(data: PolarsData, max_value: Any) -> pl.LazyFrame:
"""Ensure values of a series are strictly below a maximum value.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param max_value: Upper bound not to be exceeded. Must be a type comparable to the dtype of the
series datatype of Polars
"""
return data.dataframe.with_columns(
[pl.col(data.key).le(max_value).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
aliases=["between"],
error="in_range({min_value}, {max_value})",
)
def in_range(
data: PolarsData,
min_value: T,
max_value: T,
include_min: bool = True,
include_max: bool = True,
) -> pl.LazyFrame:
"""Ensure all values of a series are within an interval.

Both endpoints must be a type comparable to the dtype of the
series datatype of Polars

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param min_value: Left / lower endpoint of the interval.
:param max_value: Right / upper endpoint of the interval. Must not be
smaller than min_value.
:param include_min: Defines whether min_value is also an allowed value
(the default) or whether all values must be strictly greater than
min_value.
:param include_max: Defines whether min_value is also an allowed value
(the default) or whether all values must be strictly smaller than
max_value.
"""
col = pl.col(data.key)
is_in_min = col.ge(min_value) if include_min else col.gt(min_value)
is_in_max = col.le(max_value) if include_max else col.lt(max_value)

return data.dataframe.with_columns(
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
error="isin({allowed_values})",
)
def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame:
"""Ensure only allowed values occur within a series.

This checks whether all elements of a :class:`polars.Series`
are part of the set of elements of allowed values. If allowed
values is a string, the set of elements consists of all distinct
characters of the string. Thus only single characters which occur
in allowed_values at least once can meet this condition. If you
want to check for substrings use :meth:`Check.str_contains`.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param allowed_values: The set of allowed values. May be any iterable.
"""
return data.dataframe.with_columns(
[pl.col(data.key).is_in(allowed_values).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
error="notin({forbidden_values})",
)
def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame:
"""Ensure some defined values don't occur within a series.

Like :meth:`Check.isin` this check operates on single characters if
it is applied on strings. If forbidden_values is a string, it is understood
as set of prohibited characters. Any string of length > 1 can't be in it by
design.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param forbidden_values: The set of values which should not occur. May
be any iterable.
"""
return data.dataframe.with_columns(
[
pl.col(data.key)
.is_in(forbidden_values)
.is_not()
.alias(CHECK_OUTPUT_KEY)
]
)


@register_builtin_check(
error="str_matches('{pattern}')",
)
def str_matches(
data: PolarsData,
pattern: str | re.Pattern,
) -> pl.LazyFrame:
"""Ensure that string values match a regular expression.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param pattern: Regular expression pattern to use for matching
"""

return data.dataframe.with_columns(
[
pl.col(data.key)
.str.contains(pattern=pattern)
.alias(CHECK_OUTPUT_KEY)
]
)


@register_builtin_check(
error="str_contains('{pattern}')",
)
def str_contains(
data: PolarsData,
pattern: str,
) -> pl.LazyFrame:
"""Ensure that a pattern can be found within each row.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param pattern: Regular expression pattern to use for searching
"""
return data.dataframe.with_columns(
[
pl.col(data.key)
.str.contains(pattern=pattern, literal=True)
.alias(CHECK_OUTPUT_KEY)
]
)


@register_builtin_check(
error="str_startswith('{string}')",
)
def str_startswith(data: PolarsData, string: str) -> pl.LazyFrame:
"""Ensure that all values start with a certain string.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param string: String all values should start with
"""

return data.dataframe.with_columns(
[pl.col(data.key).str.starts_with(string).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(error="str_endswith('{string}')")
def str_endswith(data: PolarsData, string: str) -> pl.LazyFrame:
"""Ensure that all values end with a certain string.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param string: String all values should end with
"""
return data.dataframe.with_columns(
[pl.col(data.key).str.ends_with(string).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
error="str_length({min_value}, {max_value})",
)
def str_length(
data: PolarsData,
min_value: int = None,
max_value: int = None,
) -> pl.LazyFrame:
"""Ensure that the length of strings is within a specified range.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param min_value: Minimum length of strings (default: no minimum)
:param max_value: Maximum length of strings (default: no maximum)
"""
# TODO: consider using len_bytes (faster but returns != n_chars for non ASCII strings
n_chars = pl.col("string_col").str.n_chars()
is_in_min = (
n_chars.ge(min_value) if min_value is not None else pl.lit(True)
)
is_in_max = (
n_chars.le(max_value) if max_value is not None else pl.lit(True)
)

return data.dataframe.with_columns(
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)]
)


@register_builtin_check(
error="unique_values_eq({values})",
)
def unique_values_eq(data: PolarsData, values: Iterable) -> bool:
"""Ensure that unique values in the data object contain all values.

.. note::
In contrast with :func:`isin`, this check makes sure that all the items
in the ``values`` iterable are contained within the series.

:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys
to access the dataframe is "dataframe" and column name using "key".
:param values: The set of values that must be present. Maybe any iterable.
"""

return (
set(data.dataframe.collect().get_column(data.key).unique()) == values
)
27 changes: 25 additions & 2 deletions pandera/backends/polars/checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Check backend for pandas."""

from functools import partial
from typing import Optional, Tuple
from typing import Optional

import polars as pl
from polars.lazyframe.group_by import LazyGroupBy

from multimethod import overload
from pandera.api.base.checks import CheckResult
from pandera.api.checks import Check
from pandera.api.polars.types import PolarsData
Expand Down Expand Up @@ -48,6 +48,14 @@ def apply(self, check_obj: PolarsData):
"""Apply the check function to a check object."""
return self.check_fn(check_obj)

@overload
def postprocess(self, check_obj, check_output):
"""Postprocesses the result of applying the check function."""
raise TypeError( # pragma: no cover
f"output type of check_fn not recognized: {type(check_output)}"
)

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj: PolarsData,
Expand All @@ -68,6 +76,21 @@ def postprocess(
failure_cases=failure_cases,
)

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj: PolarsData,
check_output: bool,
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
ldf_output = pl.LazyFrame({CHECK_OUTPUT_KEY: [check_output]})
return CheckResult(
check_output=ldf_output,
check_passed=ldf_output,
checked_object=check_obj,
failure_cases=None,
)

def __call__(
self,
check_obj: pl.LazyFrame,
Expand Down
2 changes: 1 addition & 1 deletion pandera/backends/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def check_dtype(

if schema.dtype is not None:
obj_dtype = check_obj.schema[schema.name]
passed = obj_dtype is schema.dtype
passed = obj_dtype.is_(schema.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because some dtypes objects are in instantiated in the schema, for instance Duration(time_unit="us"), the is comparison doesn't work here. So I used the is_ method provided by the Polars DataType API instead


if not passed:
failure_cases = str(obj_dtype)
Expand Down
3 changes: 2 additions & 1 deletion pandera/polars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A flexible and expressive polars validation library for Python."""

# pylint: disable=unused-import
from pandera.api.polars.components import Column
from pandera.api.polars.container import DataFrameSchema

import pandera.backends.polars
from pandera.api.checks import Check
Loading