Skip to content

Commit 13cf036

Browse files
committed
feat(api): add distinct option to collect
1 parent 1983675 commit 13cf036

File tree

14 files changed

+180
-99
lines changed

14 files changed

+180
-99
lines changed

ibis/backends/polars/compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,10 @@ def array_collect(op, in_group_by=False, **kw):
10031003
if op.order_by:
10041004
keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by]
10051005
descending = [k.descending for k in op.order_by]
1006-
arg = arg.sort_by(keys, descending=descending)
1006+
arg = arg.sort_by(keys, descending=descending, nulls_last=True)
1007+
1008+
if op.distinct:
1009+
arg = arg.unique(maintain_order=op.order_by is not None)
10071010

10081011
# Polars' behavior changes for `implode` within a `group_by` currently.
10091012
# See https://github.com/pola-rs/polars/issues/16756

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -479,16 +479,24 @@ def visit_StringToTimestamp(self, op, *, arg, format_str):
479479
return self.f.parse_timestamp(format_str, arg, timezone)
480480
return self.f.parse_datetime(format_str, arg)
481481

482-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
483-
if where is not None and include_null:
484-
raise com.UnsupportedOperationError(
485-
"Combining `include_null=True` and `where` is not supported "
486-
"by bigquery"
487-
)
488-
out = self.agg.array_agg(arg, where=where, order_by=order_by)
482+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
483+
if where is not None:
484+
if include_null:
485+
raise com.UnsupportedOperationError(
486+
"Combining `include_null=True` and `where` is not supported by bigquery"
487+
)
488+
if distinct:
489+
raise com.UnsupportedOperationError(
490+
"Combining `distinct=True` and `where` is not supported by bigquery"
491+
)
492+
arg = compiler.if_(where, arg, NULL)
493+
if distinct:
494+
arg = sge.Distinct(expressions=[arg])
495+
if order_by:
496+
arg = sge.Order(this=arg, expressions=order_by)
489497
if not include_null:
490-
out = sge.IgnoreNulls(this=out)
491-
return out
498+
arg = sge.IgnoreNulls(this=arg)
499+
return self.f.array_agg(arg)
492500

493501
def _neg_idx_to_pos(self, arg, idx):
494502
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)

ibis/backends/sql/compilers/clickhouse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,12 +611,13 @@ def visit_ArrayUnion(self, op, *, left, right):
611611
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
612612
return self.f.arrayZip(*arg)
613613

614-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
614+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
615615
if include_null:
616616
raise com.UnsupportedOperationError(
617617
"`include_null=True` is not supported by the clickhouse backend"
618618
)
619-
return self.agg.groupArray(arg, where=where, order_by=order_by)
619+
func = self.agg.groupUniqArray if distinct else self.agg.groupArray
620+
return func(arg, where=where, order_by=order_by)
620621

621622
def visit_First(self, op, *, arg, where, order_by, include_null):
622623
if include_null:

ibis/backends/sql/compilers/datafusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,11 @@ def visit_ArrayRepeat(self, op, *, arg, times):
327327
def visit_ArrayPosition(self, op, *, arg, other):
328328
return self.f.coalesce(self.f.array_position(arg, other), 0)
329329

330-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
330+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
331+
if distinct:
332+
raise com.UnsupportedOperationError(
333+
"`collect` with `distinct=True` is not supported"
334+
)
331335
if not include_null:
332336
cond = arg.is_(sg.not_(NULL, copy=False))
333337
where = cond if where is None else sge.And(this=cond, expression=where)

ibis/backends/sql/compilers/duckdb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,12 @@ def visit_ArrayPosition(self, op, *, arg, other):
156156
self.f.coalesce(self.f.list_indexof(arg, other), 0),
157157
)
158158

159-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
159+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
160160
if not include_null:
161161
cond = arg.is_(sg.not_(NULL, copy=False))
162162
where = cond if where is None else sge.And(this=cond, expression=where)
163+
if distinct:
164+
arg = sge.Distinct(expressions=[arg])
163165
return self.agg.array_agg(arg, where=where, order_by=order_by)
164166

