|
27 | 27 | import os
|
28 | 28 | import random
|
29 | 29 | import typing
|
30 |
| -from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple |
| 30 | +from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union |
31 | 31 | import warnings
|
32 | 32 |
|
33 | 33 | import google.cloud.bigquery as bigquery
|
@@ -105,6 +105,8 @@ def __init__(
|
105 | 105 | index_columns: Iterable[str],
|
106 | 106 | column_labels: typing.Union[pd.Index, typing.Iterable[Label]],
|
107 | 107 | index_labels: typing.Union[pd.Index, typing.Iterable[Label], None] = None,
|
| 108 | + *, |
| 109 | + transpose_cache: Optional[Block] = None, |
108 | 110 | ):
|
109 | 111 | """Construct a block object, will create default index if no index columns specified."""
|
110 | 112 | index_columns = list(index_columns)
|
@@ -144,6 +146,7 @@ def __init__(
|
144 | 146 | # TODO(kemppeterson) Add a cache for corr to parallel the single-column stats.
|
145 | 147 |
|
146 | 148 | self._stats_cache[" ".join(self.index_columns)] = {}
|
| 149 | + self._transpose_cache: Optional[Block] = transpose_cache |
147 | 150 |
|
148 | 151 | @classmethod
|
149 | 152 | def from_local(cls, data: pd.DataFrame, session: bigframes.Session) -> Block:
|
@@ -716,6 +719,15 @@ def with_column_labels(
|
716 | 719 | index_labels=self.index.names,
|
717 | 720 | )
|
718 | 721 |
|
| 722 | + def with_transpose_cache(self, transposed: Block): |
| 723 | + return Block( |
| 724 | + self._expr, |
| 725 | + index_columns=self.index_columns, |
| 726 | + column_labels=self._column_labels, |
| 727 | + index_labels=self.index.names, |
| 728 | + transpose_cache=transposed, |
| 729 | + ) |
| 730 | + |
719 | 731 | def with_index_labels(self, value: typing.Sequence[Label]) -> Block:
|
720 | 732 | if len(value) != len(self.index_columns):
|
721 | 733 | raise ValueError(
|
@@ -804,18 +816,35 @@ def multi_apply_window_op(
|
804 | 816 | def multi_apply_unary_op(
|
805 | 817 | self,
|
806 | 818 | columns: typing.Sequence[str],
|
807 |
| - op: ops.UnaryOp, |
| 819 | + op: Union[ops.UnaryOp, ex.Expression], |
808 | 820 | ) -> Block:
|
| 821 | + if isinstance(op, ops.UnaryOp): |
| 822 | + input_varname = guid.generate_guid() |
| 823 | + expr = op.as_expr("arg") |
| 824 | + else: |
| 825 | + input_varnames = op.unbound_variables |
| 826 | + assert len(input_varnames) == 1 |
| 827 | + expr = op |
| 828 | + input_varname = input_varnames[0] |
| 829 | + |
809 | 830 | block = self
|
810 | 831 | for i, col_id in enumerate(columns):
|
811 | 832 | label = self.col_id_to_label[col_id]
|
812 |
| - block, result_id = block.apply_unary_op( |
813 |
| - col_id, |
814 |
| - op, |
815 |
| - result_label=label, |
| 833 | + block, result_id = block.project_expr( |
| 834 | + expr.bind_all_variables({input_varname: ex.free_var(col_id)}), |
| 835 | + label=label, |
816 | 836 | )
|
817 | 837 | block = block.copy_values(result_id, col_id)
|
818 | 838 | block = block.drop_columns([result_id])
|
| 839 | + # Special case, we can preserve transpose cache for full-frame unary ops |
| 840 | + if (self._transpose_cache is not None) and set(self.value_columns) == set( |
| 841 | + columns |
| 842 | + ): |
| 843 | + transpose_columns = self._transpose_cache.value_columns |
| 844 | + new_transpose_cache = self._transpose_cache.multi_apply_unary_op( |
| 845 | + transpose_columns, op |
| 846 | + ) |
| 847 | + block = block.with_transpose_cache(new_transpose_cache) |
819 | 848 | return block
|
820 | 849 |
|
821 | 850 | def apply_window_op(
|
@@ -922,20 +951,17 @@ def aggregate_all_and_stack(
|
922 | 951 | (ex.UnaryAggregation(operation, ex.free_var(col_id)), col_id)
|
923 | 952 | for col_id in self.value_columns
|
924 | 953 | ]
|
925 |
| - index_col_ids = [ |
926 |
| - guid.generate_guid() for i in range(self.column_labels.nlevels) |
927 |
| - ] |
928 |
| - result_expr = self.expr.aggregate(aggregations, dropna=dropna).unpivot( |
929 |
| - row_labels=self.column_labels.to_list(), |
930 |
| - index_col_ids=index_col_ids, |
931 |
| - unpivot_columns=tuple([(value_col_id, tuple(self.value_columns))]), |
932 |
| - ) |
| 954 | + index_id = guid.generate_guid() |
| 955 | + result_expr = self.expr.aggregate( |
| 956 | + aggregations, dropna=dropna |
| 957 | + ).assign_constant(index_id, None, None) |
| 958 | + # Transpose as last operation so that final block has valid transpose cache |
933 | 959 | return Block(
|
934 | 960 | result_expr,
|
935 |
| - index_columns=index_col_ids, |
936 |
| - column_labels=[None], |
| 961 | + index_columns=[index_id], |
| 962 | + column_labels=self.column_labels, |
937 | 963 | index_labels=self.column_labels.names,
|
938 |
| - ) |
| 964 | + ).transpose(original_row_index=pd.Index([None])) |
939 | 965 | else: # axis_n == 1
|
940 | 966 | # using offsets as identity to group on.
|
941 | 967 | # TODO: Allow to promote identity/total_order columns instead for better perf
|
@@ -1575,10 +1601,19 @@ def melt(
|
1575 | 1601 | index_columns=[index_id],
|
1576 | 1602 | )
|
1577 | 1603 |
|
1578 |
| - def transpose(self) -> Block: |
1579 |
| - """Transpose the block. Will fail if dtypes aren't coercible to a common type or too many rows""" |
| 1604 | + def transpose(self, *, original_row_index: Optional[pd.Index] = None) -> Block: |
| 1605 | + """Transpose the block. Will fail if dtypes aren't coercible to a common type or too many rows. |
| 1606 | + Can provide the original_row_index directly if it is already known, otherwise a query is needed. |
| 1607 | + """ |
| 1608 | + if self._transpose_cache is not None: |
| 1609 | + return self._transpose_cache.with_transpose_cache(self) |
| 1610 | + |
1580 | 1611 | original_col_index = self.column_labels
|
1581 |
| - original_row_index = self.index.to_pandas() |
| 1612 | + original_row_index = ( |
| 1613 | + original_row_index |
| 1614 | + if original_row_index is not None |
| 1615 | + else self.index.to_pandas() |
| 1616 | + ) |
1582 | 1617 | original_row_count = len(original_row_index)
|
1583 | 1618 | if original_row_count > bigframes.constants.MAX_COLUMNS:
|
1584 | 1619 | raise NotImplementedError(
|
@@ -1619,6 +1654,7 @@ def transpose(self) -> Block:
|
1619 | 1654 | result.with_column_labels(original_row_index)
|
1620 | 1655 | .order_by([ordering.ascending_over(result.index_columns[-1])])
|
1621 | 1656 | .drop_levels([result.index_columns[-1]])
|
| 1657 | + .with_transpose_cache(self) |
1622 | 1658 | )
|
1623 | 1659 |
|
1624 | 1660 | def _create_stack_column(
|
|
0 commit comments