-
-
Notifications
You must be signed in to change notification settings - Fork 336
Add Builtin Checks for Polars Support #1408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
cosmicBboy
merged 1 commit into
unionai-oss:polars-dev
from
AndriiG13:polars-dev-add-builtin-checks
Nov 10, 2023
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,324 @@ | ||
"""Built-in checks for polars.""" | ||
|
||
from typing import Any, Tuple | ||
from typing import Any, TypeVar, Iterable | ||
|
||
import re | ||
import polars as pl | ||
|
||
|
||
from pandera.api.extensions import register_builtin_check | ||
from pandera.api.polars.types import PolarsData | ||
from pandera.backends.polars.constants import CHECK_OUTPUT_KEY | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["eq"], | ||
error="equal_to({value})", | ||
) | ||
def equal_to(data: PolarsData, value: Any) -> pl.LazyFrame: | ||
"""Ensure all elements of a data container equal a certain value. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param value: values in this polars data structure must be | ||
equal to this value. | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).eq(value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["ne"], | ||
error="not_equal_to({value})", | ||
) | ||
def not_equal_to(data: PolarsData, value: Any) -> pl.LazyFrame: | ||
"""Ensure no elements of a data container equals a certain value. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param value: This value must not occur in the checked | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).ne(value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["gt"], | ||
error="greater_than({min_value})", | ||
) | ||
def greater_than(data: PolarsData, min_value: Any) -> pl.LazyFrame: | ||
""" | ||
Ensure values of a data container are strictly greater than a minimum | ||
value. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param min_value: Lower bound to be exceeded. Must be | ||
a type comparable to the dtype of the series datatype of Polars | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).gt(min_value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["ge"], | ||
error="greater_than_or_equal_to({min_value})", | ||
) | ||
def greater_than_or_equal_to(data: PolarsData, min_value: Any) -> pl.LazyFrame: | ||
"""Ensure all elements of a data container equal a certain value. | ||
"""Ensure all values are greater or equal a certain value. | ||
|
||
:param value: values in this pandas data structure must be | ||
equal to this value. | ||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param min_value: Allowed minimum value for values of a series. Must be | ||
a type comparable to the dtype of the series datatype of Polars | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).ge(min_value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["lt"], | ||
error="less_than({max_value})", | ||
) | ||
def less_than(data: PolarsData, max_value: Any) -> pl.LazyFrame: | ||
"""Ensure values of a series are strictly below a maximum value. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param max_value: All elements of a series must be strictly smaller | ||
than this. Must be a type comparable to the dtype of the series datatype of Polars | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).lt(max_value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["le"], | ||
error="less_than_or_equal_to({max_value})", | ||
) | ||
def less_than_or_equal_to(data: PolarsData, max_value: Any) -> pl.LazyFrame: | ||
"""Ensure values of a series are strictly below a maximum value. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param max_value: Upper bound not to be exceeded. Must be a type comparable to the dtype of the | ||
series datatype of Polars | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).le(max_value).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
aliases=["between"], | ||
error="in_range({min_value}, {max_value})", | ||
) | ||
def in_range( | ||
data: PolarsData, | ||
min_value: T, | ||
max_value: T, | ||
include_min: bool = True, | ||
include_max: bool = True, | ||
) -> pl.LazyFrame: | ||
"""Ensure all values of a series are within an interval. | ||
|
||
Both endpoints must be a type comparable to the dtype of the | ||
series datatype of Polars | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param min_value: Left / lower endpoint of the interval. | ||
:param max_value: Right / upper endpoint of the interval. Must not be | ||
smaller than min_value. | ||
:param include_min: Defines whether min_value is also an allowed value | ||
(the default) or whether all values must be strictly greater than | ||
min_value. | ||
:param include_max: Defines whether min_value is also an allowed value | ||
(the default) or whether all values must be strictly smaller than | ||
max_value. | ||
""" | ||
col = pl.col(data.key) | ||
is_in_min = col.ge(min_value) if include_min else col.gt(min_value) | ||
is_in_max = col.le(max_value) if include_max else col.lt(max_value) | ||
|
||
return data.dataframe.with_columns( | ||
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="isin({allowed_values})", | ||
) | ||
def isin(data: PolarsData, allowed_values: Iterable) -> pl.LazyFrame: | ||
"""Ensure only allowed values occur within a series. | ||
|
||
This checks whether all elements of a :class:`polars.Series` | ||
are part of the set of elements of allowed values. If allowed | ||
values is a string, the set of elements consists of all distinct | ||
characters of the string. Thus only single characters which occur | ||
in allowed_values at least once can meet this condition. If you | ||
want to check for substrings use :meth:`Check.str_contains`. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param allowed_values: The set of allowed values. May be any iterable. | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).is_in(allowed_values).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="notin({forbidden_values})", | ||
) | ||
def notin(data: PolarsData, forbidden_values: Iterable) -> pl.LazyFrame: | ||
"""Ensure some defined values don't occur within a series. | ||
|
||
Like :meth:`Check.isin` this check operates on single characters if | ||
it is applied on strings. If forbidden_values is a string, it is understood | ||
as set of prohibited characters. Any string of length > 1 can't be in it by | ||
design. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param forbidden_values: The set of values which should not occur. May | ||
be any iterable. | ||
""" | ||
return data.dataframe.with_columns( | ||
[ | ||
pl.col(data.key) | ||
.is_in(forbidden_values) | ||
.is_not() | ||
.alias(CHECK_OUTPUT_KEY) | ||
] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="str_matches('{pattern}')", | ||
) | ||
def str_matches( | ||
data: PolarsData, | ||
pattern: str | re.Pattern, | ||
) -> pl.LazyFrame: | ||
"""Ensure that string values match a regular expression. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param pattern: Regular expression pattern to use for matching | ||
""" | ||
|
||
return data.dataframe.with_columns( | ||
[ | ||
pl.col(data.key) | ||
.str.contains(pattern=pattern) | ||
.alias(CHECK_OUTPUT_KEY) | ||
] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="str_contains('{pattern}')", | ||
) | ||
def str_contains( | ||
data: PolarsData, | ||
pattern: str, | ||
) -> pl.LazyFrame: | ||
"""Ensure that a pattern can be found within each row. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param pattern: Regular expression pattern to use for searching | ||
""" | ||
return data.dataframe.with_columns( | ||
[ | ||
pl.col(data.key) | ||
.str.contains(pattern=pattern, literal=True) | ||
.alias(CHECK_OUTPUT_KEY) | ||
] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="str_startswith('{string}')", | ||
) | ||
def str_startswith(data: PolarsData, string: str) -> pl.LazyFrame: | ||
"""Ensure that all values start with a certain string. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param string: String all values should start with | ||
""" | ||
|
||
return data.dataframe.with_columns( | ||
[pl.col(data.key).str.starts_with(string).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check(error="str_endswith('{string}')") | ||
def str_endswith(data: PolarsData, string: str) -> pl.LazyFrame: | ||
"""Ensure that all values end with a certain string. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param string: String all values should end with | ||
""" | ||
return data.dataframe.with_columns( | ||
[pl.col(data.key).str.ends_with(string).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="str_length({min_value}, {max_value})", | ||
) | ||
def str_length( | ||
data: PolarsData, | ||
min_value: int = None, | ||
max_value: int = None, | ||
) -> pl.LazyFrame: | ||
"""Ensure that the length of strings is within a specified range. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param min_value: Minimum length of strings (default: no minimum) | ||
:param max_value: Maximum length of strings (default: no maximum) | ||
""" | ||
# TODO: consider using len_bytes (faster but returns != n_chars for non ASCII strings | ||
n_chars = pl.col("string_col").str.n_chars() | ||
is_in_min = ( | ||
n_chars.ge(min_value) if min_value is not None else pl.lit(True) | ||
) | ||
is_in_max = ( | ||
n_chars.le(max_value) if max_value is not None else pl.lit(True) | ||
) | ||
|
||
return data.dataframe.with_columns( | ||
[is_in_min.and_(is_in_max).alias(CHECK_OUTPUT_KEY)] | ||
) | ||
|
||
|
||
@register_builtin_check( | ||
error="unique_values_eq({values})", | ||
) | ||
def unique_values_eq(data: PolarsData, values: Iterable) -> bool: | ||
"""Ensure that unique values in the data object contain all values. | ||
|
||
.. note:: | ||
In contrast with :func:`isin`, this check makes sure that all the items | ||
in the ``values`` iterable are contained within the series. | ||
|
||
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The keys | ||
to access the dataframe is "dataframe" and column name using "key". | ||
:param values: The set of values that must be present. Maybe any iterable. | ||
""" | ||
|
||
return ( | ||
set(data.dataframe.collect().get_column(data.key).unique()) == values | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""A flexible and expressive polars validation library for Python.""" | ||
|
||
# pylint: disable=unused-import | ||
from pandera.api.polars.components import Column | ||
from pandera.api.polars.container import DataFrameSchema | ||
|
||
import pandera.backends.polars | ||
from pandera.api.checks import Check |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because some dtypes objects are in instantiated in the schema, for instance
Duration(time_unit="us")
, theis
comparison doesn't work here. So I used theis_
method provided by the PolarsDataType
API instead