Skip to content

Commit 9fdebf6

Browse files
authored
Tests on Postgres (#42)
1 parent 6400276 commit 9fdebf6

File tree

8 files changed

+38
-19
lines changed

8 files changed

+38
-19
lines changed

.github/workflows/tests.yml

+7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ jobs:
2727
run: |
2828
uv sync --dev
2929
30+
- name: Set up PostgreSQL
31+
uses: harmon758/postgresql-action@v1
32+
with:
33+
postgresql db: mustelatestdb
34+
postgresql user: mustelatestuser
35+
postgresql password: mustelatestpassword
36+
3037
- name: Run Test Suite
3138
run: |
3239
uv run pytest -v --tb=short --disable-warnings --maxfail=1 --cov=mustela

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ test = [
4242
"ibis-framework[duckdb]>=5.1.0",
4343
"pytest-cov>=5.0.0",
4444
"pytest>=8.3.2",
45+
"sqlalchemy",
46+
"psycopg2",
47+
"duckdb",
4548
]
4649

4750
[tool.uv]
@@ -65,6 +68,9 @@ dev-dependencies = [
6568
"pydot",
6669
"onnxruntime",
6770
"onnxscript",
71+
"sqlalchemy",
72+
"psycopg2",
73+
"duckdb",
6874
]
6975

7076

src/mustela/translation/optimizer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ def fold_case(self, expr: ibis.Value | ibis.Deferred) -> ibis.Value:
174174
else:
175175
return op.default.to_expr()
176176
elif len(op.cases) == 1 and results_are_literals and possible_values == {1, 0}:
177-
# results are 1 or 0, we can fold it to a boolean
178-
# expression.
177+
# results are 1 or 0, we can fold it to a boolean expression.
178+
# FIXME: This doesn't work on postgresql so we need to disable it for the moment.
179+
return expr
179180
if op.results[0].value == 1:
180181
return (op.cases[0].to_expr()).cast("float64")
181182
else:

src/mustela/translation/steps/onehotencoder.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ def process(self) -> None:
3232
raise ValueError("OneHotEncoder: input expression not found")
3333

3434
casted_variables = [
35-
self._optimizer.fold_cast(
36-
typing.cast(ibis.expr.types.BooleanValue, (input_expr == cat)).cast(
37-
"float64"
38-
)
39-
).name(self.variable_unique_short_alias("onehot"))
35+
ibis.ifelse(input_expr == cat, 1, 0)
36+
.cast("float64")
37+
.name(self.variable_unique_short_alias("oh"))
4038
for cat in cats
4139
]
4240

src/mustela/translation/steps/softmax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def compute_softmax(
5353
if isinstance(data, VariablesGroup):
5454
data = NumericVariablesGroup(data)
5555
max_value = ibis.greatest(*data.values()).name(
56-
translator.variable_unique_short_alias("sfmmax")
56+
translator.variable_unique_short_alias("sfmx")
5757
)
5858
translator.preserve(max_value)
5959

src/mustela/translation/steps/trees/classifier.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,20 @@ def build_classifier(
103103
else:
104104
ordered_features = typing.cast(list[ibis.Value], [input_expr])
105105
ordered_features = [
106-
feature.name(self.variable_unique_short_alias("tclass"))
106+
feature.name(self.variable_unique_short_alias("tcl"))
107107
for feature in ordered_features
108108
]
109109
ordered_features = self.preserve(*ordered_features)
110110

111111
def build_tree_case(node: dict) -> dict[str | int, ibis.Expr]:
112112
# Leaf node, return the votes
113113
if node["mode"] == "LEAF":
114-
votes = {}
115-
for clslabel in classlabels:
116-
# We can assume missing class = weight 0
117-
# The optimizer will remove this if both true and false have 0.
118-
votes[clslabel] = ibis.literal(node["weight"].get(clslabel, 0))
119-
return votes
114+
# We can assume missing class = weight 0
115+
# The optimizer will remove this if both true and false have 0.
116+
return {
117+
clslabel: ibis.literal(node["weight"].get(clslabel, 0.0))
118+
for clslabel in classlabels
119+
}
120120

121121
# Branch node, build a CASE statement
122122
feature_expr = ordered_features[node["feature_id"]]

src/mustela/translation/steps/trees/regressor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def build_regressor(self, input_expr: VariablesGroup | ibis.Expr) -> ibis.Expr:
4646
else:
4747
ordered_features = typing.cast(list[ibis.Value], [input_expr])
4848
ordered_features = [
49-
feature.name(self.variable_unique_short_alias("tclass"))
49+
feature.name(self.variable_unique_short_alias("tcl"))
5050
for feature in ordered_features
5151
]
5252
ordered_features = self.preserve(*ordered_features)

tests/test_pipeline_e2e.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sqlite3
22

33
import duckdb
4+
import sqlalchemy
45
import numpy as np
56
import pandas as pd
67
import pytest
@@ -41,7 +42,7 @@ def diabetes_data(self):
4142
df = pd.concat([X, y], axis=1)
4243
return df, feature_names
4344

44-
@pytest.fixture(params=["duckdb", "sqlite"])
45+
@pytest.fixture(params=["duckdb", "sqlite", "postgres"])
4546
def db_connection(self, request):
4647
dialect = request.param
4748
if dialect == "duckdb":
@@ -52,13 +53,19 @@ def db_connection(self, request):
5253
conn = sqlite3.connect(":memory:")
5354
yield conn, dialect
5455
conn.close()
56+
elif dialect == "postgres":
57+
try:
58+
conn = sqlalchemy.create_engine("postgresql://mustelatestuser:mustelatestpassword@localhost:5432/mustelatestdb")
59+
except (sqlalchemy.exc.OperationalError, ImportError):
60+
pytest.skip("Postgres database not available")
61+
yield conn, dialect
62+
conn.dispose()
5563

5664
def execute_sql(self, sql, conn, dialect, data):
5765
if dialect == "duckdb":
5866
conn.execute("CREATE TABLE data AS SELECT * FROM data")
59-
# print(conn.execute("SELECT * FROM data").fetchdf())
6067
result = conn.execute(sql).fetchdf()
61-
elif dialect == "sqlite":
68+
elif dialect in ("sqlite", "postgres"):
6269
data.to_sql("data", conn, index=False, if_exists="replace")
6370
result = pd.read_sql(sql, conn)
6471
return result

0 commit comments

Comments
 (0)