1
1
import sqlite3
2
2
3
3
import duckdb
4
+ import sqlalchemy
4
5
import numpy as np
5
6
import pandas as pd
6
7
import pytest
@@ -41,7 +42,7 @@ def diabetes_data(self):
41
42
df = pd .concat ([X , y ], axis = 1 )
42
43
return df , feature_names
43
44
44
- @pytest .fixture (params = ["duckdb" , "sqlite" ])
45
+ @pytest .fixture (params = ["duckdb" , "sqlite" , "postgres" ])
45
46
def db_connection (self , request ):
46
47
dialect = request .param
47
48
if dialect == "duckdb" :
@@ -52,13 +53,19 @@ def db_connection(self, request):
52
53
conn = sqlite3 .connect (":memory:" )
53
54
yield conn , dialect
54
55
conn .close ()
56
+ elif dialect == "postgres" :
57
+ try :
58
+ conn = sqlalchemy .create_engine ("postgresql://mustelatestuser:mustelatestpassword@localhost:5432/mustelatestdb" )
59
+ except (sqlalchemy .exc .OperationalError , ImportError ):
60
+ pytest .skip ("Postgres database not available" )
61
+ yield conn , dialect
62
+ conn .dispose ()
55
63
56
64
def execute_sql (self , sql , conn , dialect , data ):
57
65
if dialect == "duckdb" :
58
66
conn .execute ("CREATE TABLE data AS SELECT * FROM data" )
59
- # print(conn.execute("SELECT * FROM data").fetchdf())
60
67
result = conn .execute (sql ).fetchdf ()
61
- elif dialect == "sqlite" :
68
+ elif dialect in ( "sqlite" , "postgres" ) :
62
69
data .to_sql ("data" , conn , index = False , if_exists = "replace" )
63
70
result = pd .read_sql (sql , conn )
64
71
return result
0 commit comments