1
1
"""Built-in checks for Ibis."""
2
2
3
3
import datetime
4
- from typing import Any , TypeVar
4
+ from typing import Any , Optional , TypeVar
5
5
6
6
import ibis
7
7
import ibis .expr .types as ir
8
+ from ibis import _ , selectors as s
9
+ from ibis .common .selectors import Selector
8
10
9
11
from pandera .api .extensions import register_builtin_check
10
12
from pandera .api .ibis .types import IbisData
13
+ from pandera .backends .ibis .utils import select_column
14
+ from pandera .constants import check_col_name
15
+
11
16
12
17
T = TypeVar ("T" )
13
18
@@ -24,6 +29,10 @@ def _infer_interval_with_mixed_units(value: Any) -> Any:
24
29
return value
25
30
26
31
32
+ def _selector (key : Optional [str ]) -> Selector :
33
+ return s .all () if key is None else select_column (key )
34
+
35
+
27
36
@register_builtin_check (
28
37
aliases = ["eq" ],
29
38
error = "equal_to({value})" ,
@@ -37,7 +46,9 @@ def equal_to(data: IbisData, value: Any) -> ir.Table:
37
46
equal to this value.
38
47
"""
39
48
value = _infer_interval_with_mixed_units (value )
40
- return data .table [data .key ] == value
49
+ return data .table .mutate (
50
+ s .across (_selector (data .key ), _ == value , names = check_col_name )
51
+ )
41
52
42
53
43
54
@register_builtin_check (
@@ -52,12 +63,14 @@ def not_equal_to(data: IbisData, value: Any) -> ir.Table:
52
63
:param value: This value must not occur in the checked data structure.
53
64
"""
54
65
value = _infer_interval_with_mixed_units (value )
55
- return data .table [data .key ] != value
66
+ return data .table .mutate (
67
+ s .across (_selector (data .key ), _ != value , names = check_col_name )
68
+ )
56
69
57
70
58
71
@register_builtin_check (
59
72
aliases = ["gt" ],
60
- error = "greater_than({value })" ,
73
+ error = "greater_than({min_value })" ,
61
74
)
62
75
def greater_than (data : IbisData , min_value : Any ) -> ir .Table :
63
76
"""Ensure values of a column are strictly greater than a minimum
@@ -69,12 +82,14 @@ def greater_than(data: IbisData, min_value: Any) -> ir.Table:
69
82
to the dtype of the :class:`ir.Column` to be validated.
70
83
"""
71
84
value = _infer_interval_with_mixed_units (min_value )
72
- return data .table [data .key ] > value
85
+ return data .table .mutate (
86
+ s .across (_selector (data .key ), _ > value , names = check_col_name )
87
+ )
73
88
74
89
75
90
@register_builtin_check (
76
91
aliases = ["ge" ],
77
- error = "greater_than_or_equal_to({value })" ,
92
+ error = "greater_than_or_equal_to({min_value })" ,
78
93
)
79
94
def greater_than_or_equal_to (data : IbisData , min_value : Any ) -> ir .Table :
80
95
"""Ensure all values are greater than or equal to a minimum value.
@@ -85,7 +100,9 @@ def greater_than_or_equal_to(data: IbisData, min_value: Any) -> ir.Table:
85
100
to the dtype of the :class:`ir.Column` to be validated.
86
101
"""
87
102
value = _infer_interval_with_mixed_units (min_value )
88
- return data .table [data .key ] >= value
103
+ return data .table .mutate (
104
+ s .across (_selector (data .key ), _ >= value , names = check_col_name )
105
+ )
89
106
90
107
91
108
@register_builtin_check (
@@ -102,7 +119,9 @@ def less_than(data: IbisData, max_value: Any) -> ir.Table:
102
119
:class:`ir.Column` to be validated.
103
120
"""
104
121
value = _infer_interval_with_mixed_units (max_value )
105
- return data .table [data .key ] < value
122
+ return data .table .mutate (
123
+ s .across (_selector (data .key ), _ < value , names = check_col_name )
124
+ )
106
125
107
126
108
127
@register_builtin_check (
@@ -118,4 +137,6 @@ def less_than_or_equal_to(data: IbisData, max_value: Any) -> ir.Table:
118
137
:class:`ir.Column` to be validated.
119
138
"""
120
139
value = _infer_interval_with_mixed_units (max_value )
121
- return data .table [data .key ] <= value
140
+ return data .table .mutate (
141
+ s .across (_selector (data .key ), _ <= value , names = check_col_name )
142
+ )
0 commit comments