Skip to content

Commit 45c9d9f

Browse files
authored
feat: add on parameter in dataframe.rolling() and dataframe.groupby.rolling() (#1556)
* feat: add parameter in dataframe/groupby.rolling() * fix test * add docs * update doc * update docs * remove on param from windowspec
1 parent b5297f9 commit 45c9d9f

File tree

10 files changed

+167
-75
lines changed

10 files changed

+167
-75
lines changed

bigframes/core/blocks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ def apply_nary_op(
987987
def multi_apply_window_op(
988988
self,
989989
columns: typing.Sequence[str],
990-
op: agg_ops.WindowOp,
990+
op: agg_ops.UnaryWindowOp,
991991
window_spec: windows.WindowSpec,
992992
*,
993993
skip_null_groups: bool = False,
@@ -1058,7 +1058,7 @@ def project_exprs(
10581058
def apply_window_op(
10591059
self,
10601060
column: str,
1061-
op: agg_ops.WindowOp,
1061+
op: agg_ops.UnaryWindowOp,
10621062
window_spec: windows.WindowSpec,
10631063
*,
10641064
result_label: Label = None,

bigframes/core/groupby/dataframe_group_by.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def rolling(
310310
self,
311311
window: int,
312312
min_periods=None,
313+
on: str | None = None,
313314
closed: Literal["right", "left", "both", "neither"] = "right",
314315
) -> windows.Window:
315316
window_spec = window_specs.WindowSpec(
@@ -320,8 +321,15 @@ def rolling(
320321
block = self._block.order_by(
321322
[order.ascending_over(col) for col in self._by_col_ids],
322323
)
324+
skip_agg_col_id = (
325+
None if on is None else self._block.resolve_label_exact_or_error(on)
326+
)
323327
return windows.Window(
324-
block, window_spec, self._selected_cols, drop_null_groups=self._dropna
328+
block,
329+
window_spec,
330+
self._selected_cols,
331+
drop_null_groups=self._dropna,
332+
skip_agg_column_id=skip_agg_col_id,
325333
)
326334

327335
@validations.requires_ordering()
@@ -511,7 +519,7 @@ def _aggregate_all(
511519

512520
def _apply_window_op(
513521
self,
514-
op: agg_ops.WindowOp,
522+
op: agg_ops.UnaryWindowOp,
515523
window: typing.Optional[window_specs.WindowSpec] = None,
516524
numeric_only: bool = False,
517525
):

bigframes/core/groupby/series_group_by.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _aggregate(self, aggregate_op: agg_ops.UnaryAggregateOp) -> series.Series:
294294

295295
def _apply_window_op(
296296
self,
297-
op: agg_ops.WindowOp,
297+
op: agg_ops.UnaryWindowOp,
298298
discard_name=False,
299299
window: typing.Optional[window_specs.WindowSpec] = None,
300300
never_skip_nulls: bool = False,

bigframes/core/window/rolling.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ def __init__(
3434
value_column_ids: typing.Sequence[str],
3535
drop_null_groups: bool = True,
3636
is_series: bool = False,
37+
skip_agg_column_id: str | None = None,
3738
):
3839
self._block = block
3940
self._window_spec = window_spec
4041
self._value_column_ids = value_column_ids
4142
self._drop_null_groups = drop_null_groups
4243
self._is_series = is_series
44+
# The column ID that won't be aggregated on.
45+
# This is equivalent to pandas `on` parameter in rolling()
46+
self._skip_agg_column_id = skip_agg_column_id
4347

4448
def count(self):
4549
return self._apply_aggregate(agg_ops.count_op)
@@ -66,10 +70,37 @@ def _apply_aggregate(
6670
self,
6771
op: agg_ops.UnaryAggregateOp,
6872
):
69-
block = self._block
70-
labels = [block.col_id_to_label[col] for col in self._value_column_ids]
71-
block, result_ids = block.multi_apply_window_op(
72-
self._value_column_ids,
73+
agg_col_ids = [
74+
col_id
75+
for col_id in self._value_column_ids
76+
if col_id != self._skip_agg_column_id
77+
]
78+
agg_block = self._aggregate_block(op, agg_col_ids)
79+
80+
if self._skip_agg_column_id is not None:
81+
# Concat the skipped column to the result.
82+
agg_block, _ = agg_block.join(
83+
self._block.select_column(self._skip_agg_column_id), how="outer"
84+
)
85+
86+
if self._is_series:
87+
from bigframes.series import Series
88+
89+
return Series(agg_block)
90+
else:
91+
from bigframes.dataframe import DataFrame
92+
93+
# Preserve column order.
94+
column_labels = [
95+
self._block.col_id_to_label[col_id] for col_id in self._value_column_ids
96+
]
97+
return DataFrame(agg_block)._reindex_columns(column_labels)
98+
99+
def _aggregate_block(
100+
self, op: agg_ops.UnaryAggregateOp, agg_col_ids: typing.List[str]
101+
) -> blocks.Block:
102+
block, result_ids = self._block.multi_apply_window_op(
103+
agg_col_ids,
73104
op,
74105
self._window_spec,
75106
skip_null_groups=self._drop_null_groups,
@@ -85,13 +116,5 @@ def _apply_aggregate(
85116
)
86117
block = block.set_index(col_ids=index_ids)
87118

88-
if self._is_series:
89-
from bigframes.series import Series
90-
91-
return Series(block.select_columns(result_ids).with_column_labels(labels))
92-
else:
93-
from bigframes.dataframe import DataFrame
94-
95-
return DataFrame(
96-
block.select_columns(result_ids).with_column_labels(labels)
97-
)
119+
labels = [self._block.col_id_to_label[col] for col in agg_col_ids]
120+
return block.select_columns(result_ids).with_column_labels(labels)

bigframes/core/window_spec.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,12 @@ def __post_init__(self):
187187
class WindowSpec:
188188
"""
189189
Specifies a window over which aggregate and analytic function may be applied.
190-
grouping_keys: set of column ids to group on
191-
preceding: Number of preceding rows in the window
192-
following: Number of preceding rows in the window
193-
ordering: List of columns ids and ordering direction to override base ordering
190+
191+
Attributes:
192+
grouping_keys: A set of column ids to group on
193+
bounds: The window boundaries
194+
ordering: A list of columns ids and ordering direction to override base ordering
195+
min_periods: The minimum number of observations in window required to have a value
194196
"""
195197

196198
grouping_keys: Tuple[ex.DerefOp, ...] = tuple()

bigframes/dataframe.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -3312,14 +3312,21 @@ def rolling(
33123312
self,
33133313
window: int,
33143314
min_periods=None,
3315+
on: str | None = None,
33153316
closed: Literal["right", "left", "both", "neither"] = "right",
33163317
) -> bigframes.core.window.Window:
33173318
window_def = windows.WindowSpec(
33183319
bounds=windows.RowsWindowBounds.from_window_size(window, closed),
33193320
min_periods=min_periods if min_periods is not None else window,
33203321
)
3322+
skip_agg_col_id = (
3323+
None if on is None else self._block.resolve_label_exact_or_error(on)
3324+
)
33213325
return bigframes.core.window.Window(
3322-
self._block, window_def, self._block.value_columns
3326+
self._block,
3327+
window_def,
3328+
self._block.value_columns,
3329+
skip_agg_column_id=skip_agg_col_id,
33233330
)
33243331

33253332
@validations.requires_ordering()
@@ -3483,7 +3490,7 @@ def pct_change(self, periods: int = 1) -> DataFrame:
34833490

34843491
def _apply_window_op(
34853492
self,
3486-
op: agg_ops.WindowOp,
3493+
op: agg_ops.UnaryWindowOp,
34873494
window_spec: windows.WindowSpec,
34883495
):
34893496
block, result_ids = self._block.multi_apply_window_op(

bigframes/series.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,9 @@ def _apply_aggregation(
13781378
) -> Any:
13791379
return self._block.get_stat(self._value_column, op)
13801380

1381-
def _apply_window_op(self, op: agg_ops.WindowOp, window_spec: windows.WindowSpec):
1381+
def _apply_window_op(
1382+
self, op: agg_ops.UnaryWindowOp, window_spec: windows.WindowSpec
1383+
):
13821384
block = self._block
13831385
block, result_id = block.apply_window_op(
13841386
self._value_column, op, window_spec=window_spec, result_label=self.name

tests/system/small/test_window.py

+52-35
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@
2020
def rolling_dfs(scalars_dfs):
2121
bf_df, pd_df = scalars_dfs
2222

23-
target_cols = ["int64_too", "float64_col", "bool_col"]
23+
target_cols = ["int64_too", "float64_col", "int64_col"]
2424

25-
bf_df = bf_df[target_cols].set_index("bool_col")
26-
pd_df = pd_df[target_cols].set_index("bool_col")
27-
28-
return bf_df, pd_df
25+
return bf_df[target_cols], pd_df[target_cols]
2926

3027

3128
@pytest.fixture(scope="module")
@@ -49,31 +46,65 @@ def test_dataframe_rolling_closed_param(rolling_dfs, closed):
4946
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
5047
def test_dataframe_groupby_rolling_closed_param(rolling_dfs, closed):
5148
bf_df, pd_df = rolling_dfs
49+
# Need to specify column subset for comparison due to b/406841327
50+
check_columns = ["float64_col", "int64_col"]
5251

5352
actual_result = (
54-
bf_df.groupby(level=0).rolling(window=3, closed=closed).sum().to_pandas()
53+
bf_df.groupby(bf_df["int64_too"] % 2)
54+
.rolling(window=3, closed=closed)
55+
.sum()
56+
.to_pandas()
5557
)
5658

57-
expected_result = pd_df.groupby(level=0).rolling(window=3, closed=closed).sum()
58-
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
59+
expected_result = (
60+
pd_df.groupby(pd_df["int64_too"] % 2).rolling(window=3, closed=closed).sum()
61+
)
62+
pd.testing.assert_frame_equal(
63+
actual_result[check_columns], expected_result, check_dtype=False
64+
)
5965

6066

61-
def test_dataframe_rolling_default_closed_param(rolling_dfs):
67+
def test_dataframe_rolling_on(rolling_dfs):
6268
bf_df, pd_df = rolling_dfs
6369

64-
actual_result = bf_df.rolling(window=3).sum().to_pandas()
70+
actual_result = bf_df.rolling(window=3, on="int64_too").sum().to_pandas()
6571

66-
expected_result = pd_df.rolling(window=3).sum()
72+
expected_result = pd_df.rolling(window=3, on="int64_too").sum()
6773
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
6874

6975

70-
def test_dataframe_groupby_rolling_default_closed_param(rolling_dfs):
76+
def test_dataframe_rolling_on_invalid_column_raise_error(rolling_dfs):
77+
bf_df, _ = rolling_dfs
78+
79+
with pytest.raises(ValueError):
80+
bf_df.rolling(window=3, on="whatever").sum()
81+
82+
83+
def test_dataframe_groupby_rolling_on(rolling_dfs):
7184
bf_df, pd_df = rolling_dfs
85+
# Need to specify column subset for comparison due to b/406841327
86+
check_columns = ["float64_col", "int64_col"]
7287

73-
actual_result = bf_df.groupby(level=0).rolling(window=3).sum().to_pandas()
88+
actual_result = (
89+
bf_df.groupby(bf_df["int64_too"] % 2)
90+
.rolling(window=3, on="float64_col")
91+
.sum()
92+
.to_pandas()
93+
)
7494

75-
expected_result = pd_df.groupby(level=0).rolling(window=3).sum()
76-
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
95+
expected_result = (
96+
pd_df.groupby(pd_df["int64_too"] % 2).rolling(window=3, on="float64_col").sum()
97+
)
98+
pd.testing.assert_frame_equal(
99+
actual_result[check_columns], expected_result, check_dtype=False
100+
)
101+
102+
103+
def test_dataframe_groupby_rolling_on_invalid_column_raise_error(rolling_dfs):
104+
bf_df, _ = rolling_dfs
105+
106+
with pytest.raises(ValueError):
107+
bf_df.groupby(level=0).rolling(window=3, on="whatever").sum()
77108

78109

79110
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
@@ -103,24 +134,6 @@ def test_series_groupby_rolling_closed_param(rolling_series, closed):
103134
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
104135

105136

106-
def test_series_rolling_default_closed_param(rolling_series):
107-
bf_series, df_series = rolling_series
108-
109-
actual_result = bf_series.rolling(window=3).sum().to_pandas()
110-
111-
expected_result = df_series.rolling(window=3).sum()
112-
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
113-
114-
115-
def test_series_groupby_rolling_default_closed_param(rolling_series):
116-
bf_series, df_series = rolling_series
117-
118-
actual_result = bf_series.groupby(bf_series % 2).rolling(window=3).sum().to_pandas()
119-
120-
expected_result = df_series.groupby(df_series % 2).rolling(window=3).sum()
121-
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)
122-
123-
124137
@pytest.mark.parametrize(
125138
("windowing"),
126139
[
@@ -181,8 +194,12 @@ def test_series_window_agg_ops(rolling_series, windowing, agg_op):
181194
pytest.param(lambda x: x.var(), id="var"),
182195
],
183196
)
184-
def test_dataframe_window_agg_ops(rolling_dfs, windowing, agg_op):
185-
bf_df, pd_df = rolling_dfs
197+
def test_dataframe_window_agg_ops(scalars_dfs, windowing, agg_op):
198+
bf_df, pd_df = scalars_dfs
199+
target_columns = ["int64_too", "float64_col", "bool_col"]
200+
index_column = "bool_col"
201+
bf_df = bf_df[target_columns].set_index(index_column)
202+
pd_df = pd_df[target_columns].set_index(index_column)
186203

187204
bf_result = agg_op(windowing(bf_df)).to_pandas()
188205

0 commit comments

Comments
 (0)