Skip to content

Commit 43b97ff

Browse files
committed
Implement gt and ge check for the Ibis backend
Signed-off-by: Deepyaman Datta <[email protected]>
1 parent 8b67b07 commit 43b97ff

File tree

5 files changed

+266
-18
lines changed

5 files changed

+266
-18
lines changed

pandera/backends/ibis/builtin_checks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,36 @@ def not_equal_to(data: IbisData, value: Any) -> ir.Table:
5353
"""
5454
value = _infer_interval_with_mixed_units(value)
5555
return data.table[data.key] != value
56+
57+
58+
@register_builtin_check(
59+
aliases=["gt"],
60+
error="greater_than({value})",
61+
)
62+
def greater_than(data: IbisData, min_value: Any) -> ir.Table:
63+
"""Ensure values of a data container are strictly greater than a minimum
64+
value.
65+
66+
:param data: NamedTuple IbisData contains the table and column name for the check. The key
67+
to access the table is "table", and the key to access the column name is "key".
68+
:param min_value: Lower bound to be exceeded. Must be a type comparable
69+
to the dtype of the :class:`ir.Column` to be validated.
70+
"""
71+
value = _infer_interval_with_mixed_units(min_value)
72+
return data.table[data.key] > value
73+
74+
75+
@register_builtin_check(
76+
aliases=["ge"],
77+
error="greater_than_or_equal_to({value})",
78+
)
79+
def greater_than_or_equal_to(data: IbisData, min_value: Any) -> ir.Table:
80+
"""Ensure all values are greater than or equal to a certain value.
81+
82+
:param data: NamedTuple IbisData contains the table and column name for the check. The key
83+
to access the table is "table", and the key to access the column name is "key".
84+
:param min_value: Allowed minimum value. Must be a type comparable
85+
to the dtype of the :class:`ir.Column` to be validated.
86+
"""
87+
value = _infer_interval_with_mixed_units(min_value)
88+
return data.table[data.key] >= value

pandera/backends/pandas/builtin_checks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ def greater_than(data: PandasData, min_value: Any) -> PandasData:
8888
value.
8989
9090
:param min_value: Lower bound to be exceeded. Must be a type comparable
91-
to the dtype of the :class:`pandas.Series` to be validated (e.g. a
92-
numerical type for float or int and a datetime for datetime).
91+
to the dtype of the :class:`pandas.Series` to be validated.
9392
"""
9493
return data > min_value
9594

@@ -100,11 +99,10 @@ def greater_than(data: PandasData, min_value: Any) -> PandasData:
10099
error="greater_than_or_equal_to({min_value})",
101100
)
102101
def greater_than_or_equal_to(data: PandasData, min_value: Any) -> PandasData:
103-
"""Ensure all values are greater or equal a certain value.
102+
"""Ensure all values are greater than or equal to a certain value.
104103
105-
:param min_value: Allowed minimum value for values of a series. Must be
106-
a type comparable to the dtype of the :class:`pandas.Series` to be
107-
validated.
104+
:param min_value: Allowed minimum value. Must be a type comparable
105+
to the dtype of the :class:`pandas.Series` to be validated.
108106
"""
109107
return data >= min_value
110108

pandera/backends/polars/builtin_checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def greater_than(data: PolarsData, min_value: Any) -> pl.LazyFrame:
6262
error="greater_than_or_equal_to({min_value})",
6363
)
6464
def greater_than_or_equal_to(data: PolarsData, min_value: Any) -> pl.LazyFrame:
65-
"""Ensure all values are greater or equal a certain value.
65+
"""Ensure all values are greater than or equal to a certain value.
6666
6767
:param data: NamedTuple PolarsData contains the dataframe and column name for the check. The key
6868
to access the dataframe is "dataframe", and the key the to access the column name is "key".
69-
:param min_value: Allowed minimum value for values of a series. Must be
70-
a type comparable to the dtype of the series datatype of Polars.
69+
:param min_value: Allowed minimum value. Must be a type comparable
70+
to the dtype of the series datatype of Polars.
7171
"""
7272
return data.lazyframe.select(pl.col(data.key).ge(min_value))
7373

