Skip to content

Commit 8bc1970

Browse files
committed
built-in checks support table-level checks
Signed-off-by: cosmicBboy <[email protected]>
1 parent 4b0f5c4 commit 8bc1970

File tree

5 files changed

+51
-15
lines changed

5 files changed

+51
-15
lines changed

pandera/backends/ibis/builtin_checks.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Built-in checks for Ibis."""
22

33
import datetime
4-
from typing import Any, TypeVar
4+
from typing import Any, Optional, TypeVar
55

66
import ibis
77
import ibis.expr.types as ir
8+
from ibis import _, selectors as s
9+
from ibis.common.selectors import Selector
810

911
from pandera.api.extensions import register_builtin_check
1012
from pandera.api.ibis.types import IbisData
13+
from pandera.backends.ibis.utils import select_column
14+
from pandera.constants import check_col_name
15+
1116

1217
T = TypeVar("T")
1318

@@ -24,6 +29,10 @@ def _infer_interval_with_mixed_units(value: Any) -> Any:
2429
return value
2530

2631

32+
def _selector(key: Optional[str]) -> Selector:
33+
return s.all() if key is None else select_column(key)
34+
35+
2736
@register_builtin_check(
2837
aliases=["eq"],
2938
error="equal_to({value})",
@@ -37,7 +46,9 @@ def equal_to(data: IbisData, value: Any) -> ir.Table:
3746
equal to this value.
3847
"""
3948
value = _infer_interval_with_mixed_units(value)
40-
return data.table[data.key] == value
49+
return data.table.mutate(
50+
s.across(_selector(data.key), _ == value, names=check_col_name)
51+
)
4152

4253

