Skip to content

Commit 8071255

Browse files
committed
fix(snowflake): make sqlalchemy 2.0 compatible
1 parent 2cb96e9 commit 8071255

File tree

3 files changed

+34
-45
lines changed

3 files changed

+34
-45
lines changed

ci/schema/snowflake.sql

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
CREATE OR REPLACE FILE FORMAT ibis_csv_fmt
1+
USE WAREHOUSE ibis_testing;
2+
DROP DATABASE IF EXISTS ibis_testing;
3+
CREATE DATABASE IF NOT EXISTS ibis_testing;
4+
CREATE SCHEMA IF NOT EXISTS ibis_testing.ibis_testing;
5+
USE SCHEMA ibis_testing.ibis_testing;
6+
7+
CREATE OR REPLACE FILE FORMAT ibis_testing
28
type = 'CSV'
39
field_delimiter = ','
410
skip_header = 1
511
field_optionally_enclosed_by = '"';
612

7-
CREATE OR REPLACE STAGE ibis_testing_stage file_format = ibis_csv_fmt;
13+
CREATE OR REPLACE STAGE ibis_testing file_format = ibis_testing;
814

915
CREATE OR REPLACE TABLE diamonds (
1016
"carat" FLOAT,

ibis/backends/snowflake/__init__.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,29 +55,18 @@ def _convert_kwargs(self, kwargs):
5555
@property
5656
def version(self) -> str:
5757
with self.begin() as con:
58-
return con.execute("SELECT CURRENT_VERSION()").scalar()
58+
return con.exec_driver_sql("SELECT CURRENT_VERSION()").scalar()
5959

6060
def do_connect(
61-
self,
62-
user: str,
63-
password: str,
64-
account: str,
65-
database: str,
66-
**kwargs,
61+
self, user: str, password: str, account: str, database: str, **kwargs: Any
6762
):
6863
dbparams = dict(zip(("database", "schema"), database.split("/", 1)))
6964
if dbparams.get("schema") is None:
7065
raise ValueError(
7166
"Schema must be non-None. Pass the schema as part of the "
7267
f"database e.g., {dbparams['database']}/my_schema"
7368
)
74-
url = URL(
75-
account=account,
76-
user=user,
77-
password=password,
78-
**dbparams,
79-
**kwargs,
80-
)
69+
url = URL(account=account, user=user, password=password, **dbparams, **kwargs)
8170
self.database_name = dbparams["database"]
8271
return super().do_connect(sa.create_engine(url))
8372

@@ -88,11 +77,11 @@ def _get_sqla_table(
8877
cols = []
8978
identifier = name if not schema else schema + "." + name
9079
with self.begin() as con:
91-
cur = con.execute(sa.text(f"DESCRIBE TABLE {identifier}"))
92-
for (colname, *_), colinfo in zip(cur, inspected):
93-
del colinfo["name"]
80+
cur = con.execute(sa.text(f"DESCRIBE TABLE {identifier}")).mappings()
81+
for colname, colinfo in zip(toolz.pluck("name", cur), inspected):
82+
colinfo["name"] = colname
9483
colinfo["type_"] = colinfo.pop("type")
95-
cols.append(sa.Column(colname, **colinfo, quote=True))
84+
cols.append(sa.Column(**colinfo, quote=True))
9685
return sa.Table(
9786
name,
9887
self.meta,
@@ -139,6 +128,6 @@ def list_databases(self, like=None) -> list[str]:
139128
row.database_name
140129
for row in con.execute(
141130
sa.text('SELECT database_name FROM information_schema.databases')
142-
)
131+
).mappings()
143132
]
144133
return self._filter_with_like(databases, like)

ibis/backends/snowflake/tests/conftest.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,21 @@
1111
import sqlalchemy as sa
1212

1313
import ibis
14-
from ibis.backends.conftest import TEST_TABLES
14+
from ibis.backends.conftest import TEST_TABLES, init_database
1515
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
1616

1717
if TYPE_CHECKING:
1818
from ibis.backends.base import BaseBackend
1919

2020

2121
def copy_into(con, data_dir: Path, table: str) -> None:
22-
stage = "ibis_testing_stage"
22+
stage = "ibis_testing"
2323
csv = f"{table}.csv"
24-
src = data_dir / csv
25-
con.execute(sa.text(f"PUT file://{src.absolute()} @{stage}/{csv}"))
26-
con.execute(
27-
sa.text(
28-
f"COPY INTO {table} FROM @{stage}/{csv} FILE_FORMAT = (FORMAT_NAME = ibis_csv_fmt)"
29-
)
24+
con.exec_driver_sql(
25+
f"PUT file://{data_dir.joinpath(csv).absolute()} @{stage}/{csv}"
26+
)
27+
con.exec_driver_sql(
28+
f"COPY INTO {table} FROM @{stage}/{csv} FILE_FORMAT = (FORMAT_NAME = ibis_testing)"
3029
)
3130

3231

@@ -36,10 +35,7 @@ def __init__(self, data_directory: Path) -> None:
3635

3736
@staticmethod
3837
def _load_data(
39-
data_dir,
40-
script_dir,
41-
database: str = "ibis_testing",
42-
**_: Any,
38+
data_dir, script_dir, database: str = "ibis_testing", **_: Any
4339
) -> None:
4440
"""Load test data into a Snowflake backend instance.
4541
@@ -53,26 +49,24 @@ def _load_data(
5349

5450
pytest.importorskip("snowflake.connector")
5551
pytest.importorskip("snowflake.sqlalchemy")
56-
schema = (script_dir / 'schema' / 'snowflake.sql').read_text()
57-
58-
con = TestConf.connect(data_dir)
5952

60-
with con.begin() as con:
61-
con.execute(sa.text("USE WAREHOUSE ibis_testing"))
62-
con.execute(sa.text(f"DROP DATABASE IF EXISTS {database}"))
63-
con.execute(sa.text(f"CREATE DATABASE IF NOT EXISTS {database}"))
64-
con.execute(sa.text(f"CREATE SCHEMA IF NOT EXISTS {database}.ibis_testing"))
65-
con.execute(sa.text(f"USE SCHEMA {database}.ibis_testing"))
53+
if (snowflake_url := os.environ.get("SNOWFLAKE_URL")) is None:
54+
pytest.skip("SNOWFLAKE_URL environment variable is not defined")
6655

67-
for stmt in filter(None, map(str.strip, schema.split(';'))):
68-
con.execute(sa.text(stmt))
56+
with script_dir.joinpath('schema', 'snowflake.sql').open() as schema:
57+
con = init_database(
58+
url=sa.engine.make_url(snowflake_url).set(database=""),
59+
database=database,
60+
schema=schema,
61+
)
6962

63+
with con.begin() as c:
7064
# not much we can do to make this faster, but running these in
7165
# multiple threads seems to save about 2x
7266
with concurrent.futures.ThreadPoolExecutor() as exe:
7367
for result in concurrent.futures.as_completed(
7468
map(
75-
partial(exe.submit, partial(copy_into, con, data_dir)),
69+
partial(exe.submit, partial(copy_into, c, data_dir)),
7670
TEST_TABLES,
7771
)
7872
):

0 commit comments

Comments
 (0)