Skip to content

Commit c0b9d0f

Browse files
authored
Fix binary trees (#37)
1 parent 639c215 commit c0b9d0f

File tree

3 files changed

+70
-47
lines changed

3 files changed

+70
-47
lines changed

src/mustela/translation/steps/trees/classifier.py

+56-27
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,22 @@ def build_classifier(
8181
) or self._attributes.get("classlabels_int64s")
8282
if classlabels is None:
8383
raise ValueError("Unable to detect classlabels for classification")
84-
classlabels = typing.cast(list[str] | list[int], classlabels)
84+
output_classlabels = classlabels = typing.cast(
85+
list[str] | list[int], classlabels
86+
)
87+
88+
# ONNX treats binary classification as a special case:
89+
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L854C1-L871C4
90+
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L469-L494
91+
# In this case there is only one weight and it's the probability of the positive class.
92+
# So we need to check if we are in a binary classification case.
93+
weights_classid = typing.cast(list[int], self._attributes["class_ids"])
94+
is_binary = len(classlabels) == 2 and len(set(weights_classid)) == 1
95+
if is_binary:
96+
# In this case there is only one label, the first one
97+
# which actually acts as the score of the prediction.
98+
# When > 0.5 then class 1, when < 0.5 then class 0
99+
classlabels = typing.cast(list[str] | list[int], [classlabels[0]])
85100

86101
if isinstance(input_expr, VariablesGroup):
87102
ordered_features = input_expr.values_value()
@@ -134,39 +149,53 @@ def build_tree_case(node: dict) -> dict[str | int, ibis.Expr]:
134149
)
135150

