|
4 | 4 | import datetime
|
5 | 5 | import decimal
|
6 | 6 | import numbers
|
| 7 | +from operator import methodcaller |
7 | 8 |
|
8 | 9 | import dask.array as da
|
9 | 10 | import dask.dataframe as dd
|
@@ -214,6 +215,50 @@ def execute_arbitrary_series_groupby(op, data, _, aggcontext=None, **kwargs):
|
214 | 215 | return aggcontext.agg(data, how)
|
215 | 216 |
|
216 | 217 |
|
| 218 | +def _mode_agg(df): |
| 219 | + return df.sum().sort_values(ascending=False).index[0] |
| 220 | + |
| 221 | + |
| 222 | +@execute_node.register(ops.Mode, dd.Series, (dd.Series, type(None))) |
| 223 | +def execute_mode_series(_, data, mask, **kwargs): |
| 224 | + if mask is not None: |
| 225 | + data = data[mask] |
| 226 | + return data.reduction( |
| 227 | + chunk=methodcaller("value_counts"), |
| 228 | + combine=methodcaller("sum"), |
| 229 | + aggregate=_mode_agg, |
| 230 | + meta=data.dtype, |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +def _grouped_mode_agg(gb): |
| 235 | + return gb.obj.groupby(gb.obj.index.names).sum() |
| 236 | + |
| 237 | + |
| 238 | +def _grouped_mode_finalize(series): |
| 239 | + counts = "__counts__" |
| 240 | + values = series.index.names[-1] |
| 241 | + df = series.reset_index(-1, name=counts) |
| 242 | + out = df.groupby(df.index.names).apply( |
| 243 | + lambda g: g.sort_values(counts, ascending=False).iloc[0] |
| 244 | + ) |
| 245 | + return out[values] |
| 246 | + |
| 247 | + |
| 248 | +@execute_node.register(ops.Mode, ddgb.SeriesGroupBy, (ddgb.SeriesGroupBy, type(None))) |
| 249 | +def execute_mode_series_group_by(_, data, mask, **kwargs): |
| 250 | + if mask is not None: |
| 251 | + data = data[mask] |
| 252 | + return data.agg( |
| 253 | + dd.Aggregation( |
| 254 | + name="mode", |
| 255 | + chunk=methodcaller("value_counts"), |
| 256 | + agg=_grouped_mode_agg, |
| 257 | + finalize=_grouped_mode_finalize, |
| 258 | + ) |
| 259 | + ) |
| 260 | + |
| 261 | + |
217 | 262 | @execute_node.register(ops.Cast, ddgb.SeriesGroupBy, dt.DataType)
|
218 | 263 | def execute_cast_series_group_by(op, data, type, **kwargs):
|
219 | 264 | result = execute_cast_series_generic(op, make_selected_obj(data), type, **kwargs)
|
|
0 commit comments