Skip to content

Commit a416755

Browse files
YStrauchP4cpcloud
authored andcommitted
fix(ir): make ibis.to_sql() parse join limits and projections; fixes #11105
1 parent 6af1dbf commit a416755

File tree

5 files changed

+113
-10
lines changed

5 files changed

+113
-10
lines changed

ibis/expr/sql.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ def overlay(self, step):
7171
return Catalog({**self, **updates})
7272

7373

74+
def apply_limit(table, step):
75+
"""Applies a LIMIT, if applicable."""
76+
77+
if not isinstance(step.limit, int):
78+
return table
79+
80+
return table.limit(step.limit)
81+
82+
83+
def apply_projections(table, step, catalog):
84+
"""Applies a SELECT projection, if applicable."""
85+
86+
if not step.projections:
87+
return table
88+
89+
projs = [convert(proj, catalog=catalog) for proj in step.projections]
90+
return table.select(projs)
91+
92+
7493
@singledispatch
7594
def convert(step, catalog):
7695
raise TypeError(type(step))
@@ -86,12 +105,8 @@ def convert_scan(scan, catalog):
86105
pred = convert(scan.condition, catalog=catalog)
87106
table = table.filter(pred)
88107

89-
if scan.projections:
90-
projs = [convert(proj, catalog=catalog) for proj in scan.projections]
91-
table = table.select(projs)
92-
93-
if isinstance(scan.limit, int):
94-
table = table.limit(scan.limit)
108+
table = apply_projections(table, scan, catalog)
109+
table = apply_limit(table, scan)
95110

96111
return table
97112

@@ -156,8 +171,7 @@ def convert_sort(sort, catalog):
156171
]
157172
table = table.select(projs)
158173

159-
if isinstance(sort.limit, int):
160-
table = table.limit(sort.limit)
174+
table = apply_limit(table, sort)
161175

162176
return table
163177

@@ -203,6 +217,9 @@ def convert_join(join, catalog):
203217
predicate = convert(join.condition, catalog=catalog)
204218
left_table = left_table.filter(predicate)
205219

220+
left_table = apply_projections(left_table, join, catalog)
221+
left_table = apply_limit(left_table, join)
222+
206223
catalog[left_name] = left_table
207224

208225
return left_table
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import ibis
2+
3+
4+
call = ibis.table(
5+
name="call",
6+
schema={
7+
"start_time": "timestamp",
8+
"end_time": "timestamp",
9+
"employee_id": "int64",
10+
"call_outcome_id": "int64",
11+
"call_attempts": "int64",
12+
},
13+
)
14+
agg = call.aggregate([call.call_attempts.mean().name("mean")])
15+
16+
result = call.inner_join(
17+
agg, [(call.call_attempts > agg.mean), ibis.literal(True)]
18+
).select(
19+
call.start_time,
20+
call.end_time,
21+
call.employee_id,
22+
call.call_outcome_id,
23+
call.call_attempts,
24+
agg.mean,
25+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import ibis
2+
3+
4+
employee = ibis.table(
5+
name="employee",
6+
schema={"first_name": "string", "last_name": "string", "id": "int64"},
7+
)
8+
call = ibis.table(
9+
name="call",
10+
schema={
11+
"start_time": "timestamp",
12+
"end_time": "timestamp",
13+
"employee_id": "int64",
14+
"call_outcome_id": "int64",
15+
"call_attempts": "int64",
16+
},
17+
)
18+
19+
result = (
20+
employee.inner_join(call, [(employee.id == call.employee_id), ibis.literal(True)])
21+
.select(
22+
employee.first_name,
23+
employee.last_name,
24+
employee.id,
25+
call.start_time,
26+
call.end_time,
27+
call.employee_id,
28+
call.call_outcome_id,
29+
call.call_attempts,
30+
employee.first_name.name("first"),
31+
)
32+
.limit(3)
33+
)

ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@
2121
call.employee_id,
2222
call.call_outcome_id,
2323
call.call_attempts,
24-
agg.mean,
2524
)

ibis/expr/tests/test_sql.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,23 @@ def test_parse_sql_basic_join(how, snapshot):
4444
snapshot.assert_match(code, "decompiled.py")
4545

4646

47+
def test_parse_sql_limited_join(snapshot):
48+
sql = """
49+
SELECT
50+
*,
51+
first_name as first
52+
FROM employee
53+
JOIN call ON
54+
employee.id = call.employee_id
55+
LIMIT 3"""
56+
expr = ibis.parse_sql(sql, catalog)
57+
code = ibis.decompile(expr, format=True)
58+
snapshot.assert_match(code, "decompiled.py")
59+
60+
4761
def test_parse_sql_multiple_joins(snapshot):
4862
sql = """
49-
SELECT *
63+
SELECT employee.*, call.*, call_outcome.outcome_text, call_outcome.id as id_right
5064
FROM employee
5165
JOIN call
5266
ON employee.id = call.employee_id
@@ -118,6 +132,21 @@ def test_parse_sql_scalar_subquery(snapshot):
118132
snapshot.assert_match(code, "decompiled.py")
119133

120134

135+
def test_parse_sql_join_subquery(snapshot):
136+
sql = """
137+
SELECT *
138+
FROM call
139+
INNER JOIN (
140+
SELECT
141+
AVG(call.call_attempts) AS mean
142+
FROM call
143+
) AS subq
144+
ON subq.mean < call.call_attempts"""
145+
expr = ibis.parse_sql(sql, catalog)
146+
code = ibis.decompile(expr, format=True)
147+
snapshot.assert_match(code, "decompiled.py")
148+
149+
121150
def test_parse_sql_simple_select_count(snapshot):
122151
sql = """SELECT COUNT(first_name) FROM employee"""
123152
expr = ibis.parse_sql(sql, catalog)

0 commit comments

Comments
 (0)