4354
@register_builtin_check(
@@ -52,12 +63,14 @@ def not_equal_to(data: IbisData, value: Any) -> ir.Table:
5263
:param value: This value must not occur in the checked data structure.
5364
"""
5465
value = _infer_interval_with_mixed_units(value)
55-
return data.table[data.key] != value
66+
return data.table.mutate(
67+
s.across(_selector(data.key), _ != value, names=check_col_name)
68+
)
5669

5770

5871
@register_builtin_check(
5972
aliases=["gt"],
60-
error="greater_than({value})",
73+
error="greater_than({min_value})",
6174
)
6275
def greater_than(data: IbisData, min_value: Any) -> ir.Table:
6376
"""Ensure values of a column are strictly greater than a minimum
@@ -69,12 +82,14 @@ def greater_than(data: IbisData, min_value: Any) -> ir.Table:
6982
to the dtype of the :class:`ir.Column` to be validated.
7083
"""
7184
value = _infer_interval_with_mixed_units(min_value)
72-
return data.table[data.key] > value
85+
return data.table.mutate(
86+
s.across(_selector(data.key), _ > value, names=check_col_name)
87+
)
7388

7489

7590
@register_builtin_check(
7691
aliases=["ge"],
77-
error="greater_than_or_equal_to({value})",
92+
error="greater_than_or_equal_to({min_value})",
7893
)
7994
def greater_than_or_equal_to(data: IbisData, min_value: Any) -> ir.Table:
8095
"""Ensure all values are greater than or equal to a minimum value.
@@ -85,7 +100,9 @@ def greater_than_or_equal_to(data: IbisData, min_value: Any) -> ir.Table:
85100
to the dtype of the :class:`ir.Column` to be validated.
86101
"""
87102
value = _infer_interval_with_mixed_units(min_value)
88-
return data.table[data.key] >= value
103+
return data.table.mutate(
104+
s.across(_selector(data.key), _ >= value, names=check_col_name)
105+
)
89106

90107

91108
@register_builtin_check(
@@ -102,7 +119,9 @@ def less_than(data: IbisData, max_value: Any) -> ir.Table:
102119
:class:`ir.Column` to be validated.
103120
"""
104121
value = _infer_interval_with_mixed_units(max_value)
105-
return data.table[data.key] < value
122+
return data.table.mutate(
123+
s.across(_selector(data.key), _ < value, names=check_col_name)
124+
)
106125

107126

108127
@register_builtin_check(
@@ -118,4 +137,6 @@ def less_than_or_equal_to(data: IbisData, max_value: Any) -> ir.Table:
118137
:class:`ir.Column` to be validated.
119138
"""
120139
value = _infer_interval_with_mixed_units(max_value)
121-
return data.table[data.key] <= value
140+
return data.table.mutate(
141+
s.across(_selector(data.key), _ <= value, names=check_col_name)
142+
)

pandera/backends/ibis/checks.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from pandera.api.ibis.types import IbisData
1414
from pandera.backends.base import BaseCheckBackend
1515
from pandera.backends.ibis.utils import select_column
16-
from pandera.constants import CHECK_OUTPUT_KEY
17-
18-
CHECK_OUTPUT_SUFFIX = f"__{CHECK_OUTPUT_KEY}__"
16+
from pandera.constants import CHECK_OUTPUT_KEY, CHECK_OUTPUT_SUFFIX
1917

2018

2119
class IbisCheckBackend(BaseCheckBackend):
@@ -69,15 +67,15 @@ def apply(self, check_obj: IbisData):
6967
)
7068

7169
if isinstance(out, ir.Table):
72-
# for checks that return a boolean dataframe, make sure all columns
70+
# for checks that return a boolean table, make sure all columns
7371
# are boolean and reduce to a single boolean column.
7472
acc = ibis.literal(True)
7573
for col in out.columns:
7674
if col.endswith(CHECK_OUTPUT_SUFFIX):
7775
assert out[col].type().is_boolean(), (
7876
f"column '{col[: -len(CHECK_OUTPUT_SUFFIX)]}' "
7977
"is not boolean. If check function returns a "
80-
"dataframe, it must contain only boolean columns."
78+
"table, it must contain only boolean columns."
8179
)
8280
acc = acc & out[col]
8381
return out.mutate({CHECK_OUTPUT_KEY: acc})
@@ -145,7 +143,6 @@ def postprocess_table_output(
145143
s.endswith(f"__{CHECK_OUTPUT_KEY}__")
146144
| select_column(CHECK_OUTPUT_KEY)
147145
)
148-
149146
if check_obj.key is not None:
150147
failure_cases = failure_cases.select(check_obj.key)
151148
return CheckResult(

pandera/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Pandera constants."""
22

33
CHECK_OUTPUT_KEY = "check_output"
4+
CHECK_OUTPUT_SUFFIX = f"__{CHECK_OUTPUT_KEY}__"
45
FAILURE_CASE_KEY = "failure_case"
6+
check_col_name = f"{{col}}{CHECK_OUTPUT_SUFFIX}"

pandera/polars.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandera.api.polars.types import PolarsData
1414
from pandera.backends.polars.register import register_polars_backends
1515
from pandera.decorators import check_input, check_io, check_output, check_types
16+
from pandera.constants import check_col_name
1617

1718
register_polars_backends()
1819

@@ -31,4 +32,5 @@
3132
"errors",
3233
"Field",
3334
"PolarsData",
35+
"check_col_name",
3436
]

tests/ibis/test_ibis_check.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def test_ibis_column_check(
6565
assert check_output == expected_output
6666

6767

68+
def _df_check_fn_table_out(data: pa.IbisData) -> ir.Table:
69+
return data.table.mutate(
70+
{col: data.table[col] >= 0 for col in data.table.columns}
71+
)
72+
73+
6874
def _df_check_fn_dict_out(data: pa.IbisData) -> Dict[str, ir.BooleanColumn]:
6975
return {col: data.table[col] >= 0 for col in data.table.columns}
7076

@@ -83,6 +89,14 @@ def _df_check_fn_scalar_out(data: pa.IbisData) -> ir.BooleanScalar:
8389
@pytest.mark.parametrize(
8490
"check_fn, invalid_data, expected_output",
8591
[
92+
[
93+
_df_check_fn_table_out,
94+
{
95+
"col_1": [-1, 2, -3, 4],
96+
"col_2": [1, 2, 3, -4],
97+
},
98+
[False, True, False, False],
99+
],
86100
[
87101
_df_check_fn_dict_out,
88102
{

0 commit comments

Comments
 (0)