Skip to content

Commit 017f07a

Browse files
jcristcpcloud
authored andcommitted
feat(dask): implement mode aggregation
1 parent fc023b5 commit 017f07a

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

ibis/backends/dask/execution/generic.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import datetime
55
import decimal
66
import numbers
7+
from operator import methodcaller
78

89
import dask.array as da
910
import dask.dataframe as dd
@@ -214,6 +215,50 @@ def execute_arbitrary_series_groupby(op, data, _, aggcontext=None, **kwargs):
214215
return aggcontext.agg(data, how)
215216

216217

218+
def _mode_agg(df):
219+
return df.sum().sort_values(ascending=False).index[0]
220+
221+
222+
@execute_node.register(ops.Mode, dd.Series, (dd.Series, type(None)))
223+
def execute_mode_series(_, data, mask, **kwargs):
224+
if mask is not None:
225+
data = data[mask]
226+
return data.reduction(
227+
chunk=methodcaller("value_counts"),
228+
combine=methodcaller("sum"),
229+
aggregate=_mode_agg,
230+
meta=data.dtype,
231+
)
232+
233+
234+
def _grouped_mode_agg(gb):
235+
return gb.obj.groupby(gb.obj.index.names).sum()
236+
237+
238+
def _grouped_mode_finalize(series):
239+
counts = "__counts__"
240+
values = series.index.names[-1]
241+
df = series.reset_index(-1, name=counts)
242+
out = df.groupby(df.index.names).apply(
243+
lambda g: g.sort_values(counts, ascending=False).iloc[0]
244+
)
245+
return out[values]
246+
247+
248+
@execute_node.register(ops.Mode, ddgb.SeriesGroupBy, (ddgb.SeriesGroupBy, type(None)))
249+
def execute_mode_series_group_by(_, data, mask, **kwargs):
250+
if mask is not None:
251+
data = data[mask]
252+
return data.agg(
253+
dd.Aggregation(
254+
name="mode",
255+
chunk=methodcaller("value_counts"),
256+
agg=_grouped_mode_agg,
257+
finalize=_grouped_mode_finalize,
258+
)
259+
)
260+
261+
217262
@execute_node.register(ops.Cast, ddgb.SeriesGroupBy, dt.DataType)
218263
def execute_cast_series_group_by(op, data, type, **kwargs):
219264
result = execute_cast_series_generic(op, make_selected_obj(data), type, **kwargs)

ibis/backends/tests/test_aggregation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def mean_udf(s):
5757
marks=pytest.mark.notyet(
5858
[
5959
"clickhouse",
60-
"dask",
6160
"datafusion",
6261
"impala",
6362
"mysql",
@@ -273,7 +272,6 @@ def mean_and_std(v):
273272
marks=pytest.mark.notyet(
274273
[
275274
"clickhouse",
276-
"dask",
277275
"datafusion",
278276
"impala",
279277
"mysql",

0 commit comments

Comments
 (0)