Skip to content

Commit e8f96b6

Browse files
cpcloudkszucs
authored andcommitted
refactor(sqlalchemy): use exec_driver_sql everywhere
1 parent ee6d58a commit e8f96b6

20 files changed

+148
-185
lines changed

ci/schema/postgresql.sql

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
DROP SEQUENCE IF EXISTS test_sequence;
2-
CREATE SEQUENCE IF NOT EXISTS test_sequence;
3-
41
CREATE EXTENSION IF NOT EXISTS hstore;
52
CREATE EXTENSION IF NOT EXISTS postgis;
63
CREATE EXTENSION IF NOT EXISTS plpython3u;

ci/schema/trino.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
DROP TABLE IF EXISTS map;
2+
CREATE TABLE map (kv MAP<VARCHAR, BIGINT>);
3+
INSERT INTO map VALUES
4+
(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3])),
5+
(MAP(ARRAY['d', 'e', 'f'], ARRAY[4, 5, 6]));
6+
7+
DROP TABLE IF EXISTS ts;
8+
CREATE TABLE ts (x TIMESTAMP(3), y TIMESTAMP(6), z TIMESTAMP(9));
9+
INSERT INTO ts VALUES
10+
(TIMESTAMP '2023-01-07 13:20:05.561',
11+
TIMESTAMP '2023-01-07 13:20:05.561021',
12+
TIMESTAMP '2023-01-07 13:20:05.561000231');

