Skip to content

Commit b775281

Browse files
committed
Expose a simple way to perform projections
1 parent 08835d5 commit b775281

File tree

5 files changed

+102
-11
lines changed

5 files changed

+102
-11
lines changed

examples/minimal.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@
4141
})
4242
print(orbitalml_pipeline)
4343

44-
sql = orbitalml.export_sql("DATA_TABLE", orbitalml_pipeline, dialect="duckdb")
44+
sql = orbitalml.export_sql("DATA_TABLE", orbitalml_pipeline, projection=orbitalml.ResultsProjection(["sepal_width"]), dialect="duckdb")
4545
print("\nGenerated Query for DuckDB:")
4646
print(sql)
4747
print("\nPrediction with SQL")
4848
duckdb.register("DATA_TABLE", X_test)
49-
print(duckdb.sql(sql).df()["variable"][:5].to_numpy())
49+
result = duckdb.sql(sql).df()
50+
print(result.head())
51+
print("---")
52+
print(result["variable"][:5].to_numpy())
5053
print("\nPrediction with SciKit-Learn")
5154
print(pipeline.predict(X_test)[:5])

src/orbitalml/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99

1010
from .ast import parse_pipeline
1111
from .sql import export_sql
12-
from .translate import translate
12+
from .translate import ResultsProjection, translate
1313

14-
__all__ = ["parse_pipeline", "translate", "export_sql"]
14+
__all__ = ["parse_pipeline", "translate", "export_sql", "ResultsProjection"]

src/orbitalml/sql.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ibis.expr.sql import Catalog
1313

1414
from .ast import ParsedPipeline
15-
from .translate import translate
15+
from .translate import ResultsProjection, translate
1616

