Skip to content

Commit 2e67918

Browse files
jcristcpcloud
authored andcommitted
feat(duckdb): support unsigned integer types
1 parent 0cb8a63 commit 2e67918

File tree

6 files changed

+104
-17
lines changed

6 files changed

+104
-17
lines changed

ibis/backends/base/sql/alchemy/datatypes.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sqlalchemy.dialects.postgresql.base import PGDialect
1010
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
1111
from sqlalchemy.engine.interfaces import Dialect
12+
from sqlalchemy.ext.compiler import compiles
1213
from sqlalchemy.types import UserDefinedType
1314

1415
import ibis.expr.datatypes as dt
@@ -33,6 +34,41 @@ def get_col_spec(self, **_):
3334
return f"STRUCT({pairs})"
3435

3536

37+
class UInt64(sa.types.Integer):
38+
pass
39+
40+
41+
class UInt32(sa.types.Integer):
42+
pass
43+
44+
45+
class UInt16(sa.types.Integer):
46+
pass
47+
48+
49+
class UInt8(sa.types.Integer):
50+
pass
51+
52+
53+
@compiles(UInt64, "postgresql")
54+
@compiles(UInt32, "postgresql")
55+
@compiles(UInt16, "postgresql")
56+
@compiles(UInt8, "postgresql")
57+
@compiles(UInt64, "mysql")
58+
@compiles(UInt32, "mysql")
59+
@compiles(UInt16, "mysql")
60+
@compiles(UInt8, "mysql")
61+
@compiles(UInt64, "sqlite")
62+
@compiles(UInt32, "sqlite")
63+
@compiles(UInt16, "sqlite")
64+
@compiles(UInt8, "sqlite")
65+
def compile_uint(element, compiler, **kw):
66+
dialect_name = compiler.dialect.name
67+
raise TypeError(
68+
f"unsigned integers are not supported in the {dialect_name} backend"
69+
)
70+
71+
3672
def table_from_schema(name, meta, schema, database: str | None = None):
3773
# Convert Ibis schema to SQLA table
3874
columns = []
@@ -62,6 +98,10 @@ def table_from_schema(name, meta, schema, database: str | None = None):
6298
dt.Int16: sa.SmallInteger,
6399
dt.Int32: sa.Integer,
64100
dt.Int64: sa.BigInteger,
101+
dt.UInt8: UInt8,
102+
dt.UInt16: UInt16,
103+
dt.UInt32: UInt32,
104+
dt.UInt64: UInt64,
65105
dt.JSON: sa.JSON,
66106
}
67107

ibis/backends/duckdb/compiler.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from sqlalchemy.ext.compiler import compiles
2+
3+
import ibis.backends.base.sql.alchemy.datatypes as sat
4+
import ibis.expr.datatypes as dt
15
import ibis.expr.operations as ops
26
from ibis.backends.base.sql.alchemy import (
37
AlchemyCompiler,
@@ -16,6 +20,28 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
1620
_has_reduction_filter_syntax = True
1721

1822

23+
@compiles(sat.UInt64, "duckdb")
24+
@compiles(sat.UInt32, "duckdb")
25+
@compiles(sat.UInt16, "duckdb")
26+
@compiles(sat.UInt8, "duckdb")
27+
def compile_uint(element, compiler, **kw):
28+
return element.__class__.__name__.upper()
29+
30+
31+
try:
32+
import duckdb_engine
33+
except ImportError:
34+
pass
35+
else:
36+
37+
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt64)
38+
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt32)
39+
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt16)
40+
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt8)
41+
def dtype_uint(_, satype, nullable=True):
42+
return getattr(dt, satype.__class__.__name__)(nullable=nullable)
43+
44+
1945
rewrites = DuckDBSQLExprTranslator.rewrites
2046

2147