165167
def visit_ArrayIndex(self, op, *, arg, index):

ibis/backends/sql/compilers/flink.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,20 +572,24 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
572572
def visit_StructColumn(self, op, *, names, values):
573573
return self.cast(sge.Struct(expressions=list(values)), op.dtype)
574574

575-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
575+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
576576
if order_by:
577577
raise com.UnsupportedOperationError(
578578
"ordering of order-sensitive aggregations via `order_by` is "
579579
"not supported for this backend"
580580
)
581-
# the only way to get filtering *and* respecting nulls is to use
582-
# `FILTER` syntax, but it's broken in various ways for other aggregates
583-
out = self.f.array_agg(arg)
584581
if not include_null:
585582
cond = arg.is_(sg.not_(NULL, copy=False))
586583
where = cond if where is None else sge.And(this=cond, expression=where)
584+
out = self.f.array_agg(arg)
587585
if where is not None:
588586
out = sge.Filter(this=out, expression=sge.Where(this=where))
587+
if distinct:
588+
# TODO: Flink supposedly supports `ARRAY_AGG(DISTINCT ...)`, but it
589+
# doesn't work with filtering (either `include_null=False` or
590+
# additional filtering). Their `array_distinct` function does maintain
591+
# ordering though, so we can use it here.
592+
out = self.f.array_distinct(out)
589593
return out
590594

591595
def visit_Strip(self, op, *, arg):

ibis/backends/sql/compilers/postgres.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,12 @@ def visit_ArrayIntersect(self, op, *, left, right):
372372
)
373373
)
374374

375-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
375+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
376376
if not include_null:
377377
cond = arg.is_(sg.not_(NULL, copy=False))
378378
where = cond if where is None else sge.And(this=cond, expression=where)
379+
if distinct:
380+
arg = sge.Distinct(expressions=[arg])
379381
return self.agg.array_agg(arg, where=where, order_by=order_by)
380382

381383
def visit_First(self, op, *, arg, where, order_by, include_null):

ibis/backends/sql/compilers/pyspark.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,16 @@ def visit_ArrayContains(self, op, *, arg, other):
432432
def visit_ArrayStringJoin(self, op, *, arg, sep):
433433
return self.f.concat_ws(sep, arg)
434434

435-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
435+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
436436
if include_null:
437437
raise com.UnsupportedOperationError(
438438
"`include_null=True` is not supported by the pyspark backend"
439439
)
440-
return self.agg.array_agg(arg, where=where, order_by=order_by)
440+
if where:
441+
arg = self.if_(where, arg, NULL)
442+
if distinct:
443+
arg = sge.Distinct(expressions=[arg])
444+
return self.agg.array_agg(arg, order_by=order_by)
441445

442446
def visit_StringFind(self, op, *, arg, substr, start, end):
443447
if end is not None:

ibis/backends/sql/compilers/snowflake.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,25 +452,36 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
452452
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
453453
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])
454454

455-
def _array_collect(self, *, arg, where, order_by, include_null):
455+
def _array_collect(self, *, arg, where, order_by, include_null, distinct=False):
456456
if include_null:
457457
raise com.UnsupportedOperationError(
458458
"`include_null=True` is not supported by the snowflake backend"
459459
)
460+
if where is not None and distinct:
461+
raise com.UnsupportedOperationError(
462+
"Combining `distinct=True` and `where` is not supported by snowflake"
463+
)
460464

461465
if where is not None:
462466
arg = self.if_(where, arg, NULL)
463467

468+
if distinct:
469+
arg = sge.Distinct(expressions=[arg])
470+
464471
out = self.f.array_agg(arg)
465472

466473
if order_by:
467474
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))
468475

469476
return out
470477

471-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
478+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
472479
return self._array_collect(
473-
arg=arg, where=where, order_by=order_by, include_null=include_null
480+
arg=arg,
481+
where=where,
482+
order_by=order_by,
483+
include_null=include_null,
484+
distinct=distinct,
474485
)
475486