pandera/backends/pyspark/builtin_checks.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
6060
def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
6161
"""Ensure no element of a data container equals a certain value.
6262
63-
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
63+
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The key
6464
to access the dataframe is "dataframe" and column name using "column_name".
6565
:param value: This value must not occur in the checked
6666
"""
@@ -76,11 +76,11 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool:
7676
acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE)
7777
)
7878
def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
79-
"""
80-
Ensure values of a data container are strictly greater than a minimum
79+
"""Ensure values of a data container are strictly greater than a minimum
8180
value.
82-
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
83-
to access the dataframe is "dataframe" and column name using "column_name".
81+
82+
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The key
83+
to access the dataframe is "dataframe" and column name using "column_name".
8484
:param min_value: Lower bound to be exceeded.
8585
"""
8686
cond = col(data.column_name) > min_value
@@ -98,11 +98,12 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool:
9898
def greater_than_or_equal_to(
9999
data: PysparkDataframeColumnObject, min_value: Any
100100
) -> bool:
101-
"""Ensure all values are greater or equal a certain value.
101+
"""Ensure all values are greater than or equal to a certain value.
102+
102103
:param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys
103-
to access the dataframe is "dataframe" and column name using "column_name".
104-
:param min_value: Allowed minimum value for values of a series. Must be
105-
a type comparable to the dtype of the column datatype of pyspark
104+
to access the dataframe is "dataframe" and column name using "column_name".
105+
:param min_value: Allowed minimum value. Must be a type comparable
106+
to the dtype of the column datatype of pyspark
106107
"""
107108
cond = col(data.column_name) >= min_value
108109
return data.dataframe.filter(~cond).limit(1).count() == 0

