Skip to content

Commit 03c1b0d

Browse files
perf: Automatically condense internal expression representation (#516)
1 parent 3aa643f commit 03c1b0d

File tree

3 files changed

+65
-23
lines changed

3 files changed

+65
-23
lines changed

bigframes/core/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def project_to_id(self, expression: ex.Expression, output_id: str):
183183
child=self.node,
184184
assignments=tuple(exprs),
185185
)
186-
)
186+
).merge_projections()
187187

188188
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
189189
if destination_id in self.column_ids: # Mutate case
@@ -208,7 +208,7 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue:
208208
child=self.node,
209209
assignments=tuple(exprs),
210210
)
211-
)
211+
).merge_projections()
212212

213213
def assign_constant(
214214
self,
@@ -242,7 +242,7 @@ def assign_constant(
242242
child=self.node,
243243
assignments=tuple(exprs),
244244
)
245-
)
245+
).merge_projections()
246246

247247
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
248248
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
@@ -251,7 +251,7 @@ def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
251251
child=self.node,
252252
assignments=tuple(selections),
253253
)
254-
)
254+
).merge_projections()
255255

256256
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
257257
new_projection = (
@@ -264,7 +264,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
264264
child=self.node,
265265
assignments=tuple(new_projection),
266266
)
267-
)
267+
).merge_projections()
268268

