Skip to content

Commit 7f7a091

Browse files
committed
Fix binary classification on trees
1 parent 4285bb1 commit 7f7a091

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/mustela/translation/steps/trees/classifier.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...translator import Translator
88
from ...variables import VariablesGroup
99
from ..softmax import SoftmaxTranslator
10+
from ..linearclass import LinearClassifierTranslator
1011
from .tree import build_tree, mode_to_condition
1112

1213

@@ -40,11 +41,16 @@ def process(self) -> None:
4041
)
4142

4243
label_expr, prob_colgroup = self.build_classifier(input_exr)
43-
post_transform = self._attributes.get("post_transform", "NONE")
44+
post_transform = typing.cast(str, self._attributes.get("post_transform", "NONE"))
4445

4546
if post_transform != "NONE":
4647
if post_transform == "SOFTMAX":
4748
prob_colgroup = SoftmaxTranslator.compute_softmax(prob_colgroup)
49+
elif post_transform == "LOGISTIC":
50+
prob_colgroup = VariablesGroup({
51+
lbl: LinearClassifierTranslator._apply_post_transform(prob_col, post_transform)
52+
for lbl, prob_col in prob_colgroup.items()
53+
})
4854
else:
4955
raise NotImplementedError(
5056
f"Post transform {post_transform} not implemented."

src/mustela/translation/steps/trees/tree.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
4343
)
4444

4545
# Weight could be a float or a dictionary of class labels weights
46-
weights: dict[tuple[int, int], dict[str | int, float] | float] = {}
46+
weights = {}
4747
if node.op_type == "TreeEnsembleClassifier":
48+
weights = typing.cast(dict[tuple[int, int], dict[str | int, float]], weights)
4849
# Weights for classifier, in this case the weights are per-class
4950
class_nodeids = typing.cast(list[int], translator._attributes["class_nodeids"])
5051
class_treeids = typing.cast(list[int], translator._attributes["class_treeids"])
@@ -66,15 +67,31 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
6667
if not classlabels:
6768
raise ValueError("Missing class labels when building tree")
6869

70+
is_binary = (len(classlabels) == 2 and len(set(weights_classid)) == 1)
6971
for tree_id, node_id, weight, weight_classid in zip(
7072
class_treeids, class_nodeids, class_weights, weights_classid
7173
):
72-
node_weights = weights.setdefault((tree_id, node_id), {})
73-
typing.cast(dict[str | int, float], node_weights)[
74-
classlabels[weight_classid]
75-
] = weight
74+
node_weights = typing.cast(dict[str | int, float], weights.setdefault((tree_id, node_id), {}))
75+
node_weights[classlabels[weight_classid]] = weight
76+
77+
if is_binary:
78+
# ONNX treats binary classification as a special case:
79+
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L854C1-L871C4
80+
# https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L469-L494
81+
# In this case there is only one weight and it's the probability of the positive class.
82+
for node_weights in weights.values():
83+
assert len(node_weights) == 1, f"Binary classification expected to have only one class, got: {node_weights}"
84+
score = list(node_weights.values())[0]
85+
if score > 0.5:
86+
node_weights[classlabels[1]] = 1.0
87+
node_weights[classlabels[0]] = 0.0
88+
else:
89+
node_weights[classlabels[1]] = 0.0
90+
node_weights[classlabels[0]] = 1.0
91+
7692
elif node.op_type == "TreeEnsembleRegressor":
7793
# Weights for the regressor, in this case leaf nodes have only 1 weight
94+
weights = typing.cast(dict[tuple[int, int], float], weights)
7895
target_weights = typing.cast(
7996
list[float], translator._attributes["target_weights"]
8097
)

0 commit comments

Comments
 (0)