Skip to content

Commit a8bbc00

Browse files
committed
refactor(datatype): introduce to_sqla_type dispatching on dialect
1 parent 8ce8c16 commit a8bbc00

File tree

27 files changed

+347
-189
lines changed

27 files changed

+347
-189
lines changed

ibis/backends/base/sql/alchemy/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,17 @@ def _handle_failed_column_type_inference(
402402
"""Handle cases where SQLAlchemy cannot infer the column types of `table`."""
403403

404404
self.inspector.reflect_table(table, table.columns)
405-
quoted_name = self.con.dialect.identifier_preparer.quote(table.name)
405+
dialect = self.con.dialect
406+
quoted_name = dialect.identifier_preparer.quote(table.name)
406407

407408
for colname, type in self._metadata(quoted_name):
408409
if colname in nulltype_cols:
409410
# replace null types discovered by sqlalchemy with non null
410411
# types
411412
table.append_column(
412-
sa.Column(colname, to_sqla_type(type), nullable=type.nullable),
413+
sa.Column(
414+
colname, to_sqla_type(dialect, type), nullable=type.nullable
415+
),
413416
replace_existing=True,
414417
)
415418
return table

ibis/backends/base/sql/alchemy/datatypes.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3-
import functools
43
from typing import Iterable
54

65
import sqlalchemy as sa
6+
from multipledispatch import Dispatcher
77
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
88
from sqlalchemy.dialects.mssql.base import MSDialect
99
from sqlalchemy.dialects.mysql.base import MySQLDialect
@@ -82,8 +82,9 @@ def table_from_schema(name, meta, schema, database: str | None = None):
8282
# Convert Ibis schema to SQLA table
8383
columns = []
8484

85+
dialect = getattr(meta.bind, "dialect", _DEFAULT_DIALECT)
8586
for colname, dtype in zip(schema.names, schema.types):
86-
satype = to_sqla_type(dtype)
87+
satype = to_sqla_type(dialect, dtype)
8788
column = sa.Column(colname, satype, nullable=dtype.nullable)
8889
columns.append(column)
8990

@@ -115,55 +116,56 @@ def table_from_schema(name, meta, schema, database: str | None = None):
115116
dt.UInt32: UInt32,
116117
dt.UInt64: UInt64,
117118
dt.JSON: sa.JSON,
119+
dt.Interval: sa.Interval,
118120
}
119121

120122

121-
@functools.singledispatch
122-
def to_sqla_type(itype, type_map=None):
123-
if type_map is None:
124-
type_map = ibis_type_to_sqla
125-
return type_map[type(itype)]
123+
_DEFAULT_DIALECT = DefaultDialect()
126124

125+
to_sqla_type = Dispatcher("to_sqla_type")
127126

128-
@to_sqla_type.register(dt.Decimal)
129-
def _(itype, **kwargs):
130-
return sa.types.NUMERIC(itype.precision, itype.scale)
127+
128+
@to_sqla_type.register(Dialect, dt.DataType)
129+
def _default(_, itype):
130+
return ibis_type_to_sqla[type(itype)]
131131

132132

133-
@to_sqla_type.register(dt.Interval)
134-
def _(itype, **kwargs):
135-
return sa.types.Interval()
133+
@to_sqla_type.register(Dialect, dt.Decimal)
134+
def _decimal(_, itype):
135+
return sa.types.NUMERIC(itype.precision, itype.scale)
136136

137137

138-
@to_sqla_type.register(dt.Date)
139-
def _(itype, **kwargs):
140-
return sa.Date()
138+
@to_sqla_type.register(Dialect, dt.Timestamp)
139+
def _timestamp(_, itype):
140+
return sa.TIMESTAMP(timezone=bool(itype.timezone))
141141

142142

143-
@to_sqla_type.register(dt.Timestamp)
144-
def _(itype, **kwargs):
145-
return sa.TIMESTAMP(bool(itype.timezone))
143+
@to_sqla_type.register(Dialect, dt.Array)
144+
def _array(dialect, itype):
145+
return ArrayType(to_sqla_type(dialect, itype.value_type))
146146

