@@ -35,16 +35,21 @@ class SquashedSelect:
35
35
columns : Tuple [Tuple [scalar_exprs .Expression , str ], ...]
36
36
predicate : Optional [scalar_exprs .Expression ]
37
37
ordering : Tuple [order .OrderingExpression , ...]
38
+ reverse_root : bool = False
38
39
39
40
@classmethod
40
- def from_node (cls , node : nodes .BigFrameNode ) -> SquashedSelect :
41
+ def from_node (
42
+ cls , node : nodes .BigFrameNode , projections_only : bool = False
43
+ ) -> SquashedSelect :
41
44
if isinstance (node , nodes .ProjectionNode ):
42
- return cls .from_node (node .child ).project (node .assignments )
43
- elif isinstance (node , nodes .FilterNode ):
45
+ return cls .from_node (node .child , projections_only = projections_only ).project (
46
+ node .assignments
47
+ )
48
+ elif not projections_only and isinstance (node , nodes .FilterNode ):
44
49
return cls .from_node (node .child ).filter (node .predicate )
45
- elif isinstance (node , nodes .ReversedNode ):
50
+ elif not projections_only and isinstance (node , nodes .ReversedNode ):
46
51
return cls .from_node (node .child ).reverse ()
47
- elif isinstance (node , nodes .OrderByNode ):
52
+ elif not projections_only and isinstance (node , nodes .OrderByNode ):
48
53
return cls .from_node (node .child ).order_with (node .by )
49
54
else :
50
55
selection = tuple (
@@ -63,7 +68,9 @@ def project(
63
68
new_columns = tuple (
64
69
(expr .bind_all_variables (self .column_lookup ), id ) for expr , id in projection
65
70
)
66
- return SquashedSelect (self .root , new_columns , self .predicate , self .ordering )
71
+ return SquashedSelect (
72
+ self .root , new_columns , self .predicate , self .ordering , self .reverse_root
73
+ )
67
74
68
75
def filter (self , predicate : scalar_exprs .Expression ) -> SquashedSelect :
69
76
if self .predicate is None :
@@ -72,18 +79,24 @@ def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
72
79
new_predicate = ops .and_op .as_expr (
73
80
self .predicate , predicate .bind_all_variables (self .column_lookup )
74
81
)
75
- return SquashedSelect (self .root , self .columns , new_predicate , self .ordering )
82
+ return SquashedSelect (
83
+ self .root , self .columns , new_predicate , self .ordering , self .reverse_root
84
+ )
76
85
77
86
def reverse (self ) -> SquashedSelect :
78
87
new_ordering = tuple (expr .with_reverse () for expr in self .ordering )
79
- return SquashedSelect (self .root , self .columns , self .predicate , new_ordering )
88
+ return SquashedSelect (
89
+ self .root , self .columns , self .predicate , new_ordering , not self .reverse_root
90
+ )
80
91
81
92
def order_with (self , by : Tuple [order .OrderingExpression , ...]):
82
93
adjusted_orderings = [
83
94
order_part .bind_variables (self .column_lookup ) for order_part in by
84
95
]
85
96
new_ordering = (* adjusted_orderings , * self .ordering )
86
- return SquashedSelect (self .root , self .columns , self .predicate , new_ordering )
97
+ return SquashedSelect (
98
+ self .root , self .columns , self .predicate , new_ordering , self .reverse_root
99
+ )
87
100
88
101
def maybe_join (
89
102
self , right : SquashedSelect , join_def : join_defs .JoinDefinition
@@ -126,8 +139,10 @@ def maybe_join(
126
139
new_columns = remap_names (join_def , lselection , rselection )
127
140
128
141
# Reconstruct ordering
142
+ reverse_root = self .reverse_root
129
143
if join_type == "right" :
130
144
new_ordering = right .ordering
145
+ reverse_root = right .reverse_root
131
146
elif join_type == "outer" :
132
147
if lmask is not None :
133
148
prefix = order .OrderingExpression (lmask , order .OrderingDirection .DESC )
@@ -158,18 +173,31 @@ def maybe_join(
158
173
new_ordering = self .ordering
159
174
else :
160
175
raise ValueError (f"Unexpected join type { join_type } " )
161
- return SquashedSelect (self .root , new_columns , new_predicate , new_ordering )
176
+ return SquashedSelect (
177
+ self .root , new_columns , new_predicate , new_ordering , reverse_root
178
+ )
162
179
163
180
def expand (self ) -> nodes .BigFrameNode :
164
181
# Safest to apply predicates first, as it may filter out inputs that cannot be handled by other expressions
165
182
root = self .root
183
+ if self .reverse_root :
184
+ root = nodes .ReversedNode (child = root )
166
185
if self .predicate :
167
186
root = nodes .FilterNode (child = root , predicate = self .predicate )
168
187
if self .ordering :
169
188
root = nodes .OrderByNode (child = root , by = self .ordering )
170
189
return nodes .ProjectionNode (child = root , assignments = self .columns )
171
190
172
191
192
+ def maybe_squash_projection (node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
193
+ if isinstance (node , nodes .ProjectionNode ) and isinstance (
194
+ node .child , nodes .ProjectionNode
195
+ ):
196
+ # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
197
+ return SquashedSelect .from_node (node , projections_only = True ).expand ()
198
+ return node
199
+
200
+
173
201
def maybe_rewrite_join (join_node : nodes .JoinNode ) -> nodes .BigFrameNode :
174
202
left_side = SquashedSelect .from_node (join_node .left_child )
175
203
right_side = SquashedSelect .from_node (join_node .right_child )
0 commit comments