ibis/backends/tests/test_client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,22 @@ def test_in_memory(alchemy_backend):
359359
finally:
360360
con.raw_sql(f"DROP TABLE IF EXISTS {table_name}")
361361
assert table_name not in con.list_tables()
362+
363+
364+
@pytest.mark.parametrize(
365+
"coltype", [dt.uint8, dt.uint16, dt.uint32, dt.uint64]
366+
)
367+
@pytest.mark.notyet(
368+
["postgres", "mysql", "sqlite"],
369+
raises=TypeError,
370+
reason="postgres, mysql and sqlite do not support unsigned integer types",
371+
)
372+
def test_unsigned_integer_type(alchemy_con, coltype):
373+
tname = guid()
374+
alchemy_con.create_table(
375+
tname, schema=ibis.schema(dict(a=coltype)), force=True
376+
)
377+
try:
378+
assert tname in alchemy_con.list_tables()
379+
finally:
380+
alchemy_con.drop_table(tname, force=True)

ibis/expr/datatypes/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,11 +1344,13 @@ def infer_floating(value: float) -> Float64:
13441344

13451345

13461346
@infer.register(int)
1347-
def infer_integer(value: int) -> Integer:
1348-
for dtype in (int8, int16, int32, int64):
1347+
def infer_integer(value: int, prefer_unsigned: bool = False) -> Integer:
1348+
types = (uint8, uint16, uint32, uint64) if prefer_unsigned else ()
1349+
types += (int8, int16, int32, int64)
1350+
for dtype in types:
13491351
if dtype.bounds.lower <= value <= dtype.bounds.upper:
13501352
return dtype
1351-
return int64
1353+
return uint64 if prefer_unsigned else int64
13521354

13531355

13541356
@infer.register(enum.Enum)

ibis/expr/operations/temporal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def output_dtype(self):
368368
else arg
369369
for arg in self.args
370370
]
371-
value_dtype = rlz._promote_numeric_binop(integer_args, self.op)
371+
value_dtype = rlz._promote_integral_binop(integer_args, self.op)
372372
left_dtype = self.left.type()
373373
return dt.Interval(
374374
unit=left_dtype.unit,

ibis/expr/rules.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,22 +275,22 @@ def output_shape(self):
275275
# TODO(kszucs): might just use bounds instead of actual literal values
276276
# that could simplify interval binop output_type methods
277277
# TODO(kszucs): pre-generate mapping?
278-
def _promote_numeric_binop(exprs, op):
279-
bounds, dtypes = [], []
280-
for arg in exprs:
281-
dtypes.append(arg.type())
282-
if hasattr(arg.op(), 'value'):
283-
# arg.op() is a literal
284-
bounds.append([arg.op().value])
285-
else:
286-
bounds.append(arg.type().bounds)
278+
def _promote_integral_binop(exprs, op):
279+
dtypes = []
280+
bounds = []
281+
for expr in exprs:
282+
try:
283+
bounds.append([expr.op().value])
284+
except AttributeError:
285+
dtypes.append(expr.type())
286+
bounds.append(expr.type().bounds)
287287

288+
all_unsigned = dtypes and util.all_of(dtypes, dt.UnsignedInteger)
288289
# In some cases, the bounding type might be int8, even though neither
289290
# of the types are that small. We want to ensure the containing type is
290291
# _at least_ as large as the smallest type in the expression.
291-
values = starmap(op, product(*bounds))
292-
dtypes += [dt.infer(value) for value in values]
293-
292+
values = list(starmap(op, product(*bounds)))
293+
dtypes.extend(dt.infer(v, prefer_unsigned=all_unsigned) for v in values)
294294
return dt.highest_precedence(dtypes)
295295

296296

@@ -299,7 +299,7 @@ def numeric_like(name, op):
299299
def output_dtype(self):
300300
args = getattr(self, name)
301301
if util.all_of(args, ir.IntegerValue):
302-
result = _promote_numeric_binop(args, op)
302+
result = _promote_integral_binop(args, op)
303303
else:
304304
result = highest_precedence_dtype(args)
305305

0 commit comments

Comments
 (0)