147147

148-
@to_sqla_type.register(dt.Array)
149-
def _(itype, **kwargs):
148+
@to_sqla_type.register(PGDialect, dt.Array)
149+
def _pg_array(dialect, itype):
150150
# Unwrap the array element type because sqlalchemy doesn't allow arrays of
151151
# arrays. This doesn't affect the underlying data.
152152
while itype.is_array():
153153
itype = itype.value_type
154-
return sa.ARRAY(to_sqla_type(itype, **kwargs))
154+
return sa.ARRAY(to_sqla_type(dialect, itype))
155155

156156

157-
@to_sqla_type.register(dt.Struct)
158-
def _(itype, **_):
157+
@to_sqla_type.register(Dialect, dt.Struct)
158+
def _struct(dialect, itype):
159159
return StructType(
160-
[(name, to_sqla_type(type)) for name, type in itype.pairs.items()]
160+
[(name, to_sqla_type(dialect, type)) for name, type in itype.pairs.items()]
161161
)
162162

163163

164-
@to_sqla_type.register(dt.Map)
165-
def _(itype, **_):
166-
return MapType(to_sqla_type(itype.key_type), to_sqla_type(itype.value_type))
164+
@to_sqla_type.register(Dialect, dt.Map)
165+
def _map(dialect, itype):
166+
return MapType(
167+
to_sqla_type(dialect, itype.key_type), to_sqla_type(dialect, itype.value_type)
168+
)
167169

168170

169171
@dt.dtype.register(Dialect, sa.types.NullType)
@@ -322,8 +324,8 @@ def ga_geometry(_, gatype, nullable=True):
322324
else:
323325
raise ValueError(f"Unrecognized geometry type: {t}")
324326

325-
@to_sqla_type.register(dt.GeoSpatial)
326-
def _(itype, **kwargs):
327+
@to_sqla_type.register(Dialect, dt.GeoSpatial)
328+
def _(_, itype, **kwargs):
327329
if itype.geotype == 'geometry':
328330
return ga.Geometry
329331
elif itype.geotype == 'geography':
@@ -406,8 +408,8 @@ def sa_datetime(_, satype, nullable=True, default_timezone='UTC'):
406408
return dt.Timestamp(timezone=timezone, nullable=nullable)
407409

408410

