Skip to content

Commit d3e4104

Browse files
committed
feat(risingwave): support include_null in first/last aggs
1 parent ba2a0be commit d3e4104

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

ibis/backends/sql/compilers/risingwave.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sqlglot as sg
34
import sqlglot.expressions as sge
45

56
import ibis.common.exceptions as com
@@ -40,25 +41,23 @@ def visit_DateNow(self, op):
4041
return self.cast(sge.CurrentTimestamp(), dt.date)
4142

4243
def visit_First(self, op, *, arg, where, order_by, include_null):
43-
if include_null:
44-
raise com.UnsupportedOperationError(
45-
"`include_null=True` is not supported by the risingwave backend"
46-
)
4744
if not order_by:
4845
raise com.UnsupportedOperationError(
4946
"RisingWave requires an `order_by` be specified in `first`"
5047
)
48+
if not include_null:
49+
cond = arg.is_(sg.not_(NULL, copy=False))
50+
where = cond if where is None else sge.And(this=cond, expression=where)
5151
return self.agg.first_value(arg, where=where, order_by=order_by)
5252

5353
def visit_Last(self, op, *, arg, where, order_by, include_null):
54-
if include_null:
55-
raise com.UnsupportedOperationError(
56-
"`include_null=True` is not supported by the risingwave backend"
57-
)
5854
if not order_by:
5955
raise com.UnsupportedOperationError(
6056
"RisingWave requires an `order_by` be specified in `last`"
6157
)
58+
if not include_null:
59+
cond = arg.is_(sg.not_(NULL, copy=False))
60+
where = cond if where is None else sge.And(this=cond, expression=where)
6261
return self.agg.last_value(arg, where=where, order_by=order_by)
6362

6463
def visit_Correlation(self, op, *, left, right, how, where):

ibis/backends/tests/test_aggregation.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -647,23 +647,16 @@ def test_first_last(alltypes, method, filtered, include_null):
647647
raises=com.OperationNotDefinedError,
648648
)
649649
@pytest.mark.parametrize("method", ["first", "last"])
650-
@pytest.mark.parametrize("filtered", [False, True])
650+
@pytest.mark.parametrize("filtered", [False, True], ids=["not-filtered", "filtered"])
651651
@pytest.mark.parametrize(
652652
"include_null",
653653
[
654-
False,
654+
param(False, id="exclude-null"),
655655
param(
656656
True,
657657
marks=[
658658
pytest.mark.notimpl(
659-
[
660-
"clickhouse",
661-
"exasol",
662-
"flink",
663-
"postgres",
664-
"risingwave",
665-
"snowflake",
666-
],
659+
["clickhouse", "exasol", "flink", "postgres", "snowflake"],
667660
raises=com.UnsupportedOperationError,
668661
reason="`include_null=True` is not supported",
669662
),
@@ -674,6 +667,7 @@ def test_first_last(alltypes, method, filtered, include_null):
674667
strict=False,
675668
),
676669
],
670+
id="include-null",
677671
),
678672
],
679673
)

0 commit comments

Comments
 (0)