Skip to content

Commit 6603c6c

Browse files
authored
feat(arrays): add modes array aggregation (#10737)
1 parent 847ed85 commit 6603c6c

File tree

5 files changed

+66
-1
lines changed

5 files changed

+66
-1
lines changed

ibis/backends/sql/compilers/duckdb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DuckDBCompiler(SQLGlotCompiler):
6262
ops.ArrayAll: "list_bool_and",
6363
ops.ArraySum: "list_sum",
6464
ops.ArrayMean: "list_avg",
65+
ops.ArrayMode: "list_mode",
6566
ops.BitAnd: "bit_and",
6667
ops.BitOr: "bit_or",
6768
ops.BitXor: "bit_xor",

ibis/backends/sql/compilers/postgres.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,14 @@ def visit_ArrayAny(self, op, *, arg):
831831
def visit_ArrayAll(self, op, *, arg):
832832
return self._array_reduction(arg=arg, reduction="bool_and")
833833

834+
def visit_ArrayMode(self, op, *, arg):
835+
name = sg.to_identifier(gen_name("pg_arr_mode"))
836+
expr = sge.WithinGroup(
837+
this=self.f.mode(),
838+
expression=sge.Order(expressions=[sge.Ordered(this=name)]),
839+
)
840+
return sg.select(expr).from_(self._unnest(arg, as_=name)).subquery()
841+
834842
def visit_StringToTime(self, op, *, arg, format_str):
835843
return self.cast(self.f.str_to_time(arg, format_str), to=dt.time)
836844

ibis/backends/tests/test_array.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1816,8 +1816,32 @@ def _agg_with_nulls(agg, x):
18161816
(ir.ArrayValue.mins, lambda x: _agg_with_nulls(min, x)),
18171817
(ir.ArrayValue.maxs, lambda x: _agg_with_nulls(max, x)),
18181818
(ir.ArrayValue.means, lambda x: _agg_with_nulls(statistics.mean, x)),
1819+
param(
1820+
ir.ArrayValue.modes,
1821+
lambda x: _agg_with_nulls(statistics.mode, x),
1822+
marks=[
1823+
pytest.mark.notyet(
1824+
[
1825+
"athena",
1826+
"bigquery",
1827+
"clickhouse",
1828+
"databricks",
1829+
"polars",
1830+
"pyspark",
1831+
"trino",
1832+
],
1833+
raises=com.OperationNotDefinedError,
1834+
reason="no mode aggregate in the engine",
1835+
),
1836+
pytest.mark.notimpl(
1837+
["snowflake"],
1838+
raises=com.OperationNotDefinedError,
1839+
reason="not yet implemented in Ibis",
1840+
),
1841+
],
1842+
),
18191843
],
1820-
ids=["sums", "mins", "maxs", "means"],
1844+
ids=["sums", "mins", "maxs", "means", "modes"],
18211845
)
18221846
@notimpl_aggs
18231847
@pytest.mark.parametrize(

ibis/expr/operations/arrays.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ class ArrayMax(ArrayAgg):
292292
"""Compute the maximum value of an array."""
293293

294294

295+
@public
296+
class ArrayMode(ArrayAgg):
297+
"""Compute the mode of an array."""
298+
299+
295300
# in duckdb summing an array of ints leads to an int, but for other backends
296301
# it might lead to a float??
297302
@public

ibis/expr/types/arrays.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,33 @@ def alls(self) -> ir.BooleanValue:
12081208
"""
12091209
return ops.ArrayAll(self).to_expr()
12101210

1211+
def modes(self) -> ir.Value:
1212+
"""Return the mode of the values in the array.
1213+
1214+
See Also
1215+
--------
1216+
[`Column.mode`](./expression-generic.qmd#ibis.expr.types.generic.Column.mode)
1217+
1218+
Examples
1219+
--------
1220+
>>> import ibis
1221+
>>> ibis.options.interactive = True
1222+
>>> t = ibis.memtable({"arr": [[1, 2, 3, 3], [None, 6], [None], [], None]})
1223+
>>> t.mutate(mode=t.arr.modes())
1224+
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┓
1225+
┃ arr ┃ mode ┃
1226+
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━┩
1227+
│ array<int64> │ int64 │
1228+
├──────────────────────┼───────┤
1229+
│ [1, 2, ... +2] │ 3 │
1230+
│ [None, 6] │ 6 │
1231+
│ [None] │ NULL │
1232+
│ [] │ NULL │
1233+
│ NULL │ NULL │
1234+
└──────────────────────┴───────┘
1235+
"""
1236+
return ops.ArrayMode(self).to_expr()
1237+
12111238
def mins(self) -> ir.NumericValue:
12121239
"""Return the minimum value in the array.
12131240

0 commit comments

Comments
 (0)