tests/ibis/test_ibis_builtin_checks.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,219 @@ def test_not_equal_to_check(self, check_fn, datatype, data) -> None:
373373
datatype,
374374
data["test_expression"],
375375
)
376+
377+
378+
class TestGreaterThanCheck(BaseClass):
379+
"""This class is used to test the greater than check"""
380+
381+
sample_numeric_data = {
382+
"test_pass_data": [("foo", 31), ("bar", 32)],
383+
"test_fail_data": [("foo", 30), ("bar", 31)],
384+
"test_expression": 30,
385+
}
386+
387+
sample_datetime_data = {
388+
"test_pass_data": [
389+
("foo", datetime.datetime(2020, 10, 2, 11, 0)),
390+
("bar", datetime.datetime(2020, 10, 2, 11, 0)),
391+
],
392+
"test_fail_data": [
393+
("foo", datetime.datetime(2020, 10, 1, 10, 0)),
394+
("bar", datetime.datetime(2020, 10, 2, 11, 0)),
395+
],
396+
"test_expression": datetime.datetime(2020, 10, 1, 10, 0),
397+
}
398+
399+
sample_duration_data = {
400+
"test_pass_data": [
401+
("foo", datetime.timedelta(100, 11, 1)),
402+
("bar", datetime.timedelta(100, 12, 1)),
403+
],
404+
"test_fail_data": [
405+
("foo", datetime.timedelta(100, 10, 1)),
406+
("bar", datetime.timedelta(100, 11, 1)),
407+
],
408+
"test_expression": datetime.timedelta(100, 10, 1),
409+
}
410+
411+
def pytest_generate_tests(self, metafunc):
412+
"""This function passes the parameter for each function based on parameter form get_data_param function"""
413+
# called once per each test function
414+
funcarglist = self.get_data_param()[metafunc.function.__name__]
415+
argnames = sorted(funcarglist[0])
416+
metafunc.parametrize(
417+
argnames,
418+
[
419+
[funcargs[name] for name in argnames]
420+
for funcargs in funcarglist
421+
],
422+
)
423+
424+
def get_data_param(self):
425+
"""Generate the params which will be used to test this function. All the acceptable
426+
data types would be tested"""
427+
return {
428+
"test_greater_than_check": [
429+
{"datatype": dt.UInt8, "data": self.sample_numeric_data},
430+
{"datatype": dt.UInt16, "data": self.sample_numeric_data},
431+
{"datatype": dt.UInt32, "data": self.sample_numeric_data},
432+
{"datatype": dt.UInt64, "data": self.sample_numeric_data},
433+
{"datatype": dt.Int8, "data": self.sample_numeric_data},
434+
{"datatype": dt.Int16, "data": self.sample_numeric_data},
435+
{"datatype": dt.Int32, "data": self.sample_numeric_data},
436+
{"datatype": dt.Int64, "data": self.sample_numeric_data},
437+
{
438+
"datatype": dt.Float32,
439+
"data": self.convert_data(
440+
self.sample_numeric_data, "float32"
441+
),
442+
},
443+
{
444+
"datatype": dt.Float64,
445+
"data": self.convert_data(
446+
self.sample_numeric_data, "float64"
447+
),
448+
},
449+
{
450+
"datatype": dt.Date,
451+
"data": self.convert_data(
452+
self.sample_datetime_data, "date"
453+
),
454+
},
455+
{
456+
"datatype": dt.Timestamp.from_unit("us"),
457+
"data": self.sample_datetime_data,
458+
},
459+
{
460+
"datatype": dt.Time,
461+
"data": self.convert_data(
462+
self.sample_datetime_data, "time"
463+
),
464+
},
465+
{
466+
"datatype": dt.Interval(unit="us"),
467+
"data": self.sample_duration_data,
468+
},
469+
]
470+
}
471+
472+
@pytest.mark.parametrize("check_fn", [pa.Check.greater_than, pa.Check.gt])
473+
def test_greater_than_check(self, check_fn, datatype, data) -> None:
474+
"""Test the Check to see if all the values are equal to defined value"""
475+
self.check_function(
476+
check_fn,
477+
data["test_pass_data"],
478+
data["test_fail_data"],
479+
datatype,
480+
data["test_expression"],
481+
)
482+
483+
484+
class TestGreaterThanEqualToCheck(BaseClass):
485+
"""This class is used to test the greater than equal to check"""
486+
487+
sample_numeric_data = {
488+
"test_pass_data": [("foo", 31), ("bar", 32)],
489+
"test_fail_data": [("foo", 30), ("bar", 31)],
490+
"test_expression": 31,
491+
}
492+
493+
sample_datetime_data = {
494+
"test_pass_data": [
495+
("foo", datetime.datetime(2020, 10, 1, 11, 0)),
496+
("bar", datetime.datetime(2020, 10, 2, 11, 0)),
497+
],
498+
"test_fail_data": [
499+
("foo", datetime.datetime(2020, 10, 1, 11, 0)),
500+
("bar", datetime.datetime(2020, 9, 1, 10, 0)),
501+
],
502+
"test_expression": datetime.datetime(2020, 10, 1, 11, 0),
503+
}
504+
505+
sample_duration_data = {
506+
"test_pass_data": [
507+
("foo", datetime.timedelta(100, 10, 1)),
508+
("bar", datetime.timedelta(100, 11, 1)),
509+
],
510+
"test_fail_data": [
511+
("foo", datetime.timedelta(100, 11, 1)),
512+
("bar", datetime.timedelta(100, 9, 1)),
513+
],
514+
"test_expression": datetime.timedelta(100, 10, 1),
515+
}
516+
517+
def pytest_generate_tests(self, metafunc):
518+
"""This function passes the parameter for each function based on parameter form get_data_param function"""
519+
# called once per each test function
520+
funcarglist = self.get_data_param()[metafunc.function.__name__]
521+
argnames = sorted(funcarglist[0])
522+
metafunc.parametrize(
523+
argnames,
524+
[
525+
[funcargs[name] for name in argnames]
526+
for funcargs in funcarglist
527+
],
528+
)
529+
530+
def get_data_param(self):
531+
"""Generate the params which will be used to test this function. All the acceptable
532+
data types would be tested"""
533+
return {
534+
"test_greater_than_or_equal_to_check": [
535+
{"datatype": dt.UInt8, "data": self.sample_numeric_data},
536+
{"datatype": dt.UInt16, "data": self.sample_numeric_data},
537+
{"datatype": dt.UInt32, "data": self.sample_numeric_data},
538+
{"datatype": dt.UInt64, "data": self.sample_numeric_data},
539+
{"datatype": dt.Int8, "data": self.sample_numeric_data},
540+
{"datatype": dt.Int16, "data": self.sample_numeric_data},
541+
{"datatype": dt.Int32, "data": self.sample_numeric_data},
542+
{"datatype": dt.Int64, "data": self.sample_numeric_data},
543+
{
544+
"datatype": dt.Float32,
545+
"data": self.convert_data(
546+
self.sample_numeric_data, "float32"
547+
),
548+
},
549+
{
550+
"datatype": dt.Float64,
551+
"data": self.convert_data(
552+
self.sample_numeric_data, "float64"
553+
),
554+
},
555+
{
556+
"datatype": dt.Date,
557+
"data": self.convert_data(
558+
self.sample_datetime_data, "date"
559+
),
560+
},
561+
{
562+
"datatype": dt.Timestamp.from_unit("us"),
563+
"data": self.sample_datetime_data,
564+
},
565+
{
566+
"datatype": dt.Time,
567+
"data": self.convert_data(
568+
self.sample_datetime_data, "time"
569+
),
570+
},
571+
{
572+
"datatype": dt.Interval(unit="us"),
573+
"data": self.sample_duration_data,
574+
},
575+
]
576+
}
577+
578+
@pytest.mark.parametrize(
579+
"check_fn", [pa.Check.greater_than_or_equal_to, pa.Check.ge]
580+
)
581+
def test_greater_than_or_equal_to_check(
582+
self, check_fn, datatype, data
583+
) -> None:
584+
"""Test the Check to see if all the values are equal to defined value"""
585+
self.check_function(
586+
check_fn,
587+
data["test_pass_data"],
588+
data["test_fail_data"],
589+
datatype,
590+
data["test_expression"],
591+
)

0 commit comments

Comments
 (0)