Skip to content

Commit 1829169

Browse files
authored
ci(athena): use a different database for each python version to avoid clobbering data (#10930)
Ensure that Athena test runs do not clobber each other by creating a database per user+python-version pair.
1 parent a220b47 commit 1829169

File tree

2 files changed

+54
-27
lines changed

2 files changed

+54
-27
lines changed

ibis/backends/athena/__init__.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,17 @@ def do_connect(
344344
s3_staging_dir: str,
345345
cursor_class: type[Cursor] = ArrowCursor,
346346
memtable_volume: str | None = None,
347+
schema_name: str = "default",
348+
catalog_name: str = "awsdatacatalog",
347349
**config: Any,
348350
) -> None:
349351
"""Create an Ibis client connected to an Amazon Athena instance."""
350352
self.con = pyathena.connect(
351-
s3_staging_dir=s3_staging_dir, cursor_class=cursor_class, **config
353+
s3_staging_dir=s3_staging_dir,
354+
cursor_class=cursor_class,
355+
schema_name=schema_name,
356+
catalog_name=catalog_name,
357+
**config,
352358
)
353359

354360
if memtable_volume is None:
@@ -441,23 +447,29 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
441447
pass
442448

443449
def _finalize_memtable(self, name: str) -> None:
444-
path = f"{self._memtable_volume_path}/{name}"
445-
sql = sge.Drop(
446-
kind="TABLE",
447-
this=sg.to_identifier(name, quoted=self.compiler.quoted),
448-
exists=True,
449-
)
450-
451-
with self._safe_raw_sql(sql, unload=False):
452-
pass
453-
454-
self._fs.rm(path, recursive=True)
450+
self.drop_table(name, force=True)
451+
self._fs.rm(f"{self._memtable_volume_path}/{name}", recursive=True)
455452

456453
def create_database(
457-
self, name: str, /, *, catalog: str | None = None, force: bool = False
454+
self,
455+
name: str,
456+
/,
457+
*,
458+
location: str | None = None,
459+
catalog: str | None = None,
460+
force: bool = False,
458461
) -> None:
459462
name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted)
460-
sql = sge.Create(this=name, kind="SCHEMA", exists=force)
463+
sql = sge.Create(
464+
this=name,
465+
kind="SCHEMA",
466+
exists=force,
467+
properties=None
468+
if location is None
469+
else sge.Properties(
470+
expressions=[sge.LocationProperty(this=sge.convert(location))]
471+
),
472+
)
461473
with self._safe_raw_sql(sql, unload=False):
462474
pass
463475

ibis/backends/athena/tests/conftest.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import concurrent.futures
44
import getpass
55
import sys
6-
import uuid
76
from os import environ as env
87
from typing import TYPE_CHECKING, Any
98

@@ -30,6 +29,9 @@
3029
IBIS_ATHENA_S3_STAGING_DIR = env.get(
3130
"IBIS_ATHENA_S3_STAGING_DIR", "s3://aws-athena-query-results-ibis-testing"
3231
)
32+
IBIS_ATHENA_TEST_DATABASE = (
33+
f"{getpass.getuser()}_{''.join(map(str, sys.version_info[:3]))}"
34+
)
3335
AWS_REGION = env.get("AWS_REGION", "us-east-2")
3436
AWS_PROFILE = env.get("AWS_PROFILE")
3537
CONNECT_ARGS = dict(
@@ -49,7 +51,10 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)
4951

5052
ddl = sge.Create(
5153
kind="TABLE",
52-
this=sge.Schema(this=sg.table(name), expressions=sg_schema),
54+
this=sge.Schema(
55+
this=sg.table(name, db=IBIS_ATHENA_TEST_DATABASE, quoted=True),
56+
expressions=sg_schema,
57+
),
5358
properties=sge.Properties(
5459
expressions=[
5560
sge.ExternalProperty(),
@@ -61,16 +66,19 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)
6166

6267
fs.put(str(file), f"{folder.removeprefix('s3://')}/{name}/{file.name}")
6368

64-
drop_query = sge.Drop(kind="TABLE", this=sg.to_identifier(name, quoted=True)).sql(
65-
Athena
66-
)
69+
drop_query = sge.Drop(
70+
kind="TABLE", this=sg.table(name, db=IBIS_ATHENA_TEST_DATABASE), exists=True
71+
).sql(Athena)
72+
6773
create_query = ddl.sql(Athena)
6874

6975
with connection.con.cursor() as cur:
7076
cur.execute(drop_query)
7177
cur.execute(create_query)
7278

73-
assert connection.table(name).count().execute() > 0
79+
assert (
80+
connection.table(name, database=IBIS_ATHENA_TEST_DATABASE).count().execute() > 0
81+
)
7482

7583

7684
class TestConf(BackendTest):
@@ -82,35 +90,42 @@ class TestConf(BackendTest):
8290

8391
deps = ("pyathena", "fsspec")
8492

93+
@staticmethod
94+
def format_table(name: str) -> str:
95+
return sg.table(name, db=IBIS_ATHENA_TEST_DATABASE, quoted=True).sql(Athena)
96+
8597
def _load_data(self, **_: Any) -> None:
8698
import fsspec
8799

88100
files = self.data_dir.joinpath("parquet").glob("*.parquet")
89101

90-
user = getpass.getuser()
91-
python_version = "".join(map(str, sys.version_info[:3]))
92-
folder = f"{user}_{python_version}_{uuid.uuid4()}"
93-
94102
fs = fsspec.filesystem("s3")
95103

96104
connection = self.connection
97-
folder = f"{IBIS_ATHENA_S3_STAGING_DIR}/{folder}"
105+
db_dir = f"{IBIS_ATHENA_S3_STAGING_DIR}/{IBIS_ATHENA_TEST_DATABASE}"
106+
107+
connection.create_database(
108+
IBIS_ATHENA_TEST_DATABASE, location=db_dir, force=True
109+
)
98110

99111
with concurrent.futures.ThreadPoolExecutor() as executor:
100112
for future in concurrent.futures.as_completed(
101113
executor.submit(
102-
create_table, connection, fs=fs, file=file, folder=folder
114+
create_table, connection, fs=fs, file=file, folder=db_dir
103115
)
104116
for file in files
105117
):
106118
future.result()
107119

120+
def postload(self, **kw):
121+
self.connection = self.connect(schema_name=IBIS_ATHENA_TEST_DATABASE, **kw)
122+
108123
@staticmethod
109124
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
110125
return ibis.athena.connect(**CONNECT_ARGS, **kw)
111126

112127
def _remap_column_names(self, table_name: str) -> dict[str, str]:
113-
table = self.connection.table(table_name)
128+
table = self.connection.table(table_name, database=IBIS_ATHENA_TEST_DATABASE)
114129
return table.rename(
115130
dict(zip(TEST_TABLES[table_name].names, table.schema().names))
116131
)

0 commit comments

Comments
 (0)