3
3
import concurrent .futures
4
4
import getpass
5
5
import sys
6
- import uuid
7
6
from os import environ as env
8
7
from typing import TYPE_CHECKING , Any
9
8
30
29
IBIS_ATHENA_S3_STAGING_DIR = env .get (
31
30
"IBIS_ATHENA_S3_STAGING_DIR" , "s3://aws-athena-query-results-ibis-testing"
32
31
)
32
+ IBIS_ATHENA_TEST_DATABASE = (
33
+ f"{ getpass .getuser ()} _{ '' .join (map (str , sys .version_info [:3 ]))} "
34
+ )
33
35
AWS_REGION = env .get ("AWS_REGION" , "us-east-2" )
34
36
AWS_PROFILE = env .get ("AWS_PROFILE" )
35
37
CONNECT_ARGS = dict (
@@ -49,7 +51,10 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)
49
51
50
52
ddl = sge .Create (
51
53
kind = "TABLE" ,
52
- this = sge .Schema (this = sg .table (name ), expressions = sg_schema ),
54
+ this = sge .Schema (
55
+ this = sg .table (name , db = IBIS_ATHENA_TEST_DATABASE , quoted = True ),
56
+ expressions = sg_schema ,
57
+ ),
53
58
properties = sge .Properties (
54
59
expressions = [
55
60
sge .ExternalProperty (),
@@ -61,16 +66,19 @@ def create_table(connection, *, fs: s3fs.S3FileSystem, file: Path, folder: str)
61
66
62
67
fs .put (str (file ), f"{ folder .removeprefix ('s3://' )} /{ name } /{ file .name } " )
63
68
64
- drop_query = sge .Drop (kind = "TABLE" , this = sg .to_identifier (name , quoted = True )).sql (
65
- Athena
66
- )
69
+ drop_query = sge .Drop (
70
+ kind = "TABLE" , this = sg .table (name , db = IBIS_ATHENA_TEST_DATABASE ), exists = True
71
+ ).sql (Athena )
72
+
67
73
create_query = ddl .sql (Athena )
68
74
69
75
with connection .con .cursor () as cur :
70
76
cur .execute (drop_query )
71
77
cur .execute (create_query )
72
78
73
- assert connection .table (name ).count ().execute () > 0
79
+ assert (
80
+ connection .table (name , database = IBIS_ATHENA_TEST_DATABASE ).count ().execute () > 0
81
+ )
74
82
75
83
76
84
class TestConf (BackendTest ):
@@ -82,35 +90,42 @@ class TestConf(BackendTest):
82
90
83
91
deps = ("pyathena" , "fsspec" )
84
92
93
+ @staticmethod
94
+ def format_table (name : str ) -> str :
95
+ return sg .table (name , db = IBIS_ATHENA_TEST_DATABASE , quoted = True ).sql (Athena )
96
+
85
97
def _load_data (self , ** _ : Any ) -> None :
86
98
import fsspec
87
99
88
100
files = self .data_dir .joinpath ("parquet" ).glob ("*.parquet" )
89
101
90
- user = getpass .getuser ()
91
- python_version = "" .join (map (str , sys .version_info [:3 ]))
92
- folder = f"{ user } _{ python_version } _{ uuid .uuid4 ()} "
93
-
94
102
fs = fsspec .filesystem ("s3" )
95
103
96
104
connection = self .connection
97
- folder = f"{ IBIS_ATHENA_S3_STAGING_DIR } /{ folder } "
105
+ db_dir = f"{ IBIS_ATHENA_S3_STAGING_DIR } /{ IBIS_ATHENA_TEST_DATABASE } "
106
+
107
+ connection .create_database (
108
+ IBIS_ATHENA_TEST_DATABASE , location = db_dir , force = True
109
+ )
98
110
99
111
with concurrent .futures .ThreadPoolExecutor () as executor :
100
112
for future in concurrent .futures .as_completed (
101
113
executor .submit (
102
- create_table , connection , fs = fs , file = file , folder = folder
114
+ create_table , connection , fs = fs , file = file , folder = db_dir
103
115
)
104
116
for file in files
105
117
):
106
118
future .result ()
107
119
120
+ def postload (self , ** kw ):
121
+ self .connection = self .connect (schema_name = IBIS_ATHENA_TEST_DATABASE , ** kw )
122
+
108
123
@staticmethod
109
124
def connect (* , tmpdir , worker_id , ** kw ) -> BaseBackend :
110
125
return ibis .athena .connect (** CONNECT_ARGS , ** kw )
111
126
112
127
def _remap_column_names (self , table_name : str ) -> dict [str , str ]:
113
- table = self .connection .table (table_name )
128
+ table = self .connection .table (table_name , database = IBIS_ATHENA_TEST_DATABASE )
114
129
return table .rename (
115
130
dict (zip (TEST_TABLES [table_name ].names , table .schema ().names ))
116
131
)
0 commit comments