Skip to content

Commit 12f6438

Browse files
committed
fix(mysql): fix mysql query schema inference
1 parent b0f4e4c commit 12f6438

File tree

6 files changed

+84
-60
lines changed

6 files changed

+84
-60
lines changed

ci/schema/mysql.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ CREATE TABLE functional_alltypes (
6666
double_col DOUBLE,
6767
date_string_col TEXT,
6868
string_col TEXT,
69-
timestamp_col TIMESTAMP,
69+
timestamp_col DATETIME,
7070
year INTEGER,
7171
month INTEGER
7272
) DEFAULT CHARACTER SET = utf8;

ibis/backends/base/sql/alchemy/datatypes.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,6 @@ def _(itype, **_):
165165
return MapType(to_sqla_type(itype.key_type), to_sqla_type(itype.value_type))
166166

167167

168-
@to_sqla_type.register(dt.GeoSpatial)
169-
def _(itype, **kwargs):
170-
if itype.geotype == 'geometry':
171-
return ga.Geometry
172-
elif itype.geotype == 'geography':
173-
return ga.Geography
174-
else:
175-
return ga.types._GISType
176-
177-
178168
@dt.dtype.register(Dialect, sa.types.NullType)
179169
def sa_null(_, satype, nullable=True):
180170
return dt.null
@@ -185,21 +175,12 @@ def sa_boolean(_, satype, nullable=True):
185175
return dt.Boolean(nullable=nullable)
186176

187177

188-
@dt.dtype.register(MySQLDialect, mysql.NUMERIC)
189-
@dt.dtype.register(MySQLDialect, sa.NUMERIC)
178+
@dt.dtype.register(MySQLDialect, (sa.NUMERIC, mysql.NUMERIC))
190179
def sa_mysql_numeric(_, satype, nullable=True):
191180
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
192181
return dt.Decimal(satype.precision or 10, satype.scale or 0, nullable=nullable)
193182

194183