136151
# Compute prediction of class itself.
137-
candidate_cls = classlabels[0]
138-
candidate_vote = total_votes[candidate_cls]
139-
for clslabel in classlabels[1:]:
140-
candidate_cls = optimizer.fold_case(
152+
if is_binary:
153+
total_score = total_votes[classlabels[0]]
154+
label_expr = optimizer.fold_case(
141155
ibis.case()
142-
.when(total_votes[clslabel] > candidate_vote, clslabel)
143-
.else_(candidate_cls)
156+
.when(total_score > 0.5, output_classlabels[1])
157+
.else_(output_classlabels[0])
144158
.end()
145159
)
146-
candidate_vote = optimizer.fold_case(
147-
ibis.case()
148-
.when(total_votes[clslabel] > candidate_vote, total_votes[clslabel])
149-
.else_(candidate_vote)
150-
.end()
160+
# The order matters, for ONNX the VariableGroup is a list of subvariables
161+
# the names are not important.
162+
prob_dict = VariablesGroup(
163+
{
164+
str(output_classlabels[0]): 1.0 - total_score,
165+
str(output_classlabels[1]): total_score,
166+
}
151167
)
168+
else:
169+
candidate_cls = classlabels[0]
170+
candidate_vote = total_votes[candidate_cls]
171+
for clslabel in classlabels[1:]:
172+
candidate_cls = optimizer.fold_case(
173+
ibis.case()
174+
.when(total_votes[clslabel] > candidate_vote, clslabel)
175+
.else_(candidate_cls)
176+
.end()
177+
)
178+
candidate_vote = optimizer.fold_case(
179+
ibis.case()
180+
.when(total_votes[clslabel] > candidate_vote, total_votes[clslabel])
181+
.else_(candidate_vote)
182+
.end()
183+
)
152184

153-
label_expr = ibis.case()
154-
for clslabel in classlabels:
155-
label_expr = label_expr.when(candidate_cls == clslabel, clslabel)
156-
label_expr = label_expr.else_(ibis.null()).end()
157-
label_expr = optimizer.fold_case(label_expr)
185+
label_expr = ibis.case()
186+
for clslabel in classlabels:
187+
label_expr = label_expr.when(candidate_cls == clslabel, clslabel)
188+
label_expr = label_expr.else_(ibis.null()).end()
189+
label_expr = optimizer.fold_case(label_expr)
158190

159-
# Compute probability to return it too.
160-
sum_votes = None
161-
for clslabel in classlabels:
162-
if sum_votes is None:
163-
sum_votes = total_votes[clslabel]
164-
else:
191+
# Compute probability to return it too.
192+
sum_votes = ibis.literal(0.0)
193+
for clslabel in classlabels:
165194
sum_votes = optimizer.fold_operation(sum_votes + total_votes[clslabel])
166195

167-
# FIXME: Probabilities are currently broken for gradient boosted trees.
168-
prob_dict = VariablesGroup()
169-
for clslabel in classlabels:
170-
prob_dict[str(clslabel)] = total_votes[clslabel] / sum_votes
196+
# FIXME: Probabilities are currently broken for gradient boosted trees.
197+
prob_dict = VariablesGroup()
198+
for clslabel in classlabels:
199+
prob_dict[str(clslabel)] = total_votes[clslabel] / sum_votes
171200

172201
return label_expr, prob_dict

src/mustela/translation/steps/trees/tree.py

-18
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
6767
if not classlabels:
6868
raise ValueError("Missing class labels when building tree")
6969

70-
is_binary = len(classlabels) == 2 and len(set(weights_classid)) == 1
7170
for tree_id, node_id, weight, weight_classid in zip(
7271
class_treeids, class_nodeids, class_weights, weights_classid
7372
):
@@ -76,23 +75,6 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
7675
)
7776
node_weights[classlabels[weight_classid]] = weight
7877

79-
if is_binary:
80-
# ONNX treats binary classification as a special case:
81-
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L854C1-L871C4
82-
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L469-L494
83-
# In this case there is only one weight and it's the probability of the positive class.
84-
for node_weights in weights.values():
85-
assert len(node_weights) == 1, (
86-
f"Binary classification expected to have only one class, got: {node_weights}"
87-
)
88-
score = list(node_weights.values())[0]
89-
if score > 0.5:
90-
node_weights[classlabels[1]] = 1.0
91-
node_weights[classlabels[0]] = 0.0
92-
else:
93-
node_weights[classlabels[1]] = 0.0
94-
node_weights[classlabels[0]] = 1.0
95-
9678
elif node.op_type == "TreeEnsembleRegressor":
9779
# Weights for the regressor, in this case leaf nodes have only 1 weight
9880
weights = typing.cast(dict[tuple[int, int], float], weights)

tests/test_pipeline_e2e.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ def assign_region(width):
463463

464464
def test_binary_random_forest_classifier(self, iris_data, db_connection):
465465
"""Test a binary random forest classifier with mixed preprocessing."""
466-
pytest.skip("Binary classification on trees is currently not implemented.")
467466
df, feature_names = iris_data
468467
conn, dialect = db_connection
469468

@@ -504,9 +503,22 @@ def test_binary_random_forest_classifier(self, iris_data, db_connection):
504503
)
505504
parsed_pipeline = mustela.parse_pipeline(sklearn_pipeline, features=features)
506505

506+
# Test prediction
507507
sql = mustela.export_sql("data", parsed_pipeline, dialect=dialect)
508508
sql_results = self.execute_sql(sql, conn, dialect, binary_df)
509-
510509
np.testing.assert_allclose(
511510
sql_results["output_label"].to_numpy(), sklearn_class
512511
)
512+
513+
# Test probabilities
514+
sklearn_proba = sklearn_pipeline.predict_proba(X)
515+
sklearn_proba_df = pd.DataFrame(
516+
sklearn_proba, columns=sklearn_pipeline.classes_
517+
)
518+
for class_label in sklearn_pipeline.classes_:
519+
np.testing.assert_allclose(
520+
sql_results[f"output_probability.{class_label}"].to_numpy(),
521+
sklearn_proba_df[class_label].values.flatten(),
522+
rtol=1e-4,
523+
atol=1e-4,
524+
)

0 commit comments

Comments
 (0)