409-
@dt.dtype.register(Dialect, sa.ARRAY)
410-
def sa_array(dialect, satype, nullable=True):
411+
@dt.dtype.register(PGDialect, sa.ARRAY)
412+
def sa_pg_array(dialect, satype, nullable=True):
411413
dimensions = satype.dimensions
412414
if dimensions is not None and dimensions != 1:
413415
raise NotImplementedError(

ibis/backends/base/sql/alchemy/query_builder.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from sqlalchemy import sql
77

88
import ibis.expr.operations as ops
9-
import ibis.expr.schema as sch
109
from ibis.backends.base.sql.alchemy.database import AlchemyTable
11-
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
1210
from ibis.backends.base.sql.alchemy.translator import (
1311
AlchemyContext,
1412
AlchemyExprTranslator,
@@ -22,10 +20,6 @@
2220
from ibis.backends.base.sql.compiler.base import SetOp
2321

2422

25-
def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
26-
return [sa.column(n, to_sqla_type(t)) for n, t in schema.items()]
27-
28-
2923
class _AlchemyTableSetFormatter(TableSetFormatter):
3024
def get_result(self):
3125
# Got to unravel the join stack; the nesting order could be
@@ -94,32 +88,32 @@ def _format_table(self, op):
9488

9589
alias = ctx.get_ref(op)
9690

91+
translator = ctx.compiler.translator_class(ref_op, ctx)
92+
9793
if isinstance(ref_op, AlchemyTable):
9894
result = ref_op.sqla_table
9995
elif isinstance(ref_op, ops.UnboundTable):
10096
# use SQLAlchemy's TableClause for unbound tables
10197
result = sa.table(
102-
ref_op.name,
103-
*_schema_to_sqlalchemy_columns(ref_op.schema),
98+
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
10499
)
105100
elif isinstance(ref_op, ops.SQLQueryResult):
106-
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
101+
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
107102
result = sa.text(ref_op.query).columns(*columns)
108103
elif isinstance(ref_op, ops.SQLStringView):
109-
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
104+
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
110105
result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name)
111106
elif isinstance(ref_op, ops.View):
112107
# TODO(kszucs): avoid converting to expression
113108
child_expr = ref_op.child.to_expr()
114109
definition = child_expr.compile()
115110
result = sa.table(
116-
ref_op.name,
117-
*_schema_to_sqlalchemy_columns(ref_op.schema),
111+
ref_op.name, *translator._schema_to_sqlalchemy_columns(ref_op.schema)
118112
)
119113
backend = child_expr._find_backend()
120114
backend._create_temp_view(view=result, definition=definition)
121115
elif isinstance(ref_op, ops.InMemoryTable):
122-
columns = _schema_to_sqlalchemy_columns(ref_op.schema)
116+
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
123117

124118
if self.context.compiler.cheap_in_memory_tables:
125119
result = sa.table(ref_op.name, *columns)

ibis/backends/base/sql/alchemy/translator.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import ibis
66
import ibis.expr.datatypes as dt
77
import ibis.expr.operations as ops
8-
from ibis.backends.base.sql.alchemy.datatypes import ibis_type_to_sqla, to_sqla_type
8+
from ibis.backends.base.sql.alchemy import to_sqla_type
9+
from ibis.backends.base.sql.alchemy.datatypes import _DEFAULT_DIALECT
910
from ibis.backends.base.sql.alchemy.registry import (
1011
fixed_arity,
1112
sqlalchemy_operation_registry,
@@ -35,7 +36,6 @@ def subcontext(self):
3536
class AlchemyExprTranslator(ExprTranslator):
3637
_registry = sqlalchemy_operation_registry
3738
_rewrites = ExprTranslator._rewrites.copy()
38-
_type_map = ibis_type_to_sqla
3939

4040
context_class = AlchemyContext
4141

@@ -54,11 +54,26 @@ class AlchemyExprTranslator(ExprTranslator):
5454
ops.CumeDist,
5555
)
5656

57+
_dialect_name = "default"
58+
59+
@property
60+
def dialect(self) -> sa.engine.interfaces.Dialect:
61+
if (name := self._dialect_name) == "default":
62+
return _DEFAULT_DIALECT
63+
dialect_cls = sa.dialects.registry.load(name)
64+
return dialect_cls()
65+
66+
def _schema_to_sqlalchemy_columns(self, schema):
67+
return [
68+
sa.column(name, to_sqla_type(self.dialect, dtype))
69+
for name, dtype in schema.items()
70+
]
71+
5772
def name(self, translated, name, force=True):
5873
return translated.label(name)
5974

6075
def get_sqla_type(self, data_type):
61-
return to_sqla_type(data_type, type_map=self._type_map)
76+
return to_sqla_type(self.dialect, data_type)
6277

6378
def _maybe_cast_bool(self, op, arg):
6479
if (

ibis/backends/clickhouse/datatypes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,9 @@ def _(ty: dt.Struct) -> str:
257257

258258
@serialize_raw.register(dt.Timestamp)
259259
def _(ty: dt.Timestamp) -> str:
260-
return "DateTime64(6)" if ty.timezone is None else f"DateTime64(6, {ty.timezone!r})"
260+
if (scale := ty.scale) is None:
261+
scale = 3
262+
263+
if (timezone := ty.timezone) is not None:
264+
return f"DateTime64({scale:d}, {timezone})"
265+
return f"DateTime64({scale:d})"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CAST(string_col AS Nullable(DateTime64(6)))
1+
CAST(string_col AS Nullable(DateTime64(3)))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CAST(timestamp_col AS DateTime64(6))
1+
CAST(timestamp_col AS DateTime64(3))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CAST(int_col AS DateTime64(6))
1+
CAST(int_col AS DateTime64(3))

0 commit comments

Comments
 (0)