|
7 | 7 |
|
8 | 8 | import sqlalchemy as sa
|
9 | 9 | from sqlalchemy.ext.compiler import compiles
|
| 10 | +from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED |
10 | 11 | from sqlalchemy.sql.functions import FunctionElement, GenericFunction
|
11 | 12 |
|
12 | 13 | import ibis.common.exceptions as com
|
@@ -308,17 +309,89 @@ def _endswith(t, op):
|
308 | 309 | return t.translate(op.arg).endswith(t.translate(op.end))
|
309 | 310 |
|
310 | 311 |
|
311 |
| -def _translate_window_boundary(boundary): |
| 312 | +def _reinterpret_range_bound(bound): |
| 313 | + if bound is None: |
| 314 | + return RANGE_UNBOUNDED |
| 315 | + |
| 316 | + try: |
| 317 | + lower = int(bound) |
| 318 | + except ValueError as err: |
| 319 | + sa.util.raise_( |
| 320 | + sa.exc.ArgumentError( |
| 321 | + "Integer, None or expression expected for range value" |
| 322 | + ), |
| 323 | + replace_context=err, |
| 324 | + ) |
| 325 | + except TypeError: |
| 326 | + return bound |
| 327 | + else: |
| 328 | + return RANGE_CURRENT if lower == 0 else lower |
| 329 | + |
| 330 | + |
| 331 | +def _interpret_range(self, range_): |
| 332 | + if not isinstance(range_, tuple) or len(range_) != 2: |
| 333 | + raise sa.exc.ArgumentError("2-tuple expected for range/rows") |
| 334 | + |
| 335 | + lower = _reinterpret_range_bound(range_[0]) |
| 336 | + upper = _reinterpret_range_bound(range_[1]) |
| 337 | + return lower, upper |
| 338 | + |
| 339 | + |
| 340 | +# monkeypatch to allow expressions in range and rows bounds |
| 341 | +sa.sql.elements.Over._interpret_range = _interpret_range |
| 342 | + |
| 343 | + |
| 344 | +def _translate_window_boundary(t, boundary): |
312 | 345 | if boundary is None:
|
313 | 346 | return None
|
314 | 347 |
|
315 |
| - if isinstance(boundary.value, ops.Literal): |
316 |
| - if boundary.preceding: |
317 |
| - return -boundary.value.value |
318 |
| - else: |
319 |
| - return boundary.value.value |
| 348 | + value = t.translate(boundary.value) |
| 349 | + return value if boundary.preceding else value |
| 350 | + |
| 351 | + |
| 352 | +def _compile_bounds(compiler, range_, kind, **kw): |
| 353 | + left_, right_ = range_ |
| 354 | + |
| 355 | + if left_ is RANGE_UNBOUNDED: |
| 356 | + left = "UNBOUNDED PRECEDING" |
| 357 | + elif left_ is RANGE_CURRENT: |
| 358 | + left = "CURRENT ROW" |
| 359 | + else: |
| 360 | + left = f"{compiler.process(left_, **kw)} PRECEDING" |
| 361 | + |
| 362 | + if right_ is RANGE_UNBOUNDED: |
| 363 | + right = "UNBOUNDED FOLLOWING" |
| 364 | + elif right_ is RANGE_CURRENT: |
| 365 | + right = "CURRENT ROW" |
| 366 | + else: |
| 367 | + right = f"{compiler.process(right_, **kw)} FOLLOWING" |
| 368 | + |
| 369 | + return f"{kind} BETWEEN {left} AND {right}" |
| 370 | + |
| 371 | + |
| 372 | +@compiles(sa.sql.elements.Over) |
| 373 | +def compile_over(over, compiler, **kw): |
| 374 | + text = compiler.process(over.element, **kw) |
| 375 | + if over.range_: |
| 376 | + range_ = _compile_bounds(compiler, over.range_, kind="RANGE", **kw) |
| 377 | + elif over.rows: |
| 378 | + range_ = _compile_bounds(compiler, over.rows, kind="ROWS", **kw) |
| 379 | + else: |
| 380 | + range_ = None |
| 381 | + |
| 382 | + args = [ |
| 383 | + f"{word} BY {compiler.process(clause, **kw)}" |
| 384 | + for word, clause in ( |
| 385 | + ("PARTITION", over.partition_by), |
| 386 | + ("ORDER", over.order_by), |
| 387 | + ) |
| 388 | + if clause is not None and len(clause) |
| 389 | + ] |
| 390 | + |
| 391 | + if range_ is not None: |
| 392 | + args.append(range_) |
320 | 393 |
|
321 |
| - raise com.TranslationError("Window boundaries must be literal values") |
| 394 | + return f"{text} OVER ({' '.join(args)})" |
322 | 395 |
|
323 | 396 |
|
324 | 397 | def _window_function(t, window):
|
@@ -351,8 +424,8 @@ def _window_function(t, window):
|
351 | 424 | # some functions on some backends don't support frame clauses
|
352 | 425 | additional_params = {}
|
353 | 426 | else:
|
354 |
| - start = _translate_window_boundary(window.frame.start) |
355 |
| - end = _translate_window_boundary(window.frame.end) |
| 427 | + start = _translate_window_boundary(t, window.frame.start) |
| 428 | + end = _translate_window_boundary(t, window.frame.end) |
356 | 429 | additional_params = {how: (start, end)}
|
357 | 430 |
|
358 | 431 | result = sa.over(
|
|
0 commit comments