Skip to content

Commit 00e8087

Browse files
committed
feat(duckdb/postgres/mysql/pyspark): implement .sql on tables for mixing sql and expressions
1 parent a366d9c commit 00e8087

File tree

18 files changed

+441
-86
lines changed

18 files changed

+441
-86
lines changed

ibis/backends/base/sql/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,9 @@ def has_operation(cls, operation: type[ops.ValueOp]) -> bool:
263263
translator = cls.compiler.translator_class
264264
op_classes = translator._registry.keys() | translator._rewrites.keys()
265265
return operation in op_classes
266+
267+
def _create_temp_view(self, view, definition):
268+
raise NotImplementedError(
269+
f"The {self.name} backend does not implement temporary view "
270+
"creation"
271+
)

ibis/backends/base/sql/alchemy/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def do_connect(self, con: sa.engine.Engine) -> None:
8686
self._inspector = sa.inspect(self.con)
8787
self.meta = sa.MetaData(bind=self.con)
8888
self._schemas: dict[str, sch.Schema] = {}
89+
self._temp_views: set[str] = set()
8990

9091
@property
9192
def version(self):
@@ -478,3 +479,33 @@ def insert(
478479
"is not a pandas DataFrame or is not a ibis TableExpr."
479480
f"The given obj is of type {type(obj).__name__} ."
480481
)
482+
483+
def _get_temp_view_definition(
484+
self,
485+
name: str,
486+
definition: sa.sql.compiler.Compiled,
487+
) -> str:
488+
raise NotImplementedError(
489+
f"The {self.name} backend does not implement temporary view "
490+
"creation"
491+
)
492+
493+
def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
494+
pass
495+
496+
def _create_temp_view(
497+
self,
498+
view: sa.Table,
499+
definition: sa.sql.Selectable,
500+
) -> None:
501+
raw_name = view.name
502+
if raw_name not in self._temp_views and raw_name in self.list_tables():
503+
raise ValueError(f"{raw_name} already exists as a table or view")
504+
505+
name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
506+
compiled = definition.compile()
507+
defn = self._get_temp_view_definition(name, definition=compiled)
508+
query = sa.text(defn).bindparams(**compiled.params)
509+
self.con.execute(query, definition)
510+
self._temp_views.add(raw_name)
511+
self._register_temp_view_cleanup(name, raw_name)

ibis/backends/base/sql/alchemy/query_builder.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from __future__ import annotations
2+
13
import functools
24

35
import sqlalchemy as sa
46
import sqlalchemy.sql as sql
57

68
import ibis.expr.operations as ops
9+
import ibis.expr.schema as sch
710
import ibis.expr.types as ir
811
from ibis.backends.base.sql.compiler import (
912
Compiler,
@@ -18,6 +21,10 @@
1821
from .translator import AlchemyContext, AlchemyExprTranslator
1922

2023

24+
def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
25+
return [sa.column(n, to_sqla_type(t)) for n, t in schema.items()]
26+
27+
2128
class _AlchemyTableSetFormatter(TableSetFormatter):
2229
def get_result(self):
2330
# Got to unravel the join stack; the nesting order could be
@@ -85,8 +92,19 @@ def _format_table(self, expr):
8592
schema = ref_op.schema
8693
result = sa.table(
8794
ref_op.name,
88-
*(sa.column(n, to_sqla_type(t)) for n, t in schema.items()),
95+
*_schema_to_sqlalchemy_columns(schema),
96+
)
97+
elif isinstance(ref_op, ops.SQLStringView):
98+
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
99+
result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name)
100+
elif isinstance(ref_op, ops.View):
101+
definition = ref_op.child.compile()
102+
result = sa.table(
103+
ref_op.name,
104+
*_schema_to_sqlalchemy_columns(ref_op.schema),
89105
)
106+
backend = ref_op.child._find_backend()
107+
backend._create_temp_view(view=result, definition=definition)
90108
else:
91109
# A subquery
92110
if ctx.is_extracted(ref_expr):

ibis/backends/base/sql/compiler/extract_subqueries.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,21 @@ def visit_Difference(self, expr):
109109
self.visit(op.right)
110110
self.observe(expr)
111111

