Skip to content

Commit 6c2a9b0

Browse files
feat: add axis param to simple df aggregations
1 parent 33274c2 commit 6c2a9b0

File tree

4 files changed

+143
-42
lines changed

4 files changed

+143
-42
lines changed

bigframes/core/blocks.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -822,22 +822,54 @@ def filter(self, column_id: str, keep_null: bool = False):
822822
index_labels=self.index.names,
823823
)
824824

825-
def aggregate_all_and_pivot(
825+
def aggregate_all_and_stack(
826826
self,
827827
operation: agg_ops.AggregateOp,
828828
*,
829+
axis: int | str = 0,
829830
value_col_id: str = "values",
830831
dropna: bool = True,
831832
dtype=pd.Float64Dtype(),
832833
) -> Block:
833-
aggregations = [(col_id, operation, col_id) for col_id in self.value_columns]
834-
result_expr = self.expr.aggregate(aggregations, dropna=dropna).unpivot(
835-
row_labels=self.column_labels.to_list(),
836-
index_col_id="index",
837-
unpivot_columns=[(value_col_id, self.value_columns)],
838-
dtype=dtype,
839-
)
840-
return Block(result_expr, index_columns=["index"], column_labels=[None])
834+
axis_n = utils.get_axis_number(axis)
835+
if axis_n == 0:
836+
aggregations = [
837+
(col_id, operation, col_id) for col_id in self.value_columns
838+
]
839+
result_expr = self.expr.aggregate(aggregations, dropna=dropna).unpivot(
840+
row_labels=self.column_labels.to_list(),
841+
index_col_id="index",
842+
unpivot_columns=[(value_col_id, self.value_columns)],
843+
dtype=dtype,
844+
)
845+
return Block(result_expr, index_columns=["index"], column_labels=[None])
846+
else: # axis_n == 1
847+
# using offsets as identity to group on.
848+
# TODO: Allow to promote identity/total_order columns instead for better perf
849+
expr_with_offsets, offset_col = self.expr.promote_offsets()
850+
stacked_expr = expr_with_offsets.unpivot(
851+
row_labels=self.column_labels.to_list(),
852+
index_col_id=guid.generate_guid(),
853+
unpivot_columns=[(value_col_id, self.value_columns)],
854+
passthrough_columns=[*self.index_columns, offset_col],
855+
dtype=dtype,
856+
)
857+
index_aggregations = [
858+
(col_id, agg_ops.AnyValueOp(), col_id)
859+
for col_id in [*self.index_columns]
860+
]
861+
main_aggregation = (value_col_id, operation, value_col_id)
862+
result_expr = stacked_expr.aggregate(
863+
[*index_aggregations, main_aggregation],
864+
by_column_ids=[offset_col],
865+
dropna=dropna,
866+
)
867+
return Block(
868+
result_expr.drop_columns([offset_col]),
869+
self.index_columns,
870+
column_labels=[None],
871+
index_labels=self.index_labels,
872+
)
841873

842874
def select_column(self, id: str) -> Block:
843875
return self.select_columns([id])

bigframes/dataframe.py

+40-23
Original file line numberDiff line numberDiff line change
@@ -1462,41 +1462,48 @@ def dropna(
14621462
def any(
14631463
self,
14641464
*,
1465+
axis: typing.Union[str, int] = 0,
14651466
bool_only: bool = False,
14661467
) -> bigframes.series.Series:
14671468
if not bool_only:
14681469
frame = self._raise_on_non_boolean("any")
14691470
else:
14701471
frame = self._drop_non_bool()
1471-
block = frame._block.aggregate_all_and_pivot(
1472-
agg_ops.any_op, dtype=pandas.BooleanDtype()
1472+
block = frame._block.aggregate_all_and_stack(
1473+
agg_ops.any_op, dtype=pandas.BooleanDtype(), axis=axis
14731474
)
14741475
return bigframes.series.Series(block.select_column("values"))
14751476

1476-
def all(self, *, bool_only: bool = False) -> bigframes.series.Series:
1477+
def all(
1478+
self, axis: typing.Union[str, int] = 0, *, bool_only: bool = False
1479+
) -> bigframes.series.Series:
14771480
if not bool_only:
14781481
frame = self._raise_on_non_boolean("all")
14791482
else:
14801483
frame = self._drop_non_bool()
1481-
block = frame._block.aggregate_all_and_pivot(
1482-
agg_ops.all_op, dtype=pandas.BooleanDtype()
1484+
block = frame._block.aggregate_all_and_stack(
1485+
agg_ops.all_op, dtype=pandas.BooleanDtype(), axis=axis
14831486
)
14841487
return bigframes.series.Series(block.select_column("values"))
14851488

1486-
def sum(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1489+
def sum(
1490+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1491+
) -> bigframes.series.Series:
14871492
if not numeric_only:
14881493
frame = self._raise_on_non_numeric("sum")
14891494
else:
14901495
frame = self._drop_non_numeric()
1491-
block = frame._block.aggregate_all_and_pivot(agg_ops.sum_op)
1496+
block = frame._block.aggregate_all_and_stack(agg_ops.sum_op, axis=axis)
14921497
return bigframes.series.Series(block.select_column("values"))
14931498

1494-
def mean(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1499+
def mean(
1500+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1501+
) -> bigframes.series.Series:
14951502
if not numeric_only:
14961503
frame = self._raise_on_non_numeric("mean")
14971504
else:
14981505
frame = self._drop_non_numeric()
1499-
block = frame._block.aggregate_all_and_pivot(agg_ops.mean_op)
1506+
block = frame._block.aggregate_all_and_stack(agg_ops.mean_op, axis=axis)
15001507
return bigframes.series.Series(block.select_column("values"))
15011508

15021509
def median(
@@ -1510,47 +1517,57 @@ def median(
15101517
frame = self._raise_on_non_numeric("median")
15111518
else:
15121519
frame = self._drop_non_numeric()
1513-
block = frame._block.aggregate_all_and_pivot(agg_ops.median_op)
1520+
block = frame._block.aggregate_all_and_stack(agg_ops.median_op)
15141521
return bigframes.series.Series(block.select_column("values"))
15151522

1516-
def std(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1523+
def std(
1524+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1525+
) -> bigframes.series.Series:
15171526
if not numeric_only:
15181527
frame = self._raise_on_non_numeric("std")
15191528
else:
15201529
frame = self._drop_non_numeric()
1521-
block = frame._block.aggregate_all_and_pivot(agg_ops.std_op)
1530+
block = frame._block.aggregate_all_and_stack(agg_ops.std_op, axis=axis)
15221531
return bigframes.series.Series(block.select_column("values"))
15231532

1524-
def var(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1533+
def var(
1534+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1535+
) -> bigframes.series.Series:
15251536
if not numeric_only:
15261537
frame = self._raise_on_non_numeric("var")
15271538
else:
15281539
frame = self._drop_non_numeric()
1529-
block = frame._block.aggregate_all_and_pivot(agg_ops.var_op)
1540+
block = frame._block.aggregate_all_and_stack(agg_ops.var_op, axis=axis)
15301541
return bigframes.series.Series(block.select_column("values"))
15311542

1532-
def min(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1543+
def min(
1544+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1545+
) -> bigframes.series.Series:
15331546
if not numeric_only:
15341547
frame = self._raise_on_non_numeric("min")
15351548
else:
15361549
frame = self._drop_non_numeric()
1537-
block = frame._block.aggregate_all_and_pivot(agg_ops.min_op)
1550+
block = frame._block.aggregate_all_and_stack(agg_ops.min_op, axis=axis)
15381551
return bigframes.series.Series(block.select_column("values"))
15391552

1540-
def max(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1553+
def max(
1554+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1555+
) -> bigframes.series.Series:
15411556
if not numeric_only:
15421557
frame = self._raise_on_non_numeric("max")
15431558
else:
15441559
frame = self._drop_non_numeric()
1545-
block = frame._block.aggregate_all_and_pivot(agg_ops.max_op)
1560+
block = frame._block.aggregate_all_and_stack(agg_ops.max_op, axis=axis)
15461561
return bigframes.series.Series(block.select_column("values"))
15471562

1548-
def prod(self, *, numeric_only: bool = False) -> bigframes.series.Series:
1563+
def prod(
1564+
self, axis: typing.Union[str, int] = 0, *, numeric_only: bool = False
1565+
) -> bigframes.series.Series:
15491566
if not numeric_only:
15501567
frame = self._raise_on_non_numeric("prod")
15511568
else:
15521569
frame = self._drop_non_numeric()
1553-
block = frame._block.aggregate_all_and_pivot(agg_ops.product_op)
1570+
block = frame._block.aggregate_all_and_stack(agg_ops.product_op, axis=axis)
15541571
return bigframes.series.Series(block.select_column("values"))
15551572

15561573
product = prod
@@ -1560,11 +1577,11 @@ def count(self, *, numeric_only: bool = False) -> bigframes.series.Series:
15601577
frame = self
15611578
else:
15621579
frame = self._drop_non_numeric()
1563-
block = frame._block.aggregate_all_and_pivot(agg_ops.count_op)
1580+
block = frame._block.aggregate_all_and_stack(agg_ops.count_op)
15641581
return bigframes.series.Series(block.select_column("values"))
15651582

15661583
def nunique(self) -> bigframes.series.Series:
1567-
block = self._block.aggregate_all_and_pivot(agg_ops.nunique_op)
1584+
block = self._block.aggregate_all_and_stack(agg_ops.nunique_op)
15681585
return bigframes.series.Series(block.select_column("values"))
15691586

15701587
def agg(
@@ -1587,7 +1604,7 @@ def agg(
15871604
)
15881605
else:
15891606
return bigframes.series.Series(
1590-
self._block.aggregate_all_and_pivot(
1607+
self._block.aggregate_all_and_stack(
15911608
agg_ops.lookup_agg_func(typing.cast(str, func))
15921609
)
15931610
)

tests/system/small/test_dataframe.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,29 @@ def test_dataframe_aggregates(scalars_df_index, scalars_pandas_df_index, op):
19991999
pd.testing.assert_series_equal(pd_series, bf_result, check_index_type=False)
20002000

20012001

2002+
@pytest.mark.parametrize(
2003+
("op"),
2004+
[
2005+
(lambda x: x.sum(axis=1, numeric_only=True)),
2006+
(lambda x: x.mean(axis=1, numeric_only=True)),
2007+
(lambda x: x.min(axis=1, numeric_only=True)),
2008+
(lambda x: x.max(axis=1, numeric_only=True)),
2009+
(lambda x: x.std(axis=1, numeric_only=True)),
2010+
(lambda x: x.var(axis=1, numeric_only=True)),
2011+
],
2012+
ids=["sum", "mean", "min", "max", "std", "var"],
2013+
)
2014+
def test_dataframe_aggregates_axis_1(scalars_df_index, scalars_pandas_df_index, op):
2015+
col_names = ["int64_too", "int64_col", "float64_col", "bool_col", "string_col"]
2016+
bf_result = op(scalars_df_index[col_names]).to_pandas()
2017+
pd_result = op(scalars_pandas_df_index[col_names])
2018+
2019+
# Pandas may produce narrower numeric types, but bigframes always produces Float64
2020+
pd_result = pd_result.astype("Float64")
2021+
# Pandas has object index type
2022+
pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False)
2023+
2024+
20022025
def test_dataframe_aggregates_median(scalars_df_index, scalars_pandas_df_index):
20032026
col_names = ["int64_too", "float64_col", "int64_col", "bool_col"]
20042027
bf_result = scalars_df_index[col_names].median(numeric_only=True).to_pandas()
@@ -2019,11 +2042,16 @@ def test_dataframe_aggregates_median(scalars_df_index, scalars_pandas_df_index):
20192042
[
20202043
(lambda x: x.all(bool_only=True)),
20212044
(lambda x: x.any(bool_only=True)),
2045+
(lambda x: x.all(axis=1, bool_only=True)),
2046+
(lambda x: x.any(axis=1, bool_only=True)),
20222047
],
2023-
ids=["all", "any"],
2048+
ids=["all_axis0", "any_axis0", "all_axis1", "any_axis1"],
20242049
)
20252050
def test_dataframe_bool_aggregates(scalars_df_index, scalars_pandas_df_index, op):
20262051
# Pandas will drop nullable 'boolean' dtype so we convert first to bool, then cast back later
2052+
scalars_df_index = scalars_df_index.assign(
2053+
bool_col=scalars_pandas_df_index.bool_col.fillna(False)
2054+
)
20272055
scalars_pandas_df_index = scalars_pandas_df_index.assign(
20282056
bool_col=scalars_pandas_df_index.bool_col.fillna(False).astype("bool")
20292057
)

0 commit comments

Comments
 (0)