Skip to content

Commit b0f4e4c

Browse files
committed
refactor(sqlalchemy): generalize handling of failed type inference
1 parent 23c35e1 commit b0f4e4c

File tree

3 files changed

+50
-75
lines changed

3 files changed

+50
-75
lines changed

ibis/backends/base/sql/alchemy/__init__.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import abc
44
import contextlib
55
import getpass
6+
import warnings
67
from operator import methodcaller
78
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping
89

@@ -372,13 +373,38 @@ def _log(self, sql):
372373
util.log(query_str)
373374

374375
def _get_sqla_table(
375-
self,
376-
name: str,
377-
schema: str | None = None,
378-
autoload: bool = True,
379-
**kwargs: Any,
376+
self, name: str, schema: str | None = None, autoload: bool = True, **kwargs: Any
380377
) -> sa.Table:
381-
return sa.Table(name, self.meta, schema=schema, autoload=autoload)
378+
with warnings.catch_warnings():
379+
warnings.filterwarnings(
380+
"ignore", message="Did not recognize type", category=sa.exc.SAWarning
381+
)
382+
table = sa.Table(name, self.meta, schema=schema, autoload=autoload)
383+
nulltype_cols = frozenset(
384+
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
385+
)
386+
387+
if not nulltype_cols:
388+
return table
389+
return self._handle_failed_column_type_inference(table, nulltype_cols)
390+
391+
def _handle_failed_column_type_inference(
392+
self, table: sa.Table, nulltype_cols: Iterable[str]
393+
) -> sa.Table:
394+
"""Handle cases where SQLAlchemy cannot infer the column types of `table`."""
395+
396+
self.inspector.reflect_table(table, table.columns)
397+
quoted_name = self.con.dialect.identifier_preparer.quote(table.name)
398+
399+
for colname, type in self._metadata(quoted_name):
400+
if colname in nulltype_cols:
401+
# replace null types discovered by sqlalchemy with non null
402+
# types
403+
table.append_column(
404+
sa.Column(colname, to_sqla_type(type), nullable=type.nullable),
405+
replace_existing=True,
406+
)
407+
return table
382408

383409
def _sqla_table_to_expr(self, table: sa.Table) -> ir.Table:
384410
schema = self._schemas.get(table.name)

ibis/backends/duckdb/__init__.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import ibis.expr.datatypes as dt
1616
from ibis import util
17-
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
1817

1918
if TYPE_CHECKING:
2019
import duckdb
@@ -371,22 +370,7 @@ def read_postgres(self, uri, table_name=None):
371370
return self._read(table_name)
372371

373372
def _read(self, table_name):
374-
375373
_table = self.table(table_name)
376-
with warnings.catch_warnings():
377-
# don't fail or warn if duckdb-engine fails to discover types
378-
# mostly (tinyint)
379-
warnings.filterwarnings(
380-
"ignore",
381-
message="Did not recognize type",
382-
category=sa.exc.SAWarning,
383-
)
384-
# We don't rely on index reflection, ignore this warning
385-
warnings.filterwarnings(
386-
"ignore",
387-
message="duckdb-engine doesn't yet support reflection on indices",
388-
)
389-
self.inspector.reflect_table(_table.op().sqla_table, _table.columns)
390374
return self.table(table_name)
391375

392376
def to_pyarrow_batches(
@@ -476,48 +460,15 @@ def _register_in_memory_table(self, table_op):
476460
self.con.execute("register", (table_op.name, df))
477461

478462
def _get_sqla_table(
479-
self,
480-
name: str,
481-
schema: str | None = None,
482-
**kwargs: Any,
463+
self, name: str, schema: str | None = None, **kwargs: Any
483464
) -> sa.Table:
484465
with warnings.catch_warnings():
485-
# don't fail or warn if duckdb-engine fails to discover types
486-
warnings.filterwarnings(
487-
"ignore",
488-
message="Did not recognize type",
489-
category=sa.exc.SAWarning,
490-
)
491466
# We don't rely on index reflection, ignore this warning
492467
warnings.filterwarnings(
493468
"ignore",
494469
message="duckdb-engine doesn't yet support reflection on indices",
495470
)
496-
497-
table = super()._get_sqla_table(name, schema, **kwargs)
498-
499-
nulltype_cols = frozenset(
500-
col.name for col in table.c if isinstance(col.type, sa.types.NullType)
501-
)
502-
503-
if not nulltype_cols:
504-
return table
505-
506-
quoted_name = self.con.dialect.identifier_preparer.quote(name)
507-
508-
for colname, type in self._metadata(quoted_name):
509-
if colname in nulltype_cols:
510-
# replace null types discovered by sqlalchemy with non null
511-
# types
512-
table.append_column(
513-
sa.Column(
514-
colname,
515-
to_sqla_type(type),
516-
nullable=type.nullable,
517-
),
518-
replace_existing=True,
519-
)
520-
return table
471+
return super()._get_sqla_table(name, schema, **kwargs)
521472

522473
def _get_temp_view_definition(
523474
self,

ibis/backends/mssql/__init__.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
from __future__ import annotations
44

5-
import atexit
65
import contextlib
7-
from typing import TYPE_CHECKING, Iterable, Literal
6+
from typing import TYPE_CHECKING, Literal
87

98
import sqlalchemy as sa
109

1110
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
1211
from ibis.backends.mssql.compiler import MsSqlCompiler
13-
from ibis.backends.mssql.datatypes import _type_from_result_set_info
12+
from ibis.backends.mssql.datatypes import _FieldDescription, _type_from_result_set_info
1413

1514
if TYPE_CHECKING:
16-
import ibis.expr.datatypes as dt
15+
pass
1716

1817

1918
class Backend(BaseAlchemyBackend):
@@ -54,25 +53,24 @@ def begin(self):
5453
finally:
5554
bind.execute(f"SET DATEFIRST {previous_datefirst}")
5655

57-
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
56+
def _metadata(self, query):
57+
if query in self.list_tables():
58+
query = f"SELECT * FROM [{query}]"
59+
5860
with self.begin() as bind:
59-
for column in bind.execute(
60-
f"EXEC sp_describe_first_result_set @tsql = N'{query}';"
61-
).mappings():
62-
yield column["name"], _type_from_result_set_info(column)
61+
result_set_info: list[_FieldDescription] = (
62+
bind.execute(f"EXEC sp_describe_first_result_set @tsql = N'{query}';")
63+
.mappings()
64+
.fetchall()
65+
)
66+
return [
67+
(column['name'], _type_from_result_set_info(column))
68+
for column in result_set_info
69+
]
6370

6471
def _get_temp_view_definition(
6572
self,
6673
name: str,
6774
definition: sa.sql.compiler.Compiled,
6875
) -> str:
6976
return f"CREATE OR ALTER VIEW {name} AS {definition}"
70-
71-
def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
72-
query = f"DROP VIEW IF EXISTS {name}"
73-
74-
def drop(self, raw_name: str, query: str):
75-
self.con.execute(query)
76-
self._temp_views.discard(raw_name)
77-
78-
atexit.register(drop, self, raw_name, query)

0 commit comments

Comments
 (0)