Skip to content

Commit efa42bd

Browse files
committed
fix(sqlalchemy): handle correlated exists sanely
Also fix an issue where exists queries on projections didn't work at all
1 parent 9f4ff54 commit efa42bd

File tree

4 files changed

+75
-35
lines changed

4 files changed

+75
-35
lines changed

ibis/backends/base/sql/alchemy/query_builder.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_result(self):
4242
for jtype, table, preds in zip(
4343
self.join_types, self.join_tables[1:], self.join_predicates
4444
):
45-
if len(preds):
45+
if preds:
4646
sqla_preds = [self._translate(pred) for pred in preds]
4747
onclause = functools.reduce(sql.and_, sqla_preds)
4848
else:
@@ -59,9 +59,23 @@ def get_result(self):
5959
elif jtype is ops.OuterJoin:
6060
result = result.outerjoin(table, onclause, full=True)
6161
elif jtype is ops.LeftSemiJoin:
62-
result = result.select().where(sa.exists(sa.select(1).where(onclause)))
62+
# subquery is required for semi and anti joins done using
63+
# sqlalchemy, otherwise multiple references to the original
64+
# select are treated as distinct tables
65+
#
66+
# with a subquery, the result is a distinct table and so there's only one
67+
# thing for subsequent expressions to reference
68+
result = (
69+
result.select()
70+
.where(sa.exists(sa.select(1).where(onclause)))
71+
.subquery()
72+
)
6373
elif jtype is ops.LeftAntiJoin:
64-
result = result.select().where(~sa.exists(sa.select(1).where(onclause)))
74+
result = (
75+
result.select()
76+
.where(~sa.exists(sa.select(1).where(onclause)))
77+
.subquery()
78+
)
6579
else:
6680
raise NotImplementedError(jtype)
6781

@@ -227,32 +241,7 @@ def _add_select(self, table_set):
227241
if has_select_star or table_set is None:
228242
return result
229243

230-
# if we're selecting from something that isn't a subquery e.g., Select,
231-
# Alias, Table
232-
if not isinstance(table_set, sa.sql.Subquery):
233-
return result.select_from(table_set)
234-
235-
final_froms = result.get_final_froms()
236-
num_froms = len(final_froms)
237-
238-
# if the result subquery has no FROMs then we can select from the
239-
# table_set since there's only a single possibility for FROM
240-
if not num_froms:
241-
return result.select_from(table_set)
242-
243-
# we need to replace every occurrence of `result`'s `FROM`
244-
# with `table_set` to handle correlated EXISTs coming from
245-
# semi/anti-join
246-
#
247-
# previously this was `replace_selectable`, but that's deprecated so we
248-
# inline its implementation here
249-
#
250-
# sqlalchemy suggests using the functionality in sa.sql.visitors, but
251-
# that would effectively require reimplementing ClauseAdapter
252-
replaced = sa.sql.util.ClauseAdapter(table_set).traverse(result)
253-
num_froms = len(replaced.get_final_froms())
254-
assert num_froms == 1, f"num_froms == {num_froms:d}"
255-
return replaced
244+
return result.select_from(table_set)
256245

257246
def _add_group_by(self, fragment):
258247
# GROUP BY and HAVING

ibis/backends/base/sql/alchemy/registry.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,16 @@ def get_col_or_deferred_col(sa_table, colname):
9898
try:
9999
return sa_table.exported_columns[colname]
100100
except KeyError:
101-
return sa.column(colname)
101+
# cols is a sqlalchemy column collection which contains column
102+
# names that are secretly prefixed by their containing table
103+
#
104+
# sqlalchemy doesn't let you select by the *un*prefixed column name
105+
# despite the uniqueness of `colname`
106+
#
107+
# however, in ibis we have already deduplicated column names so we can
108+
# refer to the name by position
109+
colindex = op.table.schema._name_locs[colname]
110+
return cols[colindex]
102111

103112

104113
def _table_column(t, op):
@@ -123,9 +132,18 @@ def _table_column(t, op):
123132

124133

125134
def _table_array_view(t, op):
135+
# the table that the TableArrayView op contains (op.table) has
136+
# one or more input relations that we need to "pin" for sqlalchemy's
137+
# auto correlation functionality -- this is what `.correlate_except` does
138+
#
139+
# every relation that is NOT passed to `correlate_except` is considered an
140+
# outer-query table
126141
ctx = t.context
127142
table = ctx.get_compiled_expr(op.table)
128-
return table
143+
# TODO: handle the case of `op.table` being a join
144+
first, *_ = an.find_immediate_parent_tables(op.table, keep_input=False)
145+
ref = ctx.get_ref(first)
146+
return table.correlate_except(ref)
129147

130148

131149
def _exists_subquery(t, op):

ibis/backends/clickhouse/compiler/values.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,20 @@ def _truncate(op, **kw):
548548
@translate_val.register(ops.ExistsSubquery)
549549
@translate_val.register(ops.NotExistsSubquery)
550550
def _exists_subquery(op, **kw):
551-
foreign_table = translate_val(op.foreign_table, **kw)
551+
# https://github.com/ClickHouse/ClickHouse/issues/6697
552+
#
553+
# this would work, if clickhouse supported correlated subqueries
554+
from ibis.backends.clickhouse.compiler.relations import translate_rel
555+
556+
foreign_table = translate_rel(op.foreign_table, **kw)
552557
predicates = translate_val(op.predicates, **kw)
553-
subq = sg.subquery(foreign_table.where(predicates, dialect="clickhouse").select(1))
558+
subq = (
559+
sg.select(1)
560+
.from_(foreign_table, dialect="clickhouse")
561+
.where(sg.condition(predicates), dialect="clickhouse")
562+
)
554563
prefix = "NOT " * isinstance(op, ops.NotExistsSubquery)
555-
return f"{prefix}EXISTS {subq}"
564+
return f"{prefix}EXISTS ({subq})"
556565

557566

558567
@translate_val.register(ops.StringSplit)

ibis/backends/tests/test_generic.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import decimal
22
import io
33
from contextlib import redirect_stdout
4-
from operator import invert, neg
4+
from operator import invert, methodcaller, neg
55

66
import numpy as np
77
import pandas as pd
@@ -745,3 +745,27 @@ def test_int_scalar(alltypes):
745745
result = expr.execute()
746746
assert expr.type() == dt.int16
747747
assert result.dtype == np.int16
748+
749+
750+
@pytest.mark.notimpl(["dask", "datafusion", "pandas", "polars"])
751+
@pytest.mark.notyet(
752+
["clickhouse"], reason="https://github.com/ClickHouse/ClickHouse/issues/6697"
753+
)
754+
@pytest.mark.notyet(["pyspark"])
755+
@pytest.mark.parametrize(
756+
"method_name",
757+
[
758+
"any",
759+
param(
760+
"notany",
761+
marks=pytest.mark.broken(
762+
["impala"], reason="aliases are incorrectly elided"
763+
),
764+
),
765+
],
766+
)
767+
def test_exists(batting, awards_players, method_name):
768+
method = methodcaller(method_name)
769+
expr = batting[method(batting.yearID == awards_players.yearID)]
770+
result = expr.execute()
771+
assert not result.empty

0 commit comments

Comments
 (0)