Skip to content

Commit 296cd7d

Browse files
committed
refactor(sqlalchemy): move _get_schema_using_query to base class
1 parent cfd5061 commit 296cd7d

File tree

8 files changed

+54
-57
lines changed

8 files changed

+54
-57
lines changed

ibis/backends/base/sql/alchemy/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import abc
34
import contextlib
45
import getpass
56
from operator import methodcaller
6-
from typing import TYPE_CHECKING, Any, Literal, Mapping
7+
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping
78

89
import sqlalchemy as sa
910

@@ -39,6 +40,8 @@
3940
if TYPE_CHECKING:
4041
import pandas as pd
4142

43+
import ibis.expr.datatypes as dt
44+
4245

4346
__all__ = (
4447
'BaseAlchemyBackend',
@@ -554,3 +557,11 @@ def _create_temp_view(self, view: sa.Table, definition: sa.sql.Selectable) -> No
554557
con.execute(compiled, **params)
555558
self._temp_views.add(raw_name)
556559
self._register_temp_view_cleanup(name, raw_name)
560+
561+
@abc.abstractmethod
562+
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
563+
...
564+
565+
def _get_schema_using_query(self, query: str) -> sch.Schema:
566+
"""Return an ibis Schema from a backend-specific SQL string."""
567+
return sch.Schema.from_tuples(self._metadata(query))

ibis/backends/duckdb/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,6 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
471471
ibis_type = parse(type)
472472
yield name, ibis_type(nullable=null.lower() == "yes")
473473

474-
def _get_schema_using_query(self, query: str) -> sch.Schema:
475-
"""Return an ibis Schema from a DuckDB SQL string."""
476-
return sch.Schema.from_tuples(self._metadata(query))
477-
478474
def _register_in_memory_table(self, table_op):
479475
df = table_op.data.to_frame()
480476
self.con.execute("register", (table_op.name, df))

ibis/backends/mssql/__init__.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44

55
import atexit
66
import contextlib
7-
from typing import Literal
7+
from typing import TYPE_CHECKING, Iterable, Literal
88

99
import sqlalchemy as sa
1010

11-
import ibis.expr.schema as sch
1211
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1312
from ibis.backends.mssql.compiler import MsSqlCompiler
14-
from ibis.backends.mssql.datatypes import _FieldDescription, _type_from_result_set_info
13+
from ibis.backends.mssql.datatypes import _type_from_result_set_info
14+
15+
if TYPE_CHECKING:
16+
import ibis.expr.datatypes as dt
1517

1618

1719
class Backend(BaseAlchemyBackend):
@@ -52,17 +54,12 @@ def begin(self):
5254
finally:
5355
bind.execute(f"SET DATEFIRST {previous_datefirst}")
5456

55-
def _get_schema_using_query(self, query):
57+
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
5658
with self.begin() as bind:
57-
result = bind.execute(
59+
for column in bind.execute(
5860
f"EXEC sp_describe_first_result_set @tsql = N'{query}';"
59-
)
60-
result_set_info: list[_FieldDescription] = result.mappings().fetchall()
61-
fields = [
62-
(column['name'], _type_from_result_set_info(column))
63-
for column in result_set_info
64-
]
65-
return sch.Schema.from_tuples(fields)
61+
).mappings():
62+
yield column["name"], _type_from_result_set_info(column)
6663

6764
def _get_temp_view_definition(
6865
self,

ibis/backends/mysql/__init__.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import atexit
66
import contextlib
77
import warnings
8-
from typing import Literal
8+
from typing import Iterable, Literal
99

1010
import sqlalchemy as sa
1111
from sqlalchemy.dialects import mysql
1212

1313
import ibis.expr.datatypes as dt
14-
import ibis.expr.schema as sch
1514
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1615
from ibis.backends.mysql.compiler import MySQLCompiler
1716
from ibis.backends.mysql.datatypes import _type_from_cursor_info
@@ -121,15 +120,14 @@ def begin(self):
121120
except Exception as e: # noqa: BLE001
122121
warnings.warn(f"Couldn't reset MySQL timezone: {str(e)}")
123122

124-
def _get_schema_using_query(self, query: str) -> sch.Schema:
125-
"""Infer the schema of `query`."""
126-
result = self.con.execute(f"SELECT * FROM ({query}) _ LIMIT 0")
127-
cursor = result.cursor
128-
fields = [
129-
(field.name, _type_from_cursor_info(descr, field))
130-
for descr, field in zip(cursor.description, cursor._result.fields)
131-
]
132-
return sch.Schema.from_tuples(fields)
123+
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
124+
with self.con.begin() as con:
125+
result = con.execute(f"SELECT * FROM ({query}) _ LIMIT 0")
126+
cursor = result.cursor
127+
yield from (
128+
(field.name, _type_from_cursor_info(descr, field))
129+
for descr, field in zip(cursor.description, cursor._result.fields)
130+
)
133131

134132
def _get_temp_view_definition(
135133
self,

ibis/backends/postgres/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
from __future__ import annotations
44

55
import contextlib
6-
from typing import Literal
6+
from typing import TYPE_CHECKING, Iterable, Literal
77

88
import sqlalchemy as sa
99

10-
import ibis.expr.schema as sch
1110
from ibis import util
1211
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1312
from ibis.backends.postgres.compiler import PostgreSQLCompiler
1413
from ibis.backends.postgres.datatypes import _get_type
1514
from ibis.backends.postgres.udf import udf as _udf
1615

16+
if TYPE_CHECKING:
17+
import ibis.expr.datatypes as dt
18+
1719

1820
class Backend(BaseAlchemyBackend):
1921
name = 'postgres'
@@ -172,7 +174,7 @@ def udf(
172174
language=language,
173175
)
174176

175-
def _get_schema_using_query(self, query: str) -> sch.Schema:
177+
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
176178
raw_name = util.guid()
177179
name = self.con.dialect.identifier_preparer.quote_identifier(raw_name)
178180
type_info_sql = f"""\
@@ -187,12 +189,9 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
187189
"""
188190
with self.begin() as con:
189191
con.execute(f"CREATE TEMPORARY VIEW {name} AS {query}")
190-
try:
191-
type_info = con.execute(type_info_sql).fetchall()
192-
finally:
193-
con.execute(f"DROP VIEW {name}")
194-
tuples = [(col, _get_type(typestr)) for col, typestr in type_info]
195-
return sch.Schema.from_tuples(tuples)
192+
type_info = con.execute(type_info_sql)
193+
yield from ((col, _get_type(typestr)) for col, typestr in type_info)
194+
con.execute(f"DROP VIEW IF EXISTS {name}")
196195

197196
def _get_temp_view_definition(
198197
self,

ibis/backends/snowflake/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import contextlib
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Iterable
55

66
import snowflake.connector as sfc
77
import sqlalchemy as sa
8+
import toolz
89
from snowflake.sqlalchemy import URL
910

1011
import ibis.expr.operations as ops
@@ -20,6 +21,8 @@
2021
if TYPE_CHECKING:
2122
import pandas as pd
2223

24+
import ibis.expr.datatypes as dt
25+
2326

2427
_NATIVE_ARROW = True
2528

@@ -117,16 +120,16 @@ def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
117120
return schema.apply_to(df)
118121
return super().fetch_from_cursor(cursor, schema)
119122

120-
def _get_schema_using_query(self, query):
123+
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
121124
with self.begin() as bind:
122125
result = bind.execute(f"SELECT * FROM ({query}) t0 LIMIT 0")
123126
info_rows = bind.execute(f"DESCRIBE RESULT {result.cursor.sfqid!r}")
124127

125-
schema = {}
126-
for name, raw_type, _, null, *_ in info_rows:
127-
typ = parse(raw_type)
128-
schema[name] = typ(nullable=null.upper() == "Y")
129-
return sch.Schema.from_dict(schema)
128+
for name, raw_type, null in toolz.pluck(
129+
["name", "type", "null?"], info_rows
130+
):
131+
typ = parse(raw_type)
132+
yield name, typ(nullable=null.upper() == "Y")
130133

131134
def list_databases(self, like=None) -> list[str]:
132135
databases = [

ibis/backends/sqlite/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@
1818
import sqlite3
1919
import warnings
2020
from pathlib import Path
21-
from typing import TYPE_CHECKING
21+
from typing import TYPE_CHECKING, Iterable
2222

2323
import sqlalchemy as sa
2424
from sqlalchemy.dialects.sqlite import DATETIME, TIMESTAMP
2525

26-
if TYPE_CHECKING:
27-
import ibis.expr.datatypes as dt
28-
import ibis.expr.types as ir
29-
30-
import ibis.expr.schema as sch
3126
from ibis.backends.base import Database
3227
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend, to_sqla_type
3328
from ibis.backends.sqlite import udf
3429
from ibis.backends.sqlite.compiler import SQLiteCompiler
3530
from ibis.expr.schema import datatype
3631

32+
if TYPE_CHECKING:
33+
import ibis.expr.datatypes as dt
34+
import ibis.expr.types as ir
35+
3736

3837
def to_datetime(value: str | None) -> datetime.datetime | None:
3938
"""Convert a `str` to a `datetime` according to SQLite's rules.
@@ -217,7 +216,7 @@ def _table_from_schema(self, name, schema, database: str | None = None) -> sa.Ta
217216
def _current_schema(self) -> str | None:
218217
return self.current_database
219218

220-
def _get_schema_using_query(self, query: str) -> sch.Schema:
219+
def _metadata(self, _: str) -> Iterable[tuple[str, dt.DataType]]:
221220
raise ValueError(
222221
"The SQLite backend cannot infer schemas from raw SQL - "
223222
"please specify the schema directly when calling `.sql` "

ibis/backends/trino/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import toolz
99

1010
import ibis.expr.datatypes as dt
11-
import ibis.expr.schema as sch
1211
from ibis import util
1312
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1413
from ibis.backends.trino.compiler import TrinoSQLCompiler
@@ -69,8 +68,3 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
6968
for name, type in toolz.pluck(["Column Name", "Type"], rows):
7069
ibis_type = parse(type)
7170
yield name, ibis_type(nullable=True)
72-
73-
def _get_schema_using_query(self, query: str) -> sch.Schema:
74-
"""Return an ibis Schema from a DuckDB SQL string."""
75-
pairs = self._metadata(query)
76-
return sch.Schema.from_tuples(pairs)

0 commit comments

Comments
 (0)