Skip to content

Commit 0370bcb

Browse files
authored
feat(ux): allow window functions in predicates and compile to QUALIFY where possible (#9787)
1 parent 8d4f97f commit 0370bcb

File tree

28 files changed

+390
-34
lines changed

28 files changed

+390
-34
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ class SQLGlotCompiler(abc.ABC):
267267
copy_func_args: bool = False
268268
"""Whether to copy function arguments when generating SQL."""
269269

270+
supports_qualify: bool = False
271+
"""Whether the backend supports the QUALIFY clause."""
272+
270273
NAN: ClassVar[sge.Expression] = sge.Cast(
271274
this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE)
272275
)
@@ -1249,15 +1252,21 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
12491252
else:
12501253
yield value.as_(name, quoted=self.quoted, copy=False)
12511254

1252-
def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
1255+
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
12531256
# if we've constructed a useless projection return the parent relation
1254-
if not selections and not predicates and not sort_keys:
1257+
if not (selections or predicates or qualified or sort_keys):
12551258
return parent
12561259

12571260
result = parent
12581261

12591262
if selections:
1260-
if op.is_star_selection():
1263+
# if there are `qualify` predicates then sqlglot adds a hidden
1264+
# column to implement the functionality if the dialect doesn't
1265+
# support it
1266+
#
1267+
# using STAR in that case would lead to an extra column, so in that
1268+
# case we have to spell out the columns
1269+
if op.is_star_selection() and (not qualified or self.supports_qualify):
12611270
fields = [STAR]
12621271
else:
12631272
fields = self._cleanup_names(selections)
@@ -1266,6 +1275,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
12661275
if predicates:
12671276
result = result.where(*predicates, copy=False)
12681277

1278+
if qualified:
1279+
result = result.qualify(*qualified, copy=False)
1280+
12691281
if sort_keys:
12701282
result = result.order_by(*sort_keys, copy=False)
12711283

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class BigQueryCompiler(SQLGlotCompiler):
112112
*SQLGlotCompiler.rewrites,
113113
)
114114

