Skip to content

Commit 5dbb3b1

Browse files
cpcloudkszucs
authored andcommitted
feat(sqlalchemy): support expressions in window bounds
1 parent da2a699 commit 5dbb3b1

File tree

1 file changed

+82
-9
lines changed

1 file changed

+82
-9
lines changed

ibis/backends/base/sql/alchemy/registry.py

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import sqlalchemy as sa
99
from sqlalchemy.ext.compiler import compiles
10+
from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED
1011
from sqlalchemy.sql.functions import FunctionElement, GenericFunction
1112

1213
import ibis.common.exceptions as com
@@ -308,17 +309,89 @@ def _endswith(t, op):
308309
return t.translate(op.arg).endswith(t.translate(op.end))
309310

310311

311-
def _translate_window_boundary(boundary):
312+
def _reinterpret_range_bound(bound):
313+
if bound is None:
314+
return RANGE_UNBOUNDED
315+
316+
try:
317+
lower = int(bound)
318+
except ValueError as err:
319+
sa.util.raise_(
320+
sa.exc.ArgumentError(
321+
"Integer, None or expression expected for range value"
322+
),
323+
replace_context=err,
324+
)
325+
except TypeError:
326+
return bound
327+
else:
328+
return RANGE_CURRENT if lower == 0 else lower
329+
330+
331+
def _interpret_range(self, range_):
332+
if not isinstance(range_, tuple) or len(range_) != 2:
333+
raise sa.exc.ArgumentError("2-tuple expected for range/rows")
334+
335+
lower = _reinterpret_range_bound(range_[0])
336+
upper = _reinterpret_range_bound(range_[1])
337+
return lower, upper
338+
339+
340+
# monkeypatch to allow expressions in range and rows bounds
341+
sa.sql.elements.Over._interpret_range = _interpret_range
342+
343+
344+
def _translate_window_boundary(t, boundary):
312345
if boundary is None:
313346
return None
314347

315-
if isinstance(boundary.value, ops.Literal):
316-
if boundary.preceding:
317-
return -boundary.value.value
318-
else:
319-
return boundary.value.value
348+
value = t.translate(boundary.value)
349+
return value if boundary.preceding else value
350+
351+
352+
def _compile_bounds(compiler, range_, kind, **kw):
353+
left_, right_ = range_
354+
355+
if left_ is RANGE_UNBOUNDED:
356+
left = "UNBOUNDED PRECEDING"
357+
elif left_ is RANGE_CURRENT:
358+
left = "CURRENT ROW"
359+
else:
360+
left = f"{compiler.process(left_, **kw)} PRECEDING"
361+
362+
if right_ is RANGE_UNBOUNDED:
363+
right = "UNBOUNDED FOLLOWING"
364+
elif right_ is RANGE_CURRENT:
365+
right = "CURRENT ROW"
366+
else:
367+
right = f"{compiler.process(right_, **kw)} FOLLOWING"
368+
369+
return f"{kind} BETWEEN {left} AND {right}"
370+
371+
372+
@compiles(sa.sql.elements.Over)
373+
def compile_over(over, compiler, **kw):
374+
text = compiler.process(over.element, **kw)
375+
if over.range_:
376+
range_ = _compile_bounds(compiler, over.range_, kind="RANGE", **kw)
377+
elif over.rows:
378+
range_ = _compile_bounds(compiler, over.rows, kind="ROWS", **kw)
379+
else:
380+
range_ = None
381+
382+
args = [
383+
f"{word} BY {compiler.process(clause, **kw)}"
384+
for word, clause in (
385+
("PARTITION", over.partition_by),
386+
("ORDER", over.order_by),
387+
)
388+
if clause is not None and len(clause)
389+
]
390+
391+
if range_ is not None:
392+
args.append(range_)
320393

321-
raise com.TranslationError("Window boundaries must be literal values")
394+
return f"{text} OVER ({' '.join(args)})"
322395

323396

324397
def _window_function(t, window):
@@ -351,8 +424,8 @@ def _window_function(t, window):
351424
# some functions on some backends don't support frame clauses
352425
additional_params = {}
353426
else:
354-
start = _translate_window_boundary(window.frame.start)
355-
end = _translate_window_boundary(window.frame.end)
427+
start = _translate_window_boundary(t, window.frame.start)
428+
end = _translate_window_boundary(t, window.frame.end)
356429
additional_params = {how: (start, end)}
357430

358431
result = sa.over(

0 commit comments

Comments
 (0)