Skip to content

Commit a18cb5d

Browse files
jcristcpcloud
andauthored
feat(api): support order_by in order-sensitive aggregates (collect/group_concat/first/last) (#9729)
Co-authored-by: Phillip Cloud <[email protected]>
1 parent 7d38f09 commit a18cb5d

File tree

30 files changed

+537
-128
lines changed

30 files changed

+537
-128
lines changed

ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_none/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ SELECT
33
WHEN empty(groupArray("t0"."string_col"))
44
THEN NULL
55
ELSE arrayStringConcat(groupArray("t0"."string_col"), ',')
6-
END AS "GroupConcat(string_col, ',')"
6+
END AS "GroupConcat(string_col, ',', ())"
77
FROM "functional_alltypes" AS "t0"

ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/comma_zero/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ SELECT
33
WHEN empty(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0))
44
THEN NULL
55
ELSE arrayStringConcat(groupArrayIf("t0"."string_col", "t0"."bool_col" = 0), ',')
6-
END AS "GroupConcat(string_col, ',', Equals(bool_col, 0))"
6+
END AS "GroupConcat(string_col, ',', (), Equals(bool_col, 0))"
77
FROM "functional_alltypes" AS "t0"

ibis/backends/clickhouse/tests/snapshots/test_functions/test_group_concat/minus_none/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ SELECT
33
WHEN empty(groupArray("t0"."string_col"))
44
THEN NULL
55
ELSE arrayStringConcat(groupArray("t0"."string_col"), '-')
6-
END AS "GroupConcat(string_col, '-')"
6+
END AS "GroupConcat(string_col, '-', ())"
77
FROM "functional_alltypes" AS "t0"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
SELECT
2-
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col)`
2+
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col, ())`
33
FROM `functional_alltypes` AS `t0`
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
SELECT
2-
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col)`
2+
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col, ())`
33
FROM `functional_alltypes` AS `t0`

ibis/backends/pandas/executor.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
plan,
3131
)
3232
from ibis.common.dispatch import Dispatched
33-
from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError
33+
from ibis.common.exceptions import (
34+
OperationNotDefinedError,
35+
UnboundExpressionError,
36+
UnsupportedOperationError,
37+
)
3438
from ibis.formats.pandas import PandasData, PandasType
3539
from ibis.util import any_of, gen_name
3640

@@ -253,7 +257,12 @@ def visit(
253257
############################# Reductions ##################################
254258

255259
@classmethod
256-
def visit(cls, op: ops.Reduction, arg, where):
260+
def visit(cls, op: ops.Reduction, arg, where, order_by=()):
261+
if order_by:
262+
raise UnsupportedOperationError(
263+
"ordering of order-sensitive aggregations via `order_by` is "
264+
"not supported for this backend"
265+
)
257266
func = cls.kernels.reductions[type(op)]
258267
return cls.agg(func, arg, where)
259268

@@ -344,7 +353,13 @@ def agg(df):
344353
return agg
345354

346355
@classmethod
347-
def visit(cls, op: ops.GroupConcat, arg, sep, where):
356+
def visit(cls, op: ops.GroupConcat, arg, sep, where, order_by):
357+
if order_by:
358+
raise UnsupportedOperationError(
359+
"ordering of order-sensitive aggregations via `order_by` is "
360+
"not supported for this backend"
361+
)
362+
348363
if where is None:
349364

350365
def agg(df):

ibis/backends/polars/compiler.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _literal_value(op, nan_as_none=False):
4545

4646

4747
@singledispatch
48-
def translate(expr, *, ctx):
48+
def translate(expr, **_):
4949
raise NotImplementedError(expr)
5050

5151

@@ -748,6 +748,11 @@ def execute_first_last(op, **kw):
748748

749749
arg = arg.filter(predicate)
750750

751+
if order_by := getattr(op, "order_by", ()):
752+
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
753+
descending = [k.descending for k in order_by]
754+
arg = arg.sort_by(keys, descending=descending)
755+
751756
return arg.last() if isinstance(op, ops.Last) else arg.first()
752757

753758

@@ -985,14 +990,21 @@ def array_column(op, **kw):
985990
@translate.register(ops.ArrayCollect)
986991
def array_collect(op, in_group_by=False, **kw):
987992
arg = translate(op.arg, **kw)
988-
if (where := op.where) is not None:
989-
arg = arg.filter(translate(where, **kw))
990-
out = arg.drop_nulls()
991-
if not in_group_by:
992-
# Polars' behavior changes for `implode` within a `group_by` currently.
993-
# See https://github.com/pola-rs/polars/issues/16756
994-
out = out.implode()
995-
return out
993+
994+
predicate = arg.is_not_null()
995+
if op.where is not None:
996+
predicate &= translate(op.where, **kw)
997+
998+
arg = arg.filter(predicate)
999+
1000+
if op.order_by:
1001+
keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by]
1002+
descending = [k.descending for k in op.order_by]
1003+
arg = arg.sort_by(keys, descending=descending)
1004+
1005+
# Polars' behavior changes for `implode` within a `group_by` currently.
1006+
# See https://github.com/pola-rs/polars/issues/16756
1007+
return arg if in_group_by else arg.implode()
9961008

