Skip to content

Commit e12ce8d

Browse files
kszucscpcloud
authored andcommitted
fix(ir): merge window frames for bound analytic window functions with a subsequent over call
1 parent b1137f7 commit e12ce8d

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

ibis/backends/tests/test_window.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,3 +1248,21 @@ def test_range_expression_bounds(backend):
12481248

12491249
assert not result.empty
12501250
assert len(result) == con.execute(t.count())
1251+
1252+
1253+
def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df):
1254+
# GH #7631
1255+
t = alltypes
1256+
expr = t.int_col.percent_rank().over(ibis.window(group_by=t.int_col.notnull()))
1257+
result = expr.execute()
1258+
1259+
expected = (
1260+
df.sort_values("int_col")
1261+
.groupby(df["int_col"].notnull())
1262+
.apply(lambda df: (df.int_col.rank(method="min").sub(1).div(len(df) - 1)))
1263+
.T.reset_index(drop=True)
1264+
.iloc[:, 0]
1265+
.rename(expr.get_name())
1266+
)
1267+
1268+
backend.assert_series_equal(result, expected)

ibis/expr/analysis.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import ibis.expr.operations.relations as rels
1111
import ibis.expr.types as ir
1212
from ibis import util
13-
from ibis.common.deferred import _, deferred, var
13+
from ibis.common.deferred import deferred, var
1414
from ibis.common.exceptions import IbisTypeError, IntegrityError
15-
from ibis.common.patterns import Eq, In, pattern
15+
from ibis.common.patterns import Eq, In, pattern, replace
1616
from ibis.util import Namespace
1717

1818
if TYPE_CHECKING:
@@ -163,25 +163,25 @@ def pushdown_selection_filters(parent, predicates):
163163
return parent.copy(predicates=parent.predicates + tuple(simplified))
164164

165165

166-
def windowize_function(expr, default_frame, merge_frames=False):
167-
func, frame = var("func"), var("frame")
168-
169-
wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame)
170-
merge_windows = p.WindowFunction(func, frame) >> c.WindowFunction(
171-
func,
172-
frame.copy(
173-
order_by=frame.order_by + default_frame.order_by,
174-
group_by=frame.group_by + default_frame.group_by,
175-
),
176-
)
166+
@replace(p.Analytic | p.Reduction)
167+
def wrap_analytic(_, default_frame):
168+
return ops.WindowFunction(_, default_frame)
169+
170+
171+
@replace(p.WindowFunction)
172+
def merge_windows(_, default_frame):
173+
group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by))
174+
order_by = tuple(toolz.unique(_.frame.order_by + default_frame.order_by))
175+
frame = _.frame.copy(group_by=group_by, order_by=order_by)
176+
return ops.WindowFunction(_.func, frame)
177177

178+
179+
def windowize_function(expr, default_frame, merge_frames=False):
180+
ctx = {"default_frame": default_frame}
178181
node = expr.op()
179182
if merge_frames:
180-
# it only happens in ibis.expr.groupby.GroupedTable, but the projector
181-
# changes the windowization order to put everything here
182-
node = node.replace(merge_windows, filter=p.Value)
183-
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction)
184-
183+
node = node.replace(merge_windows, filter=p.Value, context=ctx)
184+
node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx)
185185
return node.to_expr()
186186

187187

ibis/expr/types/generic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,17 +758,15 @@ def over(
758758

759759
def bind(table):
760760
frame = window.bind(table)
761-
expr = an.windowize_function(self, frame)
761+
expr = an.windowize_function(self, frame, merge_frames=True)
762762
if expr.equals(self):
763763
raise com.IbisTypeError(
764764
"No reduction or analytic function found to construct a window expression"
765765
)
766766
return expr
767767

768768
op = self.op()
769-
if isinstance(op, ops.WindowFunction):
770-
return op.func.to_expr().over(window)
771-
elif isinstance(window, bl.WindowBuilder):
769+
if isinstance(window, bl.WindowBuilder):
772770
if table := an.find_first_base_table(self.op()):
773771
return bind(table)
774772
else:

ibis/tests/expr/test_window_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,10 @@ def test_value_over_api(alltypes):
5555
expr = t.f.cumsum().over(range=(-1, 1), group_by=[t.g, t.a], order_by=[t.f])
5656
expected = t.f.cumsum().over(w2)
5757
assert expr.equals(expected)
58+
59+
60+
def test_rank_followed_by_over_call_merge_frames(alltypes):
61+
t = alltypes
62+
expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull()))
63+
expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f).resolve(t)
64+
assert expr1.equals(expr2)

0 commit comments

Comments
 (0)