112-
def visit_MaterializedJoin(self, expr):
113-
self.visit(expr.op().join)
114-
self.observe(expr)
115-
116112
def visit_Selection(self, expr):
117113
self.visit(expr.op().table)
118114
self.observe(expr)
119115

120116
def visit_SQLQueryResult(self, expr):
121117
self.observe(expr)
122118

119+
def visit_View(self, expr):
120+
self.visit(expr.op().child)
121+
self.observe(expr)
122+
123+
def visit_SQLStringView(self, expr):
124+
self.visit(expr.op().child)
125+
self.observe(expr)
126+
123127
def visit_TableColumn(self, expr):
124128
table = expr.op().table
125129
if not self.seen(table):

ibis/backends/base/sql/compiler/translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_compiled_expr(self, expr):
7575
pass
7676

7777
op = expr.op()
78-
if isinstance(op, ops.SQLQueryResult):
78+
if isinstance(op, (ops.SQLQueryResult, ops.SQLStringView)):
7979
result = op.query
8080
else:
8181
result = self._compile_subquery(expr)

ibis/backends/conftest.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,11 @@ def pytest_runtest_call(item):
262262
def backend(request, data_directory):
263263
"""Return an instance of BackendTest."""
264264
cls = _get_backend_conf(request.param)
265-
return cls(data_directory)
265+
result = cls(data_directory)
266+
try:
267+
yield result
268+
finally:
269+
result.cleanup()
266270

267271

268272
@pytest.fixture(scope='session')
@@ -286,7 +290,11 @@ def ddl_backend(request, data_directory):
286290
(sqlite, postgres, mysql, datafusion, clickhouse, pyspark, impala)
287291
"""
288292
cls = _get_backend_conf(request.param)
289-
return cls(data_directory)
293+
result = cls(data_directory)
294+
try:
295+
yield result
296+
finally:
297+
result.cleanup()
290298

291299

292300
@pytest.fixture(scope='session')
@@ -315,7 +323,11 @@ def alchemy_backend(request, data_directory):
315323
)
316324
else:
317325
cls = _get_backend_conf(request.param)
318-
return cls(data_directory)
326+
result = cls(data_directory)
327+
try:
328+
yield result
329+
finally:
330+
result.cleanup()
319331

320332

321333
@pytest.fixture(scope='session')
@@ -335,7 +347,11 @@ def udf_backend(request, data_directory):
335347
Runs the UDF-supporting backends
336348
"""
337349
cls = _get_backend_conf(request.param)
338-
return cls(data_directory)
350+
result = cls(data_directory)
351+
try:
352+
yield result
353+
finally:
354+
result.cleanup()
339355

340356

341357
@pytest.fixture(scope='session')

ibis/backends/duckdb/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1515

1616
from .compiler import DuckDBSQLCompiler
17+
from .datatypes import parse_type
1718

1819

1920
class Backend(BaseAlchemyBackend):
@@ -69,4 +70,16 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
6970
"""Return an ibis Schema from a SQL string."""
7071
with self.con.connect() as con:
7172
rel = con.connection.c.query(query)
72-
return sch.infer(rel)
73+
return sch.Schema.from_dict(
74+
{
75+
name: parse_type(type)
76+
for name, type in zip(rel.columns, rel.types)
77+
}
78+
)
79+
80+
def _get_temp_view_definition(
81+
self,
82+
name: str,
83+
definition: sa.sql.compiler.Compiled,
84+
) -> str:
85+
return f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"

ibis/backends/mysql/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import atexit
56
import contextlib
67
import warnings
78
from typing import Literal
@@ -133,6 +134,22 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
133134
]
134135
return sch.Schema.from_tuples(fields)
135136

137+
def _get_temp_view_definition(
138+
self,
139+
name: str,
140+
definition: sa.sql.compiler.Compiled,
141+
) -> str:
142+
return f"CREATE OR REPLACE VIEW {name} AS {definition}"
143+
144+
def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
145+
query = f"DROP VIEW IF EXISTS {name}"
146+
147+
def drop(self, raw_name: str, query: str):
148+
self.con.execute(query)
149+
self._temp_views.discard(raw_name)
150+
151+
atexit.register(drop, self, raw_name, query)
152+
136153