9971009

9981010
@translate.register(ops.ArrayFlatten)
@@ -1390,3 +1402,23 @@ def execute_array_all(op, **kw):
13901402
arg = translate(op.arg, **kw)
13911403
no_nulls = arg.list.drop_nulls()
13921404
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.all())
1405+
1406+
1407+
@translate.register(ops.GroupConcat)
1408+
def execute_group_concat(op, **kw):
1409+
arg = translate(op.arg, **kw)
1410+
sep = _literal_value(op.sep)
1411+
1412+
predicate = arg.is_not_null()
1413+
1414+
if (where := op.where) is not None:
1415+
predicate &= translate(where, **kw)
1416+
1417+
arg = arg.filter(predicate)
1418+
1419+
if order_by := op.order_by:
1420+
keys = [translate(k.expr, **kw).filter(predicate) for k in order_by]
1421+
descending = [k.descending for k in order_by]
1422+
arg = arg.sort_by(keys, descending=descending)
1423+
1424+
return pl.when(arg.count() > 0).then(arg.str.join(sep)).otherwise(None)

ibis/backends/sql/compilers/base.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class AggGen:
6363
supports_filter
6464
Whether the backend supports a FILTER clause in the aggregate.
6565
Defaults to False.
66+
supports_order_by
67+
Whether the backend supports an ORDER BY clause in (relevant)
68+
aggregates. Defaults to False.
6669
"""
6770

6871
class _Accessor:
@@ -79,10 +82,13 @@ def __getattr__(self, name: str) -> Callable:
7982

8083
__getitem__ = __getattr__
8184

82-
__slots__ = ("supports_filter",)
85+
__slots__ = ("supports_filter", "supports_order_by")
8386

84-
def __init__(self, *, supports_filter: bool = False):
87+
def __init__(
88+
self, *, supports_filter: bool = False, supports_order_by: bool = False
89+
):
8590
self.supports_filter = supports_filter
91+
self.supports_order_by = supports_order_by
8692

8793
def __get__(self, instance, owner=None):
8894
if instance is None:
@@ -96,6 +102,7 @@ def aggregate(
96102
name: str,
97103
*args: Any,
98104
where: Any = None,
105+
order_by: tuple = (),
99106
):
100107
"""Compile the specified aggregate.
101108
@@ -109,21 +116,31 @@ def aggregate(
109116
Any arguments to pass to the aggregate.
110117
where
111118
An optional column filter to apply before performing the aggregate.
112-
119+
order_by
120+
Optional ordering keys to use to order the rows before performing
121+
the aggregate.
113122
"""
114123
func = compiler.f[name]
115124

116-
if where is None:
117-
return func(*args)
118-
119-
if self.supports_filter:
120-
return sge.Filter(
121-
this=func(*args),
122-
expression=sge.Where(this=where),
125+
if order_by and not self.supports_order_by:
126+
raise com.UnsupportedOperationError(
127+
"ordering of order-sensitive aggregations via `order_by` is "
128+
f"not supported for the {compiler.dialect} backend"
123129
)
124-
else:
130+
131+
if where is not None and not self.supports_filter:
125132
args = tuple(compiler.if_(where, arg, NULL) for arg in args)
126-
return func(*args)
133+
134+
if order_by and self.supports_order_by:
135+
*rest, last = args
136+
out = func(*rest, sge.Order(this=last, expressions=order_by))
137+
else:
138+
out = func(*args)
139+
140+
if where is not None and self.supports_filter:
141+
out = sge.Filter(this=out, expression=sge.Where(this=where))
142+
143+
return out
127144

128145

129146
class VarGen:
@@ -424,8 +441,10 @@ def make_impl(op, target_name):
424441

425442
if issubclass(op, ops.Reduction):
426443

427-
def impl(self, _, *, _name: str = target_name, where, **kw):
428-
return self.agg[_name](*kw.values(), where=where)
444+
def impl(
445+
self, _, *, _name: str = target_name, where, order_by=(), **kw
446+
):
447+
return self.agg[_name](*kw.values(), where=where, order_by=order_by)
429448

430449
else:
431450

ibis/backends/sql/compilers/bigquery.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import ibis.expr.datatypes as dt
1313
import ibis.expr.operations as ops
1414
from ibis import util
15-
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
15+
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
1616
from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType
1717
from ibis.backends.sql.rewrites import (
1818
exclude_unsupported_window_frame_from_ops,
@@ -28,6 +28,9 @@ class BigQueryCompiler(SQLGlotCompiler):
2828
dialect = BigQuery
2929
type_mapper = BigQueryType
3030
udf_type_mapper = BigQueryUDFType
31+
32+
agg = AggGen(supports_order_by=True)
33+
3134
rewrites = (
3235
exclude_unsupported_window_frame_from_ops,
3336
exclude_unsupported_window_frame_from_row_number,
@@ -172,10 +175,14 @@ def visit_TimestampDelta(self, op, *, left, right, part):
172175
"timestamp difference with mixed timezone/timezoneless values is not implemented"
173176
)
174177

175-
def visit_GroupConcat(self, op, *, arg, sep, where):
178+
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
176179
if where is not None:
177180
arg = self.if_(where, arg, NULL)
178-
return self.f.string_agg(arg, sep)
181+
182+
if order_by:
183+
sep = sge.Order(this=sep, expressions=order_by)
184+
185+
return sge.GroupConcat(this=arg, separator=sep)
179186

180187
def visit_FloorDivide(self, op, *, left, right):
181188
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)
@@ -225,10 +232,10 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
225232
return self.f.parse_timestamp(format_str, arg, timezone)
226233
return self.f.parse_datetime(format_str, arg)
227234

228-
def visit_ArrayCollect(self, op, *, arg, where):
229-
if where is not None:
230-
arg = self.if_(where, arg, NULL)
231-
return self.f.array_agg(sge.IgnoreNulls(this=arg))
235+
def visit_ArrayCollect(self, op, *, arg, where, order_by):
236+
return sge.IgnoreNulls(
237+
this=self.agg.array_agg(arg, where=where, order_by=order_by)
238+
)
232239

233240
def _neg_idx_to_pos(self, arg, idx):
234241
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
@@ -474,17 +481,25 @@ def visit_TimestampRange(self, op, *, start, stop, step):
474481
self.f.generate_timestamp_array, start, stop, step, op.step.dtype
475482
)
476483

477-
def visit_First(self, op, *, arg, where):
484+
def visit_First(self, op, *, arg, where, order_by):
478485
if where is not None:
479486
arg = self.if_(where, arg, NULL)
487+
488+
if order_by:
489+
arg = sge.Order(this=arg, expressions=order_by)
490+
480491
array = self.f.array_agg(
481492
sge.Limit(this=sge.IgnoreNulls(this=arg), expression=sge.convert(1)),
482493
)
483494
return array[self.f.safe_offset(0)]
484495

485-
def visit_Last(self, op, *, arg, where):
496+
def visit_Last(self, op, *, arg, where, order_by):
486497
if where is not None:
487498
arg = self.if_(where, arg, NULL)
499+
500+
if order_by:
501+
arg = sge.Order(this=arg, expressions=order_by)
502+
488503
array = self.f.array_reverse(self.f.array_agg(sge.IgnoreNulls(this=arg)))
489504
return array[self.f.safe_offset(0)]
490505

ibis/backends/sql/compilers/clickhouse.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020

2121

2222
class ClickhouseAggGen(AggGen):
23-
def aggregate(self, compiler, name, *args, where=None):
23+
def aggregate(self, compiler, name, *args, where=None, order_by=()):
24+
if order_by:
25+
raise com.UnsupportedOperationError(
26+
"ordering of order-sensitive aggregations via `order_by` is "
27+
"not supported for this backend"
28+
)
2429
# Clickhouse aggregate functions all have filtering variants with a
2530
# `If` suffix (e.g. `SumIf` instead of `Sum`).
2631
if where is not None:
@@ -433,7 +438,12 @@ def visit_StringSplit(self, op, *, arg, delimiter):
433438
delimiter, self.cast(arg, dt.String(nullable=False))
434439
)
435440

436-
def visit_GroupConcat(self, op, *, arg, sep, where):
441+
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
442+
if order_by:
443+
raise com.UnsupportedOperationError(
444+
"ordering of order-sensitive aggregations via `order_by` is "
445+
"not supported for this backend"
446+
)
437447
call = self.agg.groupArray(arg, where=where)
438448
return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep))
439449

ibis/backends/sql/compilers/datafusion.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DataFusionCompiler(SQLGlotCompiler):
3030
*SQLGlotCompiler.rewrites,
3131
)
3232

33-
agg = AggGen(supports_filter=True)
33+
agg = AggGen(supports_filter=True, supports_order_by=True)
3434

3535
UNSUPPORTED_OPS = (
3636
ops.ArgMax,
@@ -425,15 +425,15 @@ def visit_StringConcat(self, op, *, arg):
425425
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
426426
)
427427

428-
def visit_First(self, op, *, arg, where):
428+
def visit_First(self, op, *, arg, where, order_by):
429429
cond = arg.is_(sg.not_(NULL, copy=False))
430430
where = cond if where is None else sge.And(this=cond, expression=where)
431-
return self.agg.first_value(arg, where=where)
431+
return self.agg.first_value(arg, where=where, order_by=order_by)
432432

433-
def visit_Last(self, op, *, arg, where):
433+
def visit_Last(self, op, *, arg, where, order_by):
434434
cond = arg.is_(sg.not_(NULL, copy=False))
435435
where = cond if where is None else sge.And(this=cond, expression=where)
436-
return self.agg.last_value(arg, where=where)
436+
return self.agg.last_value(arg, where=where, order_by=order_by)
437437

438438
def visit_Aggregate(self, op, *, parent, groups, metrics):
439439
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
@@ -488,3 +488,12 @@ def visit_StructColumn(self, op, *, names, values):
488488
args.append(sge.convert(name))
489489
args.append(value)
490490
return self.f.named_struct(*args)
491+
492+
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
493+
if order_by:
494+
raise com.UnsupportedOperationError(
495+
"DataFusion does not support order-sensitive group_concat"
496+
)
497+
return super().visit_GroupConcat(
498+
op, arg=arg, sep=sep, where=where, order_by=order_by
499+
)

0 commit comments

Comments
 (0)