Skip to content

Commit d192e64

Browse files
committed
Progress on type hinting
1 parent 1fd0333 commit d192e64

12 files changed

+70
-57
lines changed

examples/pipeline_boosted_tree_classifier.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
PRINT_SQL = False
1818
logging.basicConfig(level=logging.INFO)
19-
logging.getLogger("mustela").setLevel(logging.DEBUG)
19+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
2020

2121
# Load Ames Housing for classification
2222
ames = fetch_openml(name="house_prices", as_frame=True)
@@ -91,7 +91,6 @@ def categorize_price(price: float) -> str:
9191

9292
# Convert types from numpy to mustela types
9393
features = mustela.types.guess_datatypes(X)
94-
print("Mustela Features:", features)
9594

9695
# Target only 5 rows, so that it's easier for a human to understand
9796
data_sample = X.head(5)

examples/pipeline_boosted_tree_regressor.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
PRINT_SQL = False
1818

1919
logging.basicConfig(level=logging.INFO)
20-
logging.getLogger("mustela").setLevel(logging.DEBUG)
20+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
2121

2222
ames = fetch_openml(name="house_prices", as_frame=True)
2323
ames = ames.frame
@@ -81,13 +81,11 @@
8181
)
8282
model.fit(X, y)
8383

84-
features = mustela.types.guess_datatypes(X)
85-
print("Mustela Features:", features)
86-
8784
# Create a small set of data for the prediction
8885
# It's easier to understand if it's small
8986
data_sample = X.head(5)
9087

88+
features = mustela.types.guess_datatypes(X)
9189
mustela_pipeline = mustela.parse_pipeline(model, features=features)
9290
print(mustela_pipeline)
9391

examples/pipeline_decision_tree_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
PRINT_SQL = False
1818

1919
logging.basicConfig(level=logging.INFO)
20-
logging.getLogger("mustela").setLevel(logging.DEBUG)
20+
logging.getLogger("mustela").setLevel(logging.INFO) # Change to DEBUG to see each translation step.
2121

2222
iris = load_iris()
2323
df = pd.DataFrame(

examples/pipeline_decision_tree_regressor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
PRINT_SQL = False
1818
logging.basicConfig(level=logging.INFO)
19-
logging.getLogger("mustela").setLevel(logging.DEBUG)
19+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
2020

2121
# Carica il dataset
2222
iris = load_iris()

examples/pipeline_elasticnet.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,38 @@
1313
PRINT_SQL = False
1414

1515
logging.basicConfig(level=logging.INFO)
16-
logging.getLogger("mustela").setLevel(logging.DEBUG)
16+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
1717

1818
# Load the Iris dataset
1919
iris = load_iris(as_frame=True)
20+
iris_x = iris.data
2021

21-
# Define column names for consistency
22-
names = ["sepal.length", "sepal.width", "petal.length", "petal.width"]
22+
# SQL and Mustela don't like dots in column names, replace them with underscores
23+
iris_x.columns = [cname.replace(".", "_") for cname in iris_x.columns]
2324

24-
iris_x = iris.data.set_axis(names, axis=1)
25+
numeric_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
26+
iris_x = iris_x.set_axis(numeric_cols, axis=1)
2527

2628
# Create a pipeline with ElasticNet instead of LinearRegression
2729
pipeline = Pipeline(
2830
[
2931
(
3032
"preprocess",
3133
ColumnTransformer(
32-
[("scaler", StandardScaler(with_std=False), names)],
34+
[("scaler", StandardScaler(with_std=False), numeric_cols)],
3335
remainder="passthrough",
3436
),
3537
),
3638
("elastic_net", ElasticNet(alpha=0.1, l1_ratio=0.5)), # ElasticNet with L1/L2 regularization
3739
]
3840
)
3941

40-
# Train the pipeline
4142
pipeline.fit(iris_x, iris.target)
4243

43-
print(iris_x.columns)
44-
45-
# Identify feature types for Mustela
44+
# Convenience for this example to avoid repeating the schema,
45+
# in real cases, the user would know the schema of its database.
4646
features = mustela.types.guess_datatypes(iris_x)
47-
print("Mustela Features:", features)
4847

49-
# Convert the pipeline into SQL with Mustela
5048
mustela_pipeline = mustela.parse_pipeline(pipeline, features=features)
5149
print(mustela_pipeline)
5250

@@ -60,20 +58,23 @@
6058
}
6159
)
6260