1717
OPTIMIZER_RULES = (
1818
sqlglot.optimizer.optimizer.qualify,
@@ -36,6 +36,7 @@ def export_sql(
3636
table_name: str,
3737
pipeline: ParsedPipeline,
3838
dialect: str = "duckdb",
39+
projection: ResultsProjection = ResultsProjection(),
3940
optimize: bool = True,
4041
) -> str:
4142
"""Export SQL for a given pipeline.
@@ -58,7 +59,12 @@ def export_sql(
5859
name=table_name,
5960
)
6061

61-
ibis_expr = translate(unbound_table, pipeline)
62+
if projection.is_empty():
63+
raise ValueError(
64+
"Projection is empty. Please provide a projection to export SQL."
65+
)
66+
67+
ibis_expr = translate(unbound_table, pipeline, projection=projection)
6268
sqlglot_expr = getattr(sc, dialect).compiler.to_sqlglot(ibis_expr)
6369

6470
if optimize:

src/orbitalml/translate.py

+63-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,57 @@
7676
LOG_SQL = False
7777

7878

79-
def translate(table: ibis.Table, pipeline: ParsedPipeline) -> ibis.Table:
79+
class ResultsProjection:
80+
"""Projection of the results of the pipeline.
81+
82+
This class is used to select the columns to be returned
83+
from the pipeline. It can be used to select specific
84+
columns or to omit the results of the pipeline.
85+
86+
It can also be used to skip the select step of columns
87+
from the pipeline.
88+
"""
89+
90+
RESULTS = object()
91+
OMIT = object()
92+
93+
def __init__(self, select: typing.Optional[list[str]] = None) -> None:
94+
"""
95+
:param select: A list of additional columns to be selected from the pipeline.
96+
or ResultsProjection.OMIT to skip the selection.
97+
"""
98+
if select is self.OMIT:
99+
self._select = None
100+
else:
101+
self._select = [self.RESULTS]
102+
if select:
103+
self._select.extend(select)
104+
105+
def is_empty(self) -> bool:
106+
"""Check if the projection step should be skipped."""
107+
return self._select is None
108+
109+
def _expand(self, results: typing.Iterable[str]) -> typing.Optional[list[str]]:
110+
if self._select is None:
111+
return None
112+
113+
selected_columns = self._select
114+
115+
def _emit_projection() -> typing.Generator[str, None, None]:
116+
for item in selected_columns:
117+
if item is self.RESULTS:
118+
yield from results
119+
elif isinstance(item, str):
120+
yield item
121+
122+
return list(_emit_projection())
123+
124+
125+
def translate(
126+
table: ibis.Table,
127+
pipeline: ParsedPipeline,
128+
projection: ResultsProjection = ResultsProjection(),
129+
) -> ibis.Table:
80130
"""Translate a pipeline into an Ibis expression.
81131
82132
This function takes a pipeline and a table and translates the pipeline
@@ -98,14 +148,18 @@ def translate(table: ibis.Table, pipeline: ParsedPipeline) -> ibis.Table:
98148
translator.process()
99149
table = translator.mutated_table # Translator might return a new table.
100150
_log_debug_end(translator, variables)
101-
return _projection_results(table, variables)
151+
return _projection_results(table, variables, projection)
102152

103153

104-
def _projection_results(table: ibis.Table, variables: GraphVariables) -> ibis.Table:
154+
def _projection_results(
155+
table: ibis.Table,
156+
variables: GraphVariables,
157+
projection: ResultsProjection = ResultsProjection(),
158+
) -> ibis.Table:
105159
# As we pop out the variables as we use them
106160
# the remaining ones are the values resulting from all
107161
# graph branches.
108-
final_projections = {}
162+
final_projections: dict[str, typing.Any] = {}
109163
for key, value in variables.remaining().items():
110164
if isinstance(value, dict):
111165
for field in value:
@@ -116,7 +170,11 @@ def _projection_results(table: ibis.Table, variables: GraphVariables) -> ibis.Ta
116170
final_projections[colkey] = colvalue
117171
else:
118172
final_projections[key] = value
119-
return table.mutate(**final_projections).select(final_projections.keys())
173+
query = table.mutate(**final_projections)
174+
selection = projection._expand(final_projections.keys())
175+
if selection is not None:
176+
query = query.select(*selection)
177+
return query
120178

121179

122180
def _log_debug_start(translator: Translator, variables: GraphVariables) -> None:

tests/test_pipeline_e2e.py

+24
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,30 @@ def test_simple_linear_regression(self, iris_data, db_connection):
9494
sql_results.values.flatten(), sklearn_preds.flatten(), rtol=1e-4, atol=1e-4
9595
)
9696

97+
def test_simple_linear_with_projection(self, iris_data, db_connection):
98+
df, feature_names = iris_data
99+
conn, dialect = db_connection
100+
101+
sklearn_pipeline = Pipeline(
102+
[("scaler", StandardScaler()), ("regression", LinearRegression())]
103+
)
104+
X = df[feature_names]
105+
y = df["target"]
106+
sklearn_pipeline.fit(X, y)
107+
sklearn_preds = sklearn_pipeline.predict(X)
108+
109+
features = {fname: types.FloatColumnType() for fname in feature_names}
110+
parsed_pipeline = orbitalml.parse_pipeline(sklearn_pipeline, features=features)
111+
112+
sql = orbitalml.export_sql("data", parsed_pipeline, projection=orbitalml.ResultsProjection(["sepal_length"]), dialect=dialect)
113+
114+
sql_results = self.execute_sql(sql, conn, dialect, df)
115+
print(sql_results)
116+
assert set(sql_results.columns) == {"sepal_length", "variable.target_0"}
117+
np.testing.assert_allclose(
118+
sql_results["variable.target_0"].values.flatten(), sklearn_preds.flatten(), rtol=1e-4, atol=1e-4
119+
)
120+
97121
def test_feature_selection_pipeline(self, diabetes_data, db_connection):
98122
df, feature_names = diabetes_data
99123
conn, dialect = db_connection

0 commit comments

Comments
 (0)