269269
def aggregate(
270270
self,
@@ -466,3 +466,7 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue:
466466
The row numbers of result is non-deterministic, avoid to use.
467467
"""
468468
return ArrayValue(nodes.RandomSampleNode(self.node, fraction))
469+
470+
def merge_projections(self) -> ArrayValue:
471+
new_node = bigframes.core.rewrite.maybe_squash_projection(self.node)
472+
return ArrayValue(new_node)

bigframes/core/compile/compiled.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,27 +1050,37 @@ def _hide_column(self, column_id) -> OrderedIR:
10501050
def _bake_ordering(self) -> OrderedIR:
10511051
"""Bakes ordering expression into the selection, maybe creating hidden columns."""
10521052
ordering_expressions = self._ordering.all_ordering_columns
1053-
new_exprs = []
1054-
new_baked_cols = []
1053+
new_exprs: list[OrderingExpression] = []
1054+
new_baked_cols: list[ibis_types.Value] = []
10551055
for expr in ordering_expressions:
10561056
if isinstance(expr.scalar_expression, ex.OpExpression):
10571057
baked_column = self._compile_expression(expr.scalar_expression).name(
10581058
bigframes.core.guid.generate_guid()
10591059
)
10601060
new_baked_cols.append(baked_column)
10611061
new_expr = OrderingExpression(
1062-
ex.free_var(baked_column.name), expr.direction, expr.na_last
1062+
ex.free_var(baked_column.get_name()), expr.direction, expr.na_last
10631063
)
10641064
new_exprs.append(new_expr)
1065-
else:
1065+
elif isinstance(expr.scalar_expression, ex.UnboundVariableExpression):
1066+
order_col = expr.scalar_expression.id
10661067
new_exprs.append(expr)
1067-
1068-
ordering = self._ordering.with_ordering_columns(new_exprs)
1068+
if order_col not in self.column_ids:
1069+
new_baked_cols.append(
1070+
self._ibis_bindings[expr.scalar_expression.id]
1071+
)
1072+
1073+
new_ordering = ExpressionOrdering(
1074+
tuple(new_exprs),
1075+
self._ordering.integer_encoding,
1076+
self._ordering.string_encoding,
1077+
self._ordering.total_ordering_columns,
1078+
)
10691079
return OrderedIR(
10701080
self._table,
10711081
columns=self.columns,
1072-
hidden_ordering_columns=[*self._hidden_ordering_columns, *new_baked_cols],
1073-
ordering=ordering,
1082+
hidden_ordering_columns=tuple(new_baked_cols),
1083+
ordering=new_ordering,
10741084
predicates=self._predicates,
10751085
)
10761086

bigframes/core/rewrite.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,21 @@ class SquashedSelect:
3535
columns: Tuple[Tuple[scalar_exprs.Expression, str], ...]
3636
predicate: Optional[scalar_exprs.Expression]
3737
ordering: Tuple[order.OrderingExpression, ...]
38+
reverse_root: bool = False
3839

3940
@classmethod
40-
def from_node(cls, node: nodes.BigFrameNode) -> SquashedSelect:
41+
def from_node(
42+
cls, node: nodes.BigFrameNode, projections_only: bool = False
43+
) -> SquashedSelect:
4144
if isinstance(node, nodes.ProjectionNode):
42-
return cls.from_node(node.child).project(node.assignments)
43-
elif isinstance(node, nodes.FilterNode):
45+
return cls.from_node(node.child, projections_only=projections_only).project(
46+
node.assignments
47+
)
48+
elif not projections_only and isinstance(node, nodes.FilterNode):
4449
return cls.from_node(node.child).filter(node.predicate)
45-
elif isinstance(node, nodes.ReversedNode):
50+
elif not projections_only and isinstance(node, nodes.ReversedNode):
4651
return cls.from_node(node.child).reverse()
47-
elif isinstance(node, nodes.OrderByNode):
52+
elif not projections_only and isinstance(node, nodes.OrderByNode):
4853
return cls.from_node(node.child).order_with(node.by)
4954
else:
5055
selection = tuple(
@@ -63,7 +68,9 @@ def project(
6368
new_columns = tuple(
6469
(expr.bind_all_variables(self.column_lookup), id) for expr, id in projection
6570
)
66-
return SquashedSelect(self.root, new_columns, self.predicate, self.ordering)
71+
return SquashedSelect(
72+
self.root, new_columns, self.predicate, self.ordering, self.reverse_root
73+
)
6774

6875
def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
6976
if self.predicate is None:
@@ -72,18 +79,24 @@ def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
7279
new_predicate = ops.and_op.as_expr(
7380
self.predicate, predicate.bind_all_variables(self.column_lookup)
7481
)
75-
return SquashedSelect(self.root, self.columns, new_predicate, self.ordering)
82+
return SquashedSelect(
83+
self.root, self.columns, new_predicate, self.ordering, self.reverse_root
84+
)
7685

7786
def reverse(self) -> SquashedSelect:
7887
new_ordering = tuple(expr.with_reverse() for expr in self.ordering)
79-
return SquashedSelect(self.root, self.columns, self.predicate, new_ordering)
88+
return SquashedSelect(
89+
self.root, self.columns, self.predicate, new_ordering, not self.reverse_root
90+
)
8091

8192
def order_with(self, by: Tuple[order.OrderingExpression, ...]):
8293
adjusted_orderings = [
8394
order_part.bind_variables(self.column_lookup) for order_part in by
8495
]
8596
new_ordering = (*adjusted_orderings, *self.ordering)
86-
return SquashedSelect(self.root, self.columns, self.predicate, new_ordering)
97+
return SquashedSelect(
98+
self.root, self.columns, self.predicate, new_ordering, self.reverse_root
99+
)
87100

88101
def maybe_join(
89102
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
@@ -126,8 +139,10 @@ def maybe_join(
126139
new_columns = remap_names(join_def, lselection, rselection)
127140

128141
# Reconstruct ordering
142+
reverse_root = self.reverse_root
129143
if join_type == "right":
130144
new_ordering = right.ordering
145+
reverse_root = right.reverse_root
131146
elif join_type == "outer":
132147
if lmask is not None:
133148
prefix = order.OrderingExpression(lmask, order.OrderingDirection.DESC)
@@ -158,18 +173,31 @@ def maybe_join(
158173
new_ordering = self.ordering
159174
else:
160175
raise ValueError(f"Unexpected join type {join_type}")
161-
return SquashedSelect(self.root, new_columns, new_predicate, new_ordering)
176+
return SquashedSelect(
177+
self.root, new_columns, new_predicate, new_ordering, reverse_root
178+
)
162179

163180
def expand(self) -> nodes.BigFrameNode:
164181
# Safest to apply predicates first, as it may filter out inputs that cannot be handled by other expressions
165182
root = self.root
183+
if self.reverse_root:
184+
root = nodes.ReversedNode(child=root)
166185
if self.predicate:
167186
root = nodes.FilterNode(child=root, predicate=self.predicate)
168187
if self.ordering:
169188
root = nodes.OrderByNode(child=root, by=self.ordering)
170189
return nodes.ProjectionNode(child=root, assignments=self.columns)
171190

172191

192+
def maybe_squash_projection(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
193+
if isinstance(node, nodes.ProjectionNode) and isinstance(
194+
node.child, nodes.ProjectionNode
195+
):
196+
# Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
197+
return SquashedSelect.from_node(node, projections_only=True).expand()
198+
return node
199+
200+
173201
def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
174202
left_side = SquashedSelect.from_node(join_node.left_child)
175203
right_side = SquashedSelect.from_node(join_node.right_child)

0 commit comments

Comments
 (0)