63-
# Generate an SQL query using Mustela
61+
# Generate a query expression using Mustela
6462
ibis_expression = mustela.translate(ibis.memtable(example_data), mustela_pipeline)
6563

64+
con = ibis.duckdb.connect()
65+
6666
if PRINT_SQL:
67+
sql = mustela.export_sql("DATA_TABLE", mustela_pipeline, dialect="duckdb")
6768
print("\nGenerated Query for DuckDB:")
68-
con = ibis.duckdb.connect()
69-
print(con.compile(ibis_expression))
69+
print(sql)
70+
print("\nPrediction with SQL")
71+
# We need to create the table for the SQL to query it.
72+
con.create_table(ibis_table.get_name(), obj=ibis_table)
73+
print(con.raw_sql(sql).df())
7074

71-
# Predictions using Ibis
7275
print("\nPrediction with Ibis")
7376
print(ibis_expression.execute())
7477

75-
# Predictions using SKLearn
76-
new_column_names = [name.replace("_", ".") for name in example_data.column_names] # SkLearn uses dots in column names
77-
renamed_example_data = example_data.rename_columns(new_column_names).to_pandas()
78-
predictions = pipeline.predict(renamed_example_data)
78+
print("\nPrediction with SKLearn")
79+
predictions = pipeline.predict(example_data.to_pandas())
7980
print(predictions)

examples/pipeline_lasso.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
PRINT_SQL = False
1515

1616
logging.basicConfig(level=logging.INFO)
17-
logging.getLogger("mustela").setLevel(logging.DEBUG)
17+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
1818

1919
iris = load_iris(as_frame=True)
2020

examples/pipeline_lineareg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
PRINT_SQL = True
1515

1616
logging.basicConfig(level=logging.INFO)
17-
logging.getLogger("mustela").setLevel(logging.INFO)
17+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
1818

1919
iris = load_iris(as_frame=True)
2020
iris_x = iris.data
@@ -39,9 +39,9 @@
3939
)
4040
pipeline.fit(iris_x, iris.target)
4141

42-
42+
# Convenience for this example to avoid repeating the schema,
43+
# in real cases, the user would know the schema of its database.
4344
features = mustela.types.guess_datatypes(iris_x)
44-
print("Mustela Features:", features)
4545

4646
mustela_pipeline = mustela.parse_pipeline(pipeline, features=features)
4747
print(mustela_pipeline)

examples/pipeline_logisticreg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
PRINT_SQL = False
1717

1818
logging.basicConfig(level=logging.INFO)
19-
logging.getLogger("mustela").setLevel(logging.DEBUG)
19+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
2020

2121
# Carichiamo il dataset iris e creiamo un DataFrame
2222
iris = load_iris()

examples/pipeline_randforest_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
PRINT_SQL = False
1717

1818
logging.basicConfig(level=logging.INFO)
19-
logging.getLogger("mustela").setLevel(logging.DEBUG)
19+
logging.getLogger("mustela").setLevel(logging.INFO) # Set DEBUG to see translation process.
2020

2121
iris = load_iris()
2222
df = pd.DataFrame(iris.data, columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])

src/mustela/_utils/repr_pipeline.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def _attributes(self, attributes: typing.Iterable[_onnx.AttributeProto]) -> str:
6565
def _attr_value(attr: _onnx.AttributeProto) -> str:
6666
return self._shorten(str(get_attr_value(attr)))
6767

68-
return ", ".join((f"{attr.name}={_attr_value(attr)}" for attr in attributes))
68+
indent = "\n "
69+
return indent + indent.join((f"{attr.name}={_attr_value(attr)}" for attr in attributes))
6970

7071
def _shorten(self, value: str) -> str:
7172
"""Shorten a string to maxlen characters."""