195-
@dt.dtype.register(MySQLDialect, mysql.TINYBLOB)
196-
@dt.dtype.register(MySQLDialect, mysql.MEDIUMBLOB)
197-
@dt.dtype.register(MySQLDialect, mysql.BLOB)
198-
@dt.dtype.register(MySQLDialect, mysql.LONGBLOB)
199-
def sa_mysql_blob(_, satype, nullable=True):
200-
return dt.Binary(nullable=nullable)
201-
202-
203184
_FLOAT_PREC_TO_TYPE = {
204185
11: dt.Float16,
205186
24: dt.Float32,
@@ -233,15 +214,30 @@ def sa_integer(_, satype, nullable=True):
233214

234215
@dt.dtype.register(Dialect, mysql.TINYINT)
235216
@dt.dtype.register(MSDialect, mssql.TINYINT)
217+
@dt.dtype.register(MySQLDialect, mysql.YEAR)
236218
def sa_mysql_tinyint(_, satype, nullable=True):
237219
return dt.Int8(nullable=nullable)
238220

239221

240222
@dt.dtype.register(MSDialect, mssql.BIT)
241-
def sa_mysql_bit(_, satype, nullable=True):
223+
def sa_mssql_bit(_, satype, nullable=True):
242224
return dt.Boolean(nullable=nullable)
243225

244226

227+
@dt.dtype.register(MySQLDialect, mysql.BIT)
228+
def sa_mysql_bit(_, satype, nullable=True):
229+
if 1 <= (length := satype.length) <= 8:
230+
return dt.Int8(nullable=nullable)
231+
elif 9 <= length <= 16:
232+
return dt.Int16(nullable=nullable)
233+
elif 17 <= length <= 32:
234+
return dt.Int32(nullable=nullable)
235+
elif 33 <= length <= 64:
236+
return dt.Int64(nullable=nullable)
237+
else:
238+
raise ValueError(f"Invalid MySQL BIT length: {length:d}")
239+
240+
245241
@dt.dtype.register(Dialect, sa.types.BigInteger)
246242
@dt.dtype.register(MSDialect, mssql.MONEY)
247243
def sa_bigint(_, satype, nullable=True):
@@ -254,6 +250,7 @@ def sa_mssql_smallmoney(_, satype, nullable=True):
254250

255251

256252
@dt.dtype.register(Dialect, sa.REAL)
253+
@dt.dtype.register(MySQLDialect, mysql.FLOAT)
257254
def sa_real(_, satype, nullable=True):
258255
return dt.Float32(nullable=nullable)
259256

@@ -271,11 +268,6 @@ def sa_uuid(_, satype, nullable=True):
271268
return dt.UUID(nullable=nullable)
272269

273270

274-
@dt.dtype.register(MSDialect, (mssql.BINARY, mssql.TIMESTAMP))
275-
def sa_mssql_timestamp(_, satype, nullable=True):
276-
return dt.Binary(nullable=nullable)
277-
278-
279271
@dt.dtype.register(PGDialect, postgresql.MACADDR)
280272
def sa_macaddr(_, satype, nullable=True):
281273
return dt.MACADDR(nullable=nullable)
@@ -292,6 +284,21 @@ def sa_json(_, satype, nullable=True):
292284
return dt.JSON(nullable=nullable)
293285

294286

287+
@dt.dtype.register(MySQLDialect, mysql.TIMESTAMP)
288+
def sa_mysql_timestamp(_, satype, nullable=True):
289+
return dt.Timestamp(timezone="UTC", nullable=nullable)
290+
291+
292+
@dt.dtype.register(MySQLDialect, mysql.DATETIME)
293+
def sa_mysql_datetime(_, satype, nullable=True):
294+
return dt.Timestamp(nullable=nullable)
295+
296+
297+
@dt.dtype.register(MySQLDialect, mysql.SET)
298+
def sa_mysql_set(_, satype, nullable=True):
299+
return dt.Set(dt.string, nullable=nullable)
300+
301+
295302
if geospatial_supported:
296303

297304
@dt.dtype.register(Dialect, (ga.Geometry, ga.types._GISType))
@@ -314,6 +321,15 @@ def ga_geometry(_, gatype, nullable=True):
314321
else:
315322
raise ValueError(f"Unrecognized geometry type: {t}")
316323

324+
@to_sqla_type.register(dt.GeoSpatial)
325+
def _(itype, **kwargs):
326+
if itype.geotype == 'geometry':
327+
return ga.Geometry
328+
elif itype.geotype == 'geography':
329+
return ga.Geography
330+
else:
331+
return ga.types._GISType
332+
317333

318334
POSTGRES_FIELD_TO_IBIS_UNIT = {
319335
"YEAR": "Y",
@@ -357,6 +373,18 @@ def sa_string(_, satype, nullable=True):
357373

358374

359375
@dt.dtype.register(Dialect, sa.LargeBinary)
376+
@dt.dtype.register(MSDialect, (mssql.BINARY, mssql.TIMESTAMP))
377+
@dt.dtype.register(
378+
MySQLDialect,
379+
(
380+
mysql.TINYBLOB,
381+
mysql.MEDIUMBLOB,
382+
mysql.BLOB,
383+
mysql.LONGBLOB,
384+
mysql.BINARY,
385+
mysql.VARBINARY,
386+
),
387+
)
360388
def sa_binary(_, satype, nullable=True):
361389
return dt.Binary(nullable=nullable)
362390

ibis/backends/mssql/tests/test_client.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ibis
55
import ibis.expr.datatypes as dt
6+
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
67

78
DB_TYPES = [
89
# Exact numbers
@@ -44,18 +45,23 @@
4445
('IMAGE', dt.binary),
4546
# Other data types
4647
('UNIQUEIDENTIFIER', dt.uuid),
47-
('GEOMETRY', dt.geometry),
48-
('GEOGRAPHY', dt.geography),
4948
('TIMESTAMP', dt.binary(nullable=False)),
5049
]
5150

5251

52+
skipif_no_geospatial_deps = pytest.mark.skipif(
53+
not geospatial_supported, reason="geospatial dependencies not installed"
54+
)
55+
56+
5357
@pytest.mark.parametrize(
5458
("server_type", "expected_type"),
55-
[
56-
param(server_type, ibis_type, id=server_type)
57-
for server_type, ibis_type in DB_TYPES
59+
DB_TYPES
60+
+ [
61+
param("GEOMETRY", dt.geometry, marks=[skipif_no_geospatial_deps]),
62+
param("GEOGRAPHY", dt.geography, marks=[skipif_no_geospatial_deps]),
5863
],
64+
ids=str,
5965
)
6066
def test_get_schema_from_query(con, server_type, expected_type):
6167
raw_name = f"tmp_{ibis.util.guid()}"

ibis/backends/mysql/__init__.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from __future__ import annotations
44

5-
import atexit
65
import contextlib
6+
import re
77
import warnings
88
from typing import Iterable, Literal
99

@@ -121,8 +121,14 @@ def begin(self):
121121
warnings.warn(f"Couldn't reset MySQL timezone: {str(e)}")
122122

123123
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
124-
with self.con.begin() as con:
125-
result = con.execute(f"SELECT * FROM ({query}) _ LIMIT 0")
124+
if (
125+
re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE)
126+
is not None
127+
):
128+
query = f"({query})"
129+
130+
with self.begin() as con:
131+
result = con.execute(f"SELECT * FROM {query} _ LIMIT 0")
126132
cursor = result.cursor
127133
yield from (
128134
(field.name, _type_from_cursor_info(descr, field))
@@ -136,15 +142,6 @@ def _get_temp_view_definition(
136142
) -> str:
137143
return f"CREATE OR REPLACE VIEW {name} AS {definition}"
138144

139-
def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
140-
query = f"DROP VIEW IF EXISTS {name}"
141-
142-
def drop(self, raw_name: str, query: str):
143-
self.con.execute(query)
144-
self._temp_views.discard(raw_name)
145-
146-
atexit.register(drop, self, raw_name, query)
147-
148145

149146
# TODO(kszucs): unsigned integers
150147

ibis/backends/mysql/datatypes.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,13 @@ def _type_from_cursor_info(descr, field) -> dt.DataType:
2929
if typename is None:
3030
raise NotImplementedError(f"MySQL type code {type_code:d} is not supported")
3131

32-
typ = _type_mapping[typename]
33-
3432
if typename in ("DECIMAL", "NEWDECIMAL"):
3533
precision = _decimal_length_to_precision(
3634
length=field_length,
3735
scale=scale,
3836
is_unsigned=flags.is_unsigned,
3937
)
40-
typ = partial(typ, precision=precision, scale=scale)
38+
typ = partial(_type_mapping[typename], precision=precision, scale=scale)
4139
elif typename == "BIT":
4240
if field_length <= 8:
4341
typ = dt.int8
@@ -61,18 +59,15 @@ def _type_from_cursor_info(descr, field) -> dt.DataType:
6159
typ = dt.Binary
6260
else:
6361
typ = dt.String
62+
else:
63+
typ = _type_mapping[typename]
6464

6565
# projection columns are always nullable
6666
return typ(nullable=True)
6767

6868

6969
# ported from my_decimal.h:my_decimal_length_to_precision in mariadb
70-
def _decimal_length_to_precision(
71-
*,
72-
length: int,
73-
scale: int,
74-
is_unsigned: bool,
75-
) -> int:
70+
def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int:
7671
return length - (scale > 0) - (not (is_unsigned or not length))
7772

7873

@@ -115,18 +110,14 @@ def _decimal_length_to_precision(
115110
"FLOAT": dt.Float32,
116111
"DOUBLE": dt.Float64,
117112
"NULL": dt.Null,
118-
"TIMESTAMP": lambda nullable: dt.Timestamp(
119-
timezone="UTC",
120-
nullable=nullable,
121-
),
113+
"TIMESTAMP": lambda nullable: dt.Timestamp(timezone="UTC", nullable=nullable),
122114
"LONGLONG": dt.Int64,
123115
"INT24": dt.Int32,
124116
"DATE": dt.Date,
125117
"TIME": dt.Time,
126118
"DATETIME": dt.Timestamp,
127-
"YEAR": dt.Int16,
119+
"YEAR": dt.Int8,
128120
"VARCHAR": dt.String,
129-
"BIT": dt.Int8,
130121
"JSON": dt.JSON,
131122
"NEWDECIMAL": dt.Decimal,
132123
"ENUM": dt.String,

ibis/backends/mysql/tests/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
("date", dt.date),
2828
("time", dt.time),
2929
("datetime", dt.timestamp),
30-
("year", dt.int16),
30+
("year", dt.int8),
3131
("char(32)", dt.string),
3232
("char byte", dt.binary),
3333
("varchar(42)", dt.string),
@@ -64,7 +64,9 @@ def test_get_schema_from_query(con, mysql_type, expected_type):
6464
# don't need to explicitly drop the table
6565
con.raw_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})")
6666
expected_schema = ibis.schema(dict(x=expected_type))
67+
t = con.table(raw_name)
6768
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
69+
assert t.schema() == expected_schema
6870
assert result_schema == expected_schema
6971

7072

0 commit comments

Comments
 (0)