137154
# TODO(kszucs): unsigned integers
138155

ibis/backends/postgres/__init__.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
import sqlalchemy as sa
99

10-
import ibis.backends.duckdb.datatypes as ddb
11-
import ibis.expr.datatypes as dt
1210
import ibis.expr.schema as sch
1311
from ibis import util
1412
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1513

1614
from .compiler import PostgreSQLCompiler
15+
from .datatypes import _get_type
1716
from .udf import udf
1817

1918

@@ -205,71 +204,9 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
205204
tuples = [(col, _get_type(typestr)) for col, typestr in type_info]
206205
return sch.Schema.from_tuples(tuples)
207206

208-
209-
def _get_type(typestr: str) -> dt.DataType:
210-
try:
211-
return _type_mapping[typestr]
212-
except KeyError:
213-
return ddb.parse_type(typestr)
214-
215-
216-
_type_mapping = {
217-
"boolean": dt.bool,
218-
"boolean[]": dt.Array(dt.bool),
219-
"bytea": dt.binary,
220-
"bytea[]": dt.Array(dt.binary),
221-
"character(1)": dt.string,
222-
"character(1)[]": dt.Array(dt.string),
223-
"bigint": dt.int64,
224-
"bigint[]": dt.Array(dt.int64),
225-
"smallint": dt.int16,
226-
"smallint[]": dt.Array(dt.int16),
227-
"integer": dt.int32,
228-
"integer[]": dt.Array(dt.int32),
229-
"text": dt.string,
230-
"text[]": dt.Array(dt.string),
231-
"json": dt.json,
232-
"json[]": dt.Array(dt.json),
233-
"point": dt.point,
234-
"point[]": dt.Array(dt.point),
235-
"polygon": dt.polygon,
236-
"polygon[]": dt.Array(dt.polygon),
237-
"line": dt.linestring,
238-
"line[]": dt.Array(dt.linestring),
239-
"real": dt.float32,
240-
"real[]": dt.Array(dt.float32),
241-
"double precision": dt.float64,
242-
"double precision[]": dt.Array(dt.float64),
243-
"macaddr8": dt.macaddr,
244-
"macaddr8[]": dt.Array(dt.macaddr),
245-
"macaddr": dt.macaddr,
246-
"macaddr[]": dt.Array(dt.macaddr),
247-
"inet": dt.inet,
248-
"inet[]": dt.Array(dt.inet),
249-
"character": dt.string,
250-
"character[]": dt.Array(dt.string),
251-
"character varying": dt.string,
252-
"character varying[]": dt.Array(dt.string),
253-
"date": dt.date,
254-
"date[]": dt.Array(dt.date),
255-
"time without time zone": dt.time,
256-
"time without time zone[]": dt.Array(dt.time),
257-
"timestamp without time zone": dt.timestamp,
258-
"timestamp without time zone[]": dt.Array(dt.timestamp),
259-
"timestamp with time zone": dt.Timestamp("UTC"),
260-
"timestamp with time zone[]": dt.Array(dt.Timestamp("UTC")),
261-
"interval": dt.interval,
262-
"interval[]": dt.Array(dt.interval),
263-
# NB: this isn"t correct, but we try not to fail
264-
"time with time zone": "time",
265-
"numeric": dt.decimal,
266-
"numeric[]": dt.Array(dt.decimal),
267-
"uuid": dt.uuid,
268-
"uuid[]": dt.Array(dt.uuid),
269-
"jsonb": dt.jsonb,
270-
"jsonb[]": dt.Array(dt.jsonb),
271-
"geometry": dt.geometry,
272-
"geometry[]": dt.Array(dt.geometry),
273-
"geography": dt.geography,
274-
"geography[]": dt.Array(dt.geography),
275-
}
207+
def _get_temp_view_definition(
208+
self,
209+
name: str,
210+
definition: sa.sql.compiler.Compiled,
211+
) -> str:
212+
return f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"

0 commit comments

Comments
 (0)