src/mustela/translation/steps/argmax.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ class ArgMaxTranslator(Translator):
2929
on which to perform a prediction/classification (row).
3030
"""
3131

32-
# https://onnx.ai/onnx/operators/onnx__ArgMax.html
33-
3432
def process(self) -> None:
3533
"""Performs the translation and set the output variable."""
34+
# https://onnx.ai/onnx/operators/onnx__ArgMax.html
3635
data = self._variables.consume(self.inputs[0])
3736
axis = self._attributes.get("axis", 1)
3837
keepdims = self._attributes.get("keepdims", 1)

src/mustela/translation/steps/arrayfeatureextractor.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,50 @@
66

77

88
class ArrayFeatureExtractorTranslator(Translator):
9-
"""Processes an ArgMax node and updates the variables with the output expression."""
10-
11-
12-
def process(self):
9+
"""Processes an ArrayFeatureExtractor node and updates the variables with the output expression.
10+
11+
ArrayFeatureExtractor can be considered the opposit of :class:`ConactTranslator`, as
12+
in most cases it will be used to pick one or more features out of a group of column
13+
previously concatenated, or to pick a specific feature out of the result of an ArgMax operation.
14+
15+
The provided indices always refer to the **last** axis of the input tensor.
16+
If the input is a 2D tensor, the last axis is the column axis. So an index
17+
of ``0`` would mean the first column. If the input is a 1D tensor instead the
18+
last axis is the row axis. So an index of ``0`` would mean the first row.
19+
20+
This could be confusing because axis are inverted between tensors and mustela column groups.
21+
In the case of Tensors, axis=0 means row=0, while instead of mustela
22+
column groups (by virtue of being a group of columns), axis=0 means
23+
the first column.
24+
25+
We have to consider that the indices we receive, in case of column groups,
26+
are actually column indices, not row indices as in case of a tensor,
27+
the last index would be the column index. In case of single columns,
28+
instead the index is the index of a row like it would be with a 1D tensor.
29+
"""
30+
def process(self) -> None:
31+
"""Performs the translation and set the output variable."""
1332
# https://onnx.ai/onnx/operators/onnx_aionnxml_ArrayFeatureExtractor.html
1433

15-
# Given an array of features, grab only one of them
16-
# This probably is used to extract a single feature from a list of features
17-
# Previously made by Concat.
18-
# Or to pick the right feature from the result of ArgMax
1934
data = self._variables.consume(self.inputs[0])
2035
indices = self._variables.consume(self.inputs[1])
2136

22-
data_keys = None
23-
if isinstance(data, dict):
24-
# This expects that dictionaries are sorted by insertion order
25-
# AND that all values of the dictionary are featues with dim_value: 1
26-
# TODO: Implement a class for Concatenaed values
27-
# that implements support based on dimensions
28-
data_keys = list(data.keys())
29-
data = list(data.values())
37+
if not isinstance(data, dict):
38+
# TODO: Implement support for selecting rows from a 1D tensor
39+
raise NotImplementedError("ArrayFeatureExtractor only supports column groups as inputs")
40+
41+
# This expects that dictionaries are sorted by insertion order
42+
# AND that all values of the dictionary are columns.
43+
data_keys = list(data.keys())
44+
data = list(data.values())
3045

3146
if isinstance(indices, (list, tuple)):
32-
# We only work with dictionaries of faturename: feature
33-
# So when we are expected to output a list of features
34-
# we should output a dictionary of features as they are just sorted.
47+
if data_keys is None:
48+
raise ValueError("ArrayFeatureExtractor expects a group of columns as input when receiving a list of indices")
49+
if len(indices) > len(data_keys):
50+
raise ValueError("Indices requested are more than the available numer of columns.")
51+
# Pick only the columns that are in the list of indicies.
3552
result = {data_keys[i]: data[i] for i in indices}
36-
elif isinstance(indices, int):
37-
result = data[indices]
3853
elif isinstance(indices, ibis.expr.types.Column):
3954
# The indices that we need to pick are contained in
4055
# another column of the table.

0 commit comments

Comments
 (0)