ibis/backends/conftest.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,17 @@ def recreate_database(
137137
engine = sa.create_engine(url.set(database=""), **kwargs)
138138

139139
if url.database is not None:
140-
with engine.begin() as conn:
141-
conn.execute(sa.text(f'DROP DATABASE IF EXISTS {database}'))
142-
conn.execute(sa.text(f'CREATE DATABASE {database}'))
140+
with engine.begin() as con:
141+
con.exec_driver_sql(f"DROP DATABASE IF EXISTS {database}")
142+
con.exec_driver_sql(f"CREATE DATABASE {database}")
143143

144144

145145
def init_database(
146146
url: sa.engine.url.URL,
147147
database: str,
148148
schema: TextIO | None = None,
149149
recreate: bool = True,
150+
isolation_level: str = "AUTOCOMMIT",
150151
**kwargs: Any,
151152
) -> sa.engine.Engine:
152153
"""Initialise `database` at `url` with `schema`.
@@ -163,20 +164,23 @@ def init_database(
163164
File object containing schema to use
164165
recreate : bool
165166
If true, drop the database if it exists
167+
isolation_level : str
168+
Transaction isolation_level
166169
167170
Returns
168171
-------
169-
sa.engine.Engine for the database created
172+
sa.engine.Engine
173+
SQLAlchemy engine object
170174
"""
171175
if recreate:
172-
recreate_database(url, database, **kwargs)
176+
recreate_database(url, database, isolation_level=isolation_level, **kwargs)
173177

174178
try:
175179
url.database = database
176180
except AttributeError:
177181
url = url.set(database=database)
178182

179-
engine = sa.create_engine(url, **kwargs)
183+
engine = sa.create_engine(url, isolation_level=isolation_level, **kwargs)
180184

181185
if schema:
182186
with engine.begin() as conn:

ibis/backends/duckdb/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def _load_extensions(self, extensions):
141141
for extension in extensions:
142142
if extension not in self._extensions:
143143
with self.begin() as con:
144-
con.execute(sa.text(f"INSTALL '{extension}'"))
145-
con.execute(sa.text(f"LOAD '{extension}'"))
144+
con.exec_driver_sql(f"INSTALL '{extension}'")
145+
con.exec_driver_sql(f"LOAD '{extension}'")
146146
self._extensions.add(extension)
147147

148148
def register(
@@ -449,7 +449,7 @@ def fetch_from_cursor(
449449

450450
def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
451451
with self.begin() as con:
452-
rows = con.execute(sa.text(f"DESCRIBE {query}"))
452+
rows = con.exec_driver_sql(f"DESCRIBE {query}")
453453

454454
for name, type, null in toolz.pluck(
455455
["column_name", "column_type", "null"], rows.mappings()

ibis/backends/mssql/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def do_connect(
4343
@contextlib.contextmanager
4444
def begin(self):
4545
with super().begin() as bind:
46-
prev = bind.execute(sa.text('SELECT @@DATEFIRST')).scalar()
47-
bind.execute(sa.text('SET DATEFIRST 1'))
46+
prev = bind.exec_driver_sql("SELECT @@DATEFIRST").scalar()
47+
bind.exec_driver_sql("SET DATEFIRST 1")
4848
yield bind
4949
bind.execute(sa.text("SET DATEFIRST :prev").bindparams(prev=prev))
5050

@@ -53,8 +53,8 @@ def _metadata(self, query):
5353
query = f"SELECT * FROM [{query}]"
5454

5555
with self.begin() as bind:
56-
for column in bind.execute(
57-
sa.text(f"EXEC sp_describe_first_result_set @tsql = N'{query}';")
56+
for column in bind.exec_driver_sql(
57+
f"EXEC sp_describe_first_result_set @tsql = N'{query}'"
5858
).mappings():
5959
yield column["name"], _type_from_result_set_info(column)
6060

ibis/backends/mssql/tests/test_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
import sqlalchemy as sa
32
from pytest import param
43

54
import ibis
@@ -84,12 +83,12 @@ def test_get_schema_from_query(con, server_type, expected_type):
8483
expected_schema = ibis.schema(dict(x=expected_type))
8584
try:
8685
with con.begin() as c:
87-
c.execute(sa.text(f"CREATE TABLE {name} (x {server_type})"))
86+
c.exec_driver_sql(f"CREATE TABLE {name} (x {server_type})")
8887
expected_schema = ibis.schema(dict(x=expected_type))
8988
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
9089
assert result_schema == expected_schema
9190
t = con.table(raw_name)
9291
assert t.schema() == expected_schema
9392
finally:
9493
with con.begin() as c:
95-
c.execute(sa.text(f"DROP TABLE IF EXISTS {name}"))
94+
c.exec_driver_sql(f"DROP TABLE IF EXISTS {name}")

ibis/backends/mysql/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,18 @@ def do_connect(
106106
@contextlib.contextmanager
107107
def begin(self):
108108
with super().begin() as bind:
109-
prev = bind.execute(sa.text('SELECT @@session.time_zone')).scalar()
109+
prev = bind.exec_driver_sql('SELECT @@session.time_zone').scalar()
110110
try:
111-
bind.execute(sa.text("SET @@session.time_zone = 'UTC'"))
111+
bind.exec_driver_sql("SET @@session.time_zone = 'UTC'")
112112
except Exception as e: # noqa: BLE001
113113
warnings.warn(f"Couldn't set MySQL timezone: {e}")
114114

115115
yield bind
116+
stmt = sa.text("SET @@session.time_zone = :prev").bindparams(prev=prev)
116117
try:
117-
bind.execute(
118-
sa.text("SET @@session.time_zone = :prev").bindparams(prev=prev)
119-
)
118+
bind.execute(stmt)
120119
except Exception as e: # noqa: BLE001
121-
warnings.warn(sa.text(f"Couldn't reset MySQL timezone: {e}"))
120+
warnings.warn(f"Couldn't reset MySQL timezone: {e}")
122121

123122
def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
124123
if (
@@ -128,7 +127,7 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
128127
query = f"({query})"
129128

130129
with self.begin() as con:
131-
result = con.execute(sa.text(f"SELECT * FROM {query} _ LIMIT 0"))
130+
result = con.exec_driver_sql(f"SELECT * FROM {query} _ LIMIT 0")
132131
cursor = result.cursor
133132
yield from (
134133
(field.name, _type_from_cursor_info(descr, field))

ibis/backends/mysql/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _load_data(
9393
"LINES TERMINATED BY '\\n'",
9494
"IGNORE 1 LINES",
9595
]
96-
con.execute(sa.text("\n".join(lines)))
96+
con.exec_driver_sql("\n".join(lines))
9797

9898
@staticmethod
9999
def connect(_: Path):

ibis/backends/mysql/tests/test_client.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
import sqlalchemy as sa
32
from pytest import param
43

54
import ibis
@@ -64,21 +63,26 @@ def test_get_schema_from_query(con, mysql_type, expected_type):
6463
# temporary tables get cleaned up by the db when the session ends, so we
6564
# don't need to explicitly drop the table
6665
with con.begin() as c:
67-
c.execute(sa.text(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})"))
68-
expected_schema = ibis.schema(dict(x=expected_type))
69-
t = con.table(raw_name)
70-
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
71-
assert t.schema() == expected_schema
72-
assert result_schema == expected_schema
66+
c.exec_driver_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})")
67+
try:
68+
expected_schema = ibis.schema(dict(x=expected_type))
69+
t = con.table(raw_name)
70+
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
71+
assert t.schema() == expected_schema
72+
assert result_schema == expected_schema
73+
finally:
74+
with con.begin() as c:
75+
c.exec_driver_sql(f"DROP TABLE {name}")
7376

7477

75-
@pytest.mark.parametrize(
76-
"coltype",
77-
["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"],
78-
)
78+
@pytest.mark.parametrize("coltype", ["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"])
7979
def test_blob_type(con, coltype):
8080
tmp = f"tmp_{ibis.util.guid()}"
8181
with con.begin() as c:
82-
c.execute(sa.text(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})"))
83-
t = con.table(tmp)
84-
assert t.schema() == ibis.schema({"a": dt.binary})
82+
c.exec_driver_sql(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})")
83+
try:
84+
t = con.table(tmp)
85+
assert t.schema() == ibis.schema({"a": dt.binary})
86+
finally:
87+
with con.begin() as c:
88+
c.exec_driver_sql(f"DROP TABLE {tmp}")

ibis/backends/postgres/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,18 @@ def list_databases(self, like=None):
110110
# http://dba.stackexchange.com/a/1304/58517
111111
databases = [
112112
row.datname
113-
for row in con.execute(
114-
sa.text('SELECT datname FROM pg_database WHERE NOT datistemplate')
115-
)
113+
for row in con.exec_driver_sql(
114+
"SELECT datname FROM pg_database WHERE NOT datistemplate"
115+
).mappings()
116116
]
117117
return self._filter_with_like(databases, like)
118118

119119
@contextlib.contextmanager
120120
def begin(self):
121121
with super().begin() as bind:
122-
prev = bind.execute(sa.text('SHOW TIMEZONE')).scalar()
123-
bind.execute(sa.text('SET TIMEZONE = UTC'))
122+
# LOCAL takes effect for the current transaction only
123+
bind.exec_driver_sql("SET LOCAL TIMEZONE = UTC")
124124
yield bind
125-
bind.execute(sa.text("SET TIMEZONE = :prev").bindparams(prev=prev))
126125

127126
def udf(
128127
self,
@@ -186,12 +185,12 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
186185
AND NOT attisdropped
187186
ORDER BY attnum"""
188187
with self.begin() as con:
189-
con.execute(sa.text(f"CREATE TEMPORARY VIEW {name} AS {query}"))
188+
con.exec_driver_sql(f"CREATE TEMPORARY VIEW {name} AS {query}")
190189
type_info = con.execute(
191190
sa.text(type_info_sql).bindparams(raw_name=raw_name)
192191
)
193192
yield from ((col, _get_type(typestr)) for col, typestr in type_info)
194-
con.execute(sa.text(f"DROP VIEW IF EXISTS {name}"))
193+
con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}")
195194

196195
def _get_temp_view_definition(
197196
self, name: str, definition: sa.sql.compiler.Compiled

ibis/backends/postgres/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _load_data(
9191
with data_dir.joinpath(f'{table}.csv').open('r') as file:
9292
cur.copy_expert(sql=sql, file=file)
9393

94-
con.execute(sa.text("VACUUM FULL ANALYZE"))
94+
con.exec_driver_sql("VACUUM FULL ANALYZE")
9595

9696
@staticmethod
9797
def connect(data_directory: Path):

ibis/backends/postgres/tests/test_client.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,9 @@ def test_create_and_drop_table(con, temp_table, params):
217217
],
218218
)
219219
def test_get_schema_from_query(con, pg_type, expected_type):
220-
raw_name = ibis.util.guid()
221-
name = con._quote(raw_name)
220+
name = con._quote(ibis.util.guid())
222221
with con.begin() as c:
223-
c.execute(
224-
sa.text(f"CREATE TEMPORARY TABLE {name} (x {pg_type}, y {pg_type}[])")
225-
)
222+
c.exec_driver_sql(f"CREATE TEMP TABLE {name} (x {pg_type}, y {pg_type}[])")
226223
expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type)))
227224
result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}")
228225
assert result_schema == expected_schema

ibis/backends/postgres/tests/test_functions.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,8 +1041,8 @@ def test_array_concat_mixed_types(array_types):
10411041
@pytest.fixture
10421042
def t(con, guid):
10431043
with con.begin() as c:
1044-
c.execute(
1045-
sa.text(f"CREATE TABLE \"{guid}\" (id SERIAL PRIMARY KEY, name TEXT)")
1044+
c.exec_driver_sql(
1045+
f"CREATE TABLE {con._quote(guid)} (id SERIAL PRIMARY KEY, name TEXT)"
10461046
)
10471047
return con.table(guid)
10481048

@@ -1053,27 +1053,24 @@ def s(con, t, guid, guid2):
10531053
assert t.op().name != guid2
10541054

10551055
with con.begin() as c:
1056-
c.execute(
1057-
sa.text(
1058-
f"""
1059-
CREATE TABLE \"{guid2}\" (
1056+
c.exec_driver_sql(
1057+
f"""
1058+
CREATE TABLE {con._quote(guid2)} (
10601059
id SERIAL PRIMARY KEY,
1061-
left_t_id INTEGER REFERENCES "{guid}",
1060+
left_t_id INTEGER REFERENCES {con._quote(guid)},
10621061
cost DOUBLE PRECISION
10631062
)
10641063
"""
1065-
)
10661064
)
10671065
return con.table(guid2)
10681066

10691067

10701068
@pytest.fixture
10711069
def trunc(con, guid):
1070+
quoted = con._quote(guid)
10721071
with con.begin() as c:
1073-
c.execute(
1074-
sa.text(f"CREATE TABLE \"{guid}\" (id SERIAL PRIMARY KEY, name TEXT)")
1075-
)
1076-
c.execute(sa.text(f"INSERT INTO \"{guid}\" (name) VALUES ('a'), ('b'), ('c')"))
1072+
c.exec_driver_sql(f"CREATE TABLE {quoted} (id SERIAL PRIMARY KEY, name TEXT)")
1073+
c.exec_driver_sql(f"INSERT INTO {quoted} (name) VALUES ('a'), ('b'), ('c')")
10771074
return con.table(guid)
10781075

10791076

@@ -1314,9 +1311,8 @@ def test_timestamp_with_timezone_select(tzone_compute, tz):
13141311

13151312

13161313
def test_timestamp_type_accepts_all_timezones(con):
1317-
query = 'SELECT name FROM pg_timezone_names'
13181314
with con.begin() as c:
1319-
cur = c.execute(sa.text(query)).fetchall()
1315+
cur = c.exec_driver_sql("SELECT name FROM pg_timezone_names").fetchall()
13201316
assert all(dt.Timestamp(row.name).timezone == row.name for row in cur)
13211317

13221318

@@ -1416,7 +1412,7 @@ def test_string_to_binary_cast(con):
14161412
"FROM functional_alltypes LIMIT 10"
14171413
)
14181414
with con.begin() as c:
1419-
cur = c.execute(sa.text(sql_string))
1415+
cur = c.exec_driver_sql(sql_string)
14201416
raw_data = [row[0][0] for row in cur]
14211417
expected = pd.Series(raw_data, name=name)
14221418
tm.assert_series_equal(result, expected)
@@ -1433,6 +1429,6 @@ def test_string_to_binary_round_trip(con):
14331429
"FROM functional_alltypes LIMIT 10"
14341430
)
14351431
with con.begin() as c:
1436-
cur = c.execute(sa.text(sql_string))
1432+
cur = c.exec_driver_sql(sql_string)
14371433
expected = pd.Series([row[0][0] for row in cur], name=name)
14381434
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)