Skip to content

Commit dec70f5

Browse files
committed
refactor(datatype): add custom sqlalchemy nested types for backend differentiation
1 parent a8bbc00 commit dec70f5

File tree

9 files changed

+76
-25
lines changed

9 files changed

+76
-25
lines changed

ibis/backends/base/sql/alchemy/datatypes.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,43 @@
2222
import geoalchemy2 as ga
2323

2424

25+
class ArrayType(UserDefinedType):
26+
def __init__(self, value_type: sa.types.TypeEngine):
27+
self.value_type = sa.types.to_instance(value_type)
28+
29+
30+
@compiles(ArrayType, "default")
31+
def compiles_array(element, compiler, **kw):
32+
return f"ARRAY({compiler.process(element.value_type, **kw)})"
33+
34+
2535
class StructType(UserDefinedType):
2636
def __init__(
2737
self,
2838
pairs: Iterable[tuple[str, sa.types.TypeEngine]],
2939
):
3040
self.pairs = [(name, sa.types.to_instance(type)) for name, type in pairs]
3141

32-
def get_col_spec(self, **_):
33-
pairs = ", ".join(f"{k} {v}" for k, v in self.pairs)
34-
return f"STRUCT({pairs})"
42+
43+
@compiles(StructType, "default")
44+
def compiles_struct(element, compiler, **kw):
45+
content = ", ".join(
46+
f"{field} {compiler.process(typ, **kw)}" for field, typ in element.pairs
47+
)
48+
return f"STRUCT({content})"
3549

3650

3751
class MapType(UserDefinedType):
3852
def __init__(self, key_type: sa.types.TypeEngine, value_type: sa.types.TypeEngine):
3953
self.key_type = sa.types.to_instance(key_type)
4054
self.value_type = sa.types.to_instance(value_type)
4155

42-
def get_col_spec(self, **_):
43-
return f"MAP({self.key_type}, {self.value_type})"
56+
57+
@compiles(MapType, "default")
58+
def compiles_map(element, compiler, **kw):
59+
key_type = compiler.process(element.key_type, **kw)
60+
value_type = compiler.process(element.value_type, **kw)
61+
return f"MAP({key_type}, {value_type})"
4462

4563

4664
class UInt64(sa.types.Integer):
@@ -426,6 +444,11 @@ def sa_struct(dialect, satype, nullable=True):
426444
return dt.Struct.from_tuples(pairs, nullable=nullable)
427445

428446

447+
@dt.dtype.register(Dialect, ArrayType)
448+
def sa_array(dialect, satype, nullable=True):
449+
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)
450+
451+
429452
@sch.infer.register((sa.Table, sa.sql.TableClause))
430453
def schema_from_table(
431454
table: sa.Table,

ibis/backends/base/sql/alchemy/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _cast(t, op):
185185
if typ.is_json() and not t.native_json_type:
186186
return sa_arg
187187

188-
return t.cast(sa_arg, typ)
188+
return sa.cast(sa_arg, t.get_sqla_type(typ))
189189

190190

191191
def _contains(func):

ibis/backends/base/sql/alchemy/translator.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import functools
4+
35
import sqlalchemy as sa
46

57
import ibis
@@ -56,7 +58,7 @@ class AlchemyExprTranslator(ExprTranslator):
5658

5759
_dialect_name = "default"
5860

59-
@property
61+
@functools.cached_property
6062
def dialect(self) -> sa.engine.interfaces.Dialect:
6163
if (name := self._dialect_name) == "default":
6264
return _DEFAULT_DIALECT
@@ -65,8 +67,7 @@ def dialect(self) -> sa.engine.interfaces.Dialect:
6567

6668
def _schema_to_sqlalchemy_columns(self, schema):
6769
return [
68-
sa.column(name, to_sqla_type(self.dialect, dtype))
69-
for name, dtype in schema.items()
70+
sa.column(name, self.get_sqla_type(dtype)) for name, dtype in schema.items()
7071
]
7172

7273
def name(self, translated, name, force=True):
@@ -103,9 +104,6 @@ def _reduction(self, sa_func, op):
103104

104105
return sa_func(*sa_args)
105106

106-
def cast(self, sa_expr, ibis_type: dt.DataType):
107-
return sa.cast(sa_expr, self.get_sqla_type(ibis_type))
108-
109107

110108
rewrites = AlchemyExprTranslator.rewrites
111109

ibis/backends/duckdb/compiler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import ibis.backends.base.sql.alchemy.datatypes as sat
66
import ibis.expr.datatypes as dt
77
import ibis.expr.operations as ops
8-
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
8+
from ibis.backends.base.sql.alchemy import (
9+
AlchemyCompiler,
10+
AlchemyExprTranslator,
11+
to_sqla_type,
12+
)
913
from ibis.backends.duckdb.registry import operation_registry
1014

1115

@@ -24,6 +28,11 @@ def compile_uint(element, compiler, **kw):
2428
return element.__class__.__name__.upper()
2529

2630

31+
@compiles(sat.ArrayType, "duckdb")
32+
def compile_array(element, compiler, **kw):
33+
return f"{compiler.process(element.value_type, **kw)}[]"
34+
35+
2736
try:
2837
import duckdb_engine
2938
except ImportError:
@@ -37,6 +46,14 @@ def compile_uint(element, compiler, **kw):
3746
def dtype_uint(_, satype, nullable=True):
3847
return getattr(dt, satype.__class__.__name__)(nullable=nullable)
3948

49+
@dt.dtype.register(duckdb_engine.Dialect, sat.ArrayType)
50+
def _(dialect, satype, nullable=True):
51+
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)
52+
53+
@to_sqla_type.register(duckdb_engine.Dialect, dt.Array)
54+
def _(dialect, itype):
55+
return sat.ArrayType(to_sqla_type(dialect, itype.value_type))
56+
4057