476487
def visit_First(self, op, *, arg, where, order_by, include_null):

ibis/backends/sql/compilers/trino.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,12 @@ def visit_ArrayContains(self, op, *, arg, other):
182182
NULL,
183183
)
184184

185-
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
185+
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
186186
if not include_null:
187187
cond = arg.is_(sg.not_(NULL, copy=False))
188188
where = cond if where is None else sge.And(this=cond, expression=where)
189+
if distinct:
190+
arg = sge.Distinct(expressions=[arg])
189191
return self.agg.array_agg(arg, where=where, order_by=order_by)
190192

191193
def visit_JSONGetItem(self, op, *, arg, index):

ibis/backends/tests/test_aggregation.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
from datetime import date
45
from operator import methodcaller
56

@@ -1301,67 +1302,75 @@ def test_group_concat_ordered(alltypes, df, filtered):
13011302
assert result == expected
13021303

13031304

1304-
@pytest.mark.notimpl(
1305-
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
1306-
raises=com.OperationNotDefinedError,
1307-
)
1308-
@pytest.mark.notimpl(
1309-
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
1310-
)
1311-
@pytest.mark.parametrize("filtered", [True, False])
1312-
def test_collect_ordered(alltypes, df, filtered):
1313-
ibis_cond = (_.id % 13 == 0) if filtered else None
1314-
pd_cond = (df.id % 13 == 0) if filtered else True
1315-
result = (
1316-
alltypes.filter(_.bigint_col == 10)
1317-
.id.cast("str")
1318-
.collect(where=ibis_cond, order_by=_.id.desc())
1319-
.execute()
1320-
)
1321-
expected = list(
1322-
df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str)
1323-
)
1324-
assert result == expected
1305+
def gen_test_collect_marks(distinct, filtered, ordered, include_null):
1306+
"""The marks for this test fail for different combinations of parameters.
1307+
Rather than set `strict=False` (which can let bugs sneak through), we split
1308+
the mark generation into a function"""
1309+
if distinct:
1310+
yield pytest.mark.notimpl(["datafusion"], raises=com.UnsupportedOperationError)
1311+
if ordered:
1312+
yield pytest.mark.notimpl(
1313+
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
1314+
)
1315+
if include_null:
1316+
yield pytest.mark.notimpl(
1317+
["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError
1318+
)
1319+
1320+
# Handle special cases
1321+
if filtered and distinct:
1322+
yield pytest.mark.notimpl(
1323+
["bigquery", "snowflake"],
1324+
raises=com.UnsupportedOperationError,
1325+
reason="Can't combine where and distinct",
1326+
)
1327+
elif filtered and include_null:
1328+
yield pytest.mark.notimpl(
1329+
["bigquery"],
1330+
raises=com.UnsupportedOperationError,
1331+
reason="Can't combine where and include_null",
1332+
)
1333+
elif include_null:
1334+
yield pytest.mark.notimpl(
1335+
["bigquery"],
1336+
raises=GoogleBadRequest,
1337+
reason="BigQuery can't retrieve arrays with null values",
1338+
)
13251339

13261340

13271341
@pytest.mark.notimpl(
13281342
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
13291343
raises=com.OperationNotDefinedError,
13301344
)
1331-
@pytest.mark.parametrize("filtered", [True, False])
13321345
@pytest.mark.parametrize(
1333-
"include_null",
1346+
"distinct, filtered, ordered, include_null",
13341347
[
1335-
False,
1336-
param(
1337-
True,
1338-
marks=[
1339-
pytest.mark.notimpl(
1340-
["clickhouse", "pyspark", "snowflake"],
1341-
raises=com.UnsupportedOperationError,
1342-
reason="`include_null=True` is not supported",
1343-
),
1344-
pytest.mark.notimpl(
1345-
["bigquery"],
1346-
raises=com.UnsupportedOperationError,
1347-
reason="Can't mix `where` and `include_null=True`",
1348-
strict=False,
1349-
),
1350-
],
1351-
),
1348+
param(*ps, marks=list(gen_test_collect_marks(*ps)))
1349+
for ps in itertools.product(*([[True, False]] * 4))
13521350
],
13531351
)
1354-
def test_collect(alltypes, df, filtered, include_null):
1355-
ibis_cond = (_.id % 13 == 0) if filtered else None
1356-
pd_cond = (df.id % 13 == 0) if filtered else slice(None)
1357-
expr = (
1358-
alltypes.string_col.nullif("3")
1359-
.collect(where=ibis_cond, include_null=include_null)
1360-
.length()
1352+
def test_collect(alltypes, df, distinct, filtered, ordered, include_null):
1353+
expr = alltypes.mutate(x=_.string_col.nullif("3")).x.collect(
1354+
where=((_.id % 13 == 0) if filtered else None),
1355+
include_null=include_null,
1356+
distinct=distinct,
1357+
order_by=(_.x.desc() if ordered else ()),
13611358
)
13621359
res = expr.execute()
1363-
vals = df.string_col if include_null else df.string_col[df.string_col != "3"]
1364-
sol = len(vals[pd_cond])
1360+
1361+
x = df.string_col.where(df.string_col != "3", None)
1362+
if filtered:
1363+
x = x[df.id % 13 == 0]
1364+
if not include_null:
1365+
x = x.dropna()
1366+
if distinct:
1367+
x = x.drop_duplicates()
1368+
sol = sorted(x, key=lambda x: (x is not None, x), reverse=True)
1369+
1370+
if not ordered:
1371+
# If unordered, order afterwards so we can compare
1372+
res = sorted(res, key=lambda x: (x is not None, x), reverse=True)
1373+
13651374
assert res == sol
13661375

13671376

ibis/expr/operations/reductions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import ibis.expr.datashape as ds
1010
import ibis.expr.datatypes as dt
1111
import ibis.expr.rules as rlz
12-
from ibis.common.annotations import attribute
12+
from ibis.common.annotations import ValidationError, attribute
1313
from ibis.common.typing import VarTuple # noqa: TCH001
1414
from ibis.expr.operations.core import Column, Value
1515
from ibis.expr.operations.relations import Relation # noqa: TCH001
@@ -376,6 +376,15 @@ class ArrayCollect(Filterable, Reduction):
376376
arg: Column
377377
order_by: VarTuple[SortKey] = ()
378378
include_null: bool = False
379+
distinct: bool = False
380+
381+
def __init__(self, arg, order_by, distinct, **kwargs):
382+
if distinct and order_by and [arg] != [key.expr for key in order_by]:
383+
raise ValidationError(
384+
"`collect` with `order_by` and `distinct=True` and may only "
385+
"order by the collected column"
386+
)
387+
super().__init__(arg=arg, order_by=order_by, distinct=distinct, **kwargs)
379388

380389
@attribute
381390
def dtype(self):

ibis/expr/tests/test_reductions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ibis
77
import ibis.expr.operations as ops
88
from ibis import _
9+
from ibis.common.annotations import ValidationError
910
from ibis.common.deferred import Deferred
1011
from ibis.common.exceptions import IbisTypeError
1112

@@ -161,3 +162,22 @@ def test_ordered_aggregations_no_order(method):
161162
q3 = func(order_by=())
162163
assert q1.equals(q2)
163164
assert q1.equals(q3)
165+
166+
167+
def test_collect_distinct():
168+
t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t")
169+
# Fine
170+
t.a.collect(distinct=True)
171+
t.a.collect(distinct=True, order_by=t.a.desc())
172+
(t.a + 1).collect(distinct=True, order_by=(t.a + 1).desc())
173+
174+
with pytest.raises(ValidationError, match="only order by the collected column"):
175+
t.b.collect(distinct=True, order_by=t.a)
176+
with pytest.raises(ValidationError, match="only order by the collected column"):
177+
t.b.collect(
178+
distinct=True,
179+
order_by=(
180+
t.a,
181+
t.b,
182+
),
183+
)

0 commit comments

Comments
 (0)