115+
supports_qualify = True
116+
115117
UNSUPPORTED_OPS = (
116118
ops.DateDiff,
117119
ops.ExtractAuthority,

ibis/backends/sql/compilers/clickhouse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class ClickHouseCompiler(SQLGlotCompiler):
4242

4343
agg = ClickhouseAggGen()
4444

45+
supports_qualify = True
46+
4547
UNSUPPORTED_OPS = (
4648
ops.RowID,
4749
ops.CumeDist,

ibis/backends/sql/compilers/duckdb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class DuckDBCompiler(SQLGlotCompiler):
4242

4343
agg = AggGen(supports_filter=True, supports_order_by=True)
4444

45+
supports_qualify = True
46+
4547
LOWERED_OPS = {
4648
ops.Sample: None,
4749
ops.StringSlice: None,

ibis/backends/sql/compilers/mssql.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ def visit_All(self, op, *, arg, where):
477477
arg = self.if_(where, arg, NULL)
478478
return sge.Min(this=arg)
479479

480-
def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
480+
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
481481
# if we've constructed a useless projection return the parent relation
482-
if not selections and not predicates and not sort_keys:
482+
if not (selections or predicates or qualified or sort_keys):
483483
return parent
484484

485485
result = parent
@@ -492,6 +492,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
492492
if predicates:
493493
result = result.where(*predicates, copy=True)
494494

495+
if qualified:
496+
result = result.qualify(*qualified, copy=True)
497+
495498
if sort_keys:
496499
result = result.order_by(*sort_keys, copy=False)
497500

ibis/backends/sql/compilers/snowflake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class SnowflakeCompiler(SQLGlotCompiler):
4545
dialect = Snowflake
4646
type_mapper = SnowflakeType
4747
no_limit_value = NULL
48+
supports_qualify = True
4849

4950
agg = AggGen(supports_order_by=True)
5051

ibis/backends/sql/dialects.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,13 @@ def _create_sql(self, expression: sge.Create) -> str:
368368
sge.Stddev: rename_func("stddev_pop"),
369369
sge.ApproxDistinct: rename_func("approx_count_distinct"),
370370
sge.Create: _create_sql,
371-
sge.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
371+
sge.Select: transforms.preprocess(
372+
[
373+
transforms.eliminate_semi_and_anti_joins,
374+
transforms.eliminate_distinct_on,
375+
transforms.eliminate_qualify,
376+
]
377+
),
372378
sge.GroupConcat: rename_func("listagg"),
373379
}
374380

ibis/backends/sql/rewrites.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Select(ops.Relation):
5151
parent: ops.Relation
5252
selections: FrozenDict[str, ops.Value] = {}
5353
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
54+
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
5455
sort_keys: VarTuple[ops.SortKey] = ()
5556

5657
def is_star_selection(self):
@@ -99,10 +100,26 @@ def project_to_select(_, **kwargs):
99100
return Select(_.parent, selections=_.values)
100101

101102

103+
def partition_predicates(predicates):
104+
qualified = []
105+
unqualified = []
106+
107+
for predicate in predicates:
108+
if predicate.find(ops.WindowFunction, filter=ops.Value):
109+
qualified.append(predicate)
110+
else:
111+
unqualified.append(predicate)
112+
113+
return unqualified, qualified
114+
115+
102116
@replace(p.Filter)
103117
def filter_to_select(_, **kwargs):
104118
"""Convert a Filter node to a Select node."""
105-
return Select(_.parent, selections=_.values, predicates=_.predicates)
119+
predicates, qualified = partition_predicates(_.predicates)
120+
return Select(
121+
_.parent, selections=_.values, predicates=predicates, qualified=qualified
122+
)
106123

107124

108125
@replace(p.Sort)
@@ -233,6 +250,9 @@ def merge_select_select(_, **kwargs):
233250
predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates)
234251
unique_predicates = toolz.unique(_.parent.predicates + predicates)
235252

253+
qualified = tuple(p.replace(subs, filter=ops.Value) for p in _.qualified)
254+
unique_qualified = toolz.unique(_.parent.qualified + qualified)
255+
236256
sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
237257
sort_key_exprs = {s.expr for s in sort_keys}
238258
parent_sort_keys = tuple(
@@ -244,6 +264,7 @@ def merge_select_select(_, **kwargs):
244264
_.parent.parent,
245265
selections=selections,
246266
predicates=unique_predicates,
267+
qualified=unique_qualified,
247268
sort_keys=unique_sort_keys,
248269
)
249270
return result if complexity(result) <= complexity(_) else _
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
*
3+
FROM (
4+
SELECT
5+
`t0`.`x`,
6+
SUM(`t0`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
7+
FROM `t` AS `t0`
8+
) AS `t1`
9+
WHERE
10+
`t1`.`y` <= 37
11+
QUALIFY
12+
AVG(`t1`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
*
3+
FROM (
4+
SELECT
5+
"t0"."x" AS "x",
6+
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
7+
FROM "t" AS "t0"
8+
) AS "t1"
9+
WHERE
10+
"t1"."y" <= 37
11+
QUALIFY
12+
isNotNull(AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
"x",
3+
"y"
4+
FROM (
5+
SELECT
6+
"t1"."x",
7+
"t1"."y",
8+
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
"t0"."x",
12+
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
13+
FROM "t" AS "t0"
14+
) AS "t1"
15+
WHERE
16+
"t1"."y" <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
"x",
3+
"y"
4+
FROM (
5+
SELECT
6+
"t1"."x",
7+
"t1"."y",
8+
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
"t0"."x",
12+
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
13+
FROM "t" AS "t0"
14+
) AS "t1"
15+
WHERE
16+
"t1"."y" <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
*
3+
FROM (
4+
SELECT
5+
"t0"."x",
6+
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
7+
FROM "t" AS "t0"
8+
) AS "t1"
9+
WHERE
10+
"t1"."y" <= 37
11+
QUALIFY
12+
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
"x",
3+
"y"
4+
FROM (
5+
SELECT
6+
"t1"."x",
7+
"t1"."y",
8+
AVG("t1"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
"t0"."x",
12+
SUM("t0"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
13+
FROM "t" AS "t0"
14+
) AS "t1"
15+
WHERE
16+
"t1"."y" <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
SELECT
2+
`t1`.`x`,
3+
`t1`.`y`
4+
FROM (
5+
SELECT
6+
`t0`.`x`,
7+
SUM(`t0`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
8+
FROM `t` AS `t0`
9+
) AS `t1`
10+
WHERE
11+
`t1`.`y` <= 37
12+
QUALIFY
13+
AVG(`t1`.`x`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
`x`,
3+
`y`
4+
FROM (
5+
SELECT
6+
`t1`.`x`,
7+
`t1`.`y`,
8+
AVG(`t1`.`x`) OVER (ORDER BY NULL ASC) AS _w
9+
FROM (
10+
SELECT
11+
`t0`.`x`,
12+
SUM(`t0`.`x`) OVER (ORDER BY NULL ASC) AS `y`
13+
FROM `t` AS `t0`
14+
) AS `t1`
15+
WHERE
16+
`t1`.`y` <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
[x],
3+
[y]
4+
FROM (
5+
SELECT
6+
[t1].[x] AS [x],
7+
[t1].[y] AS [y],
8+
AVG([t1].[x]) OVER (ORDER BY CASE WHEN [t1].[x] IS NULL THEN 1 ELSE 0 END, [t1].[x] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
[t0].[x],
12+
SUM([t0].[x]) OVER (ORDER BY CASE WHEN [t0].[x] IS NULL THEN 1 ELSE 0 END, [t0].[x] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS [y]
13+
FROM [t] AS [t0]
14+
) AS [t1]
15+
WHERE
16+
[t1].[y] <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
`x`,
3+
`y`
4+
FROM (
5+
SELECT
6+
`t1`.`x`,
7+
`t1`.`y`,
8+
AVG(`t1`.`x`) OVER (ORDER BY CASE WHEN NULL IS NULL THEN 1 ELSE 0 END, NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
`t0`.`x`,
12+
SUM(`t0`.`x`) OVER (ORDER BY CASE WHEN NULL IS NULL THEN 1 ELSE 0 END, NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `y`
13+
FROM `t` AS `t0`
14+
) AS `t1`
15+
WHERE
16+
`t1`.`y` <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
"x",
3+
"y"
4+
FROM (
5+
SELECT
6+
"t1"."x",
7+
"t1"."y",
8+
AVG("t1"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
"t0"."x",
12+
SUM("t0"."x") OVER (ORDER BY NULL ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
13+
FROM "t" "t0"
14+
) "t1"
15+
WHERE
16+
"t1"."y" <= 37
17+
) _t
18+
WHERE
19+
_w IS NOT NULL
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
SELECT
2+
"x",
3+
"y"
4+
FROM (
5+
SELECT
6+
"t1"."x",
7+
"t1"."y",
8+
AVG("t1"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS _w
9+
FROM (
10+
SELECT
11+
"t0"."x",
12+
SUM("t0"."x") OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "y"
13+
FROM "t" AS "t0"
14+
) AS "t1"
15+
WHERE
16+
"t1"."y" <= 37
17+
) AS _t
18+
WHERE
19+
_w IS NOT NULL

0 commit comments

Comments
 (0)