4158
rewrites = DuckDBSQLExprTranslator.rewrites
4259

ibis/backends/snowflake/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ class SnowflakeExprTranslator(AlchemyExprTranslator):
4141
_require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction)
4242
_dialect_name = "snowflake"
4343

44-
def cast(self, sa_expr, ibis_type: dt.DataType):
45-
if ibis_type.is_array() or ibis_type.is_map() or ibis_type.is_struct():
46-
return sa.type_coerce(sa_expr, self.get_sqla_type(ibis_type))
47-
return super().cast(sa_expr, ibis_type)
48-
4944

5045
class SnowflakeCompiler(AlchemyCompiler):
5146
translator_class = SnowflakeExprTranslator

ibis/backends/snowflake/datatypes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
1616

1717
import ibis.expr.datatypes as dt
18+
from ibis.backends.base.sql.alchemy import to_sqla_type
1819

1920
if TYPE_CHECKING:
2021
from ibis.expr.datatypes import DataType
@@ -177,3 +178,18 @@ def sa_sf_numeric(_, satype, nullable=True):
177178
@dt.dtype.register(SnowflakeDialect, (sa.REAL, sa.FLOAT, sa.Float))
178179
def sa_sf_real_float(_, satype, nullable=True):
179180
return dt.Float64(nullable=nullable)
181+
182+
183+
@to_sqla_type.register(SnowflakeDialect, dt.Array)
184+
def _sf_array(_, itype):
185+
return ARRAY
186+
187+
188+
@to_sqla_type.register(SnowflakeDialect, (dt.Map, dt.Struct))
189+
def _sf_map_struct(_, itype):
190+
return OBJECT
191+
192+
193+
@to_sqla_type.register(SnowflakeDialect, dt.JSON)
194+
def _sf_json(_, itype):
195+
return VARIANT

ibis/backends/tests/test_array.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ def test_array_index(con, idx):
150150

151151
@builtin_array
152152
@pytest.mark.never(
153-
["clickhouse", "pandas", "pyspark", "snowflake", "polars"],
153+
["clickhouse", "duckdb", "pandas", "pyspark", "snowflake", "polars"],
154154
reason="backend does not flatten array types",
155155
)
156156
@pytest.mark.never(["bigquery"], reason="doesn't support arrays of arrays")
157-
def test_array_discovery_postgres_duckdb(con):
157+
def test_array_discovery_postgres(con):
158158
t = con.table("array_types")
159159
expected = ibis.schema(
160160
dict(
@@ -195,8 +195,7 @@ def test_array_discovery_clickhouse(con):
195195

196196
@builtin_array
197197
@pytest.mark.notyet(
198-
["clickhouse", "duckdb", "postgres"],
199-
reason="backend does not support nullable nested types",
198+
["clickhouse", "postgres"], reason="backend does not support nullable nested types"
200199
)
201200
@pytest.mark.notimpl(
202201
["trino"],

ibis/backends/trino/datatypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from trino.sqlalchemy.datatype import DOUBLE, JSON, MAP, ROW
77
from trino.sqlalchemy.dialect import TrinoDialect
88

9-
import ibis.backends.base.sql.alchemy.datatypes as sat
109
import ibis.expr.datatypes as dt
10+
from ibis.backends.base.sql.alchemy import to_sqla_type
1111
from ibis.common.parsing import (
1212
COMMA,
1313
FIELD,
@@ -192,7 +192,8 @@ def _compiles_row(element, compiler, **kw):
192192
# TODO: @compiles should live in the dialect
193193
quote = compiler.dialect.identifier_preparer.quote
194194
content = ", ".join(
195-
f"{field} {compiler.process(typ, **kw)}" for field, typ in element.pairs
195+
f"{quote(field)} {compiler.process(typ, **kw)}"
196+
for field, typ in element.attr_types
196197
)
197198
return f"ROW({content})"
198199

@@ -206,6 +207,7 @@ def _map(dialect, itype):
206207

207208
@compiles(MAP, "trino")
208209
def compiles_map(typ, compiler, **kw):
210+
# TODO: @compiles should live in the dialect
209211
key_type = compiler.process(typ.key_type, **kw)
210212
value_type = compiler.process(typ.value_type, **kw)
211213
return f"MAP({key_type}, {value_type})"

ibis/tests/sql/test_sqlalchemy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
schema_from_table,
3030
to_sqla_type,
3131
)
32+
from ibis.backends.base.sql.alchemy.datatypes import ArrayType
3233
from ibis.tests.expr.mocks import MockAlchemyBackend
3334
from ibis.tests.util import assert_decompile_roundtrip, assert_equal
3435

@@ -1107,7 +1108,7 @@ def test_to_sqla_type_array_of_non_primitive():
11071108
expected_type = sa.BigInteger()
11081109
assert result_name == expected_name
11091110
assert type(result_type) == type(expected_type)
1110-
assert isinstance(result, sa.ARRAY)
1111+
assert isinstance(result, ArrayType)
11111112

11121113

11131114
def test_no_cart_join(con, snapshot):

0 commit comments

Comments
 (0)