@@ -43,8 +43,9 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
43
43
)
44
44
45
45
# Weight could be a float or a dictionary of class labels weights
46
- weights : dict [ tuple [ int , int ], dict [ str | int , float ] | float ] = {}
46
+ weights = {}
47
47
if node .op_type == "TreeEnsembleClassifier" :
48
+ weights = typing .cast (dict [tuple [int , int ], dict [str | int , float ]], weights )
48
49
# Weights for classifier, in this case the weights are per-class
49
50
class_nodeids = typing .cast (list [int ], translator ._attributes ["class_nodeids" ])
50
51
class_treeids = typing .cast (list [int ], translator ._attributes ["class_treeids" ])
@@ -66,15 +67,31 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
66
67
if not classlabels :
67
68
raise ValueError ("Missing class labels when building tree" )
68
69
70
+ is_binary = (len (classlabels ) == 2 and len (set (weights_classid )) == 1 )
69
71
for tree_id , node_id , weight , weight_classid in zip (
70
72
class_treeids , class_nodeids , class_weights , weights_classid
71
73
):
72
- node_weights = weights .setdefault ((tree_id , node_id ), {})
73
- typing .cast (dict [str | int , float ], node_weights )[
74
- classlabels [weight_classid ]
75
- ] = weight
74
+ node_weights = typing .cast (dict [str | int , float ], weights .setdefault ((tree_id , node_id ), {}))
75
+ node_weights [classlabels [weight_classid ]] = weight
76
+
77
+ if is_binary :
78
+ # ONNX treats binary classification as a special case:
79
+ # https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L854C1-L871C4
80
+ # https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L469-L494
81
+ # In this case there is only one weight and it's the probability of the positive class.
82
+ for node_weights in weights .values ():
83
+ assert len (node_weights ) == 1 , f"Binary classification expected to have only one class, got: { node_weights } "
84
+ score = list (node_weights .values ())[0 ]
85
+ if score > 0.5 :
86
+ node_weights [classlabels [1 ]] = 1.0
87
+ node_weights [classlabels [0 ]] = 0.0
88
+ else :
89
+ node_weights [classlabels [1 ]] = 0.0
90
+ node_weights [classlabels [0 ]] = 1.0
91
+
76
92
elif node .op_type == "TreeEnsembleRegressor" :
77
93
# Weights for the regressor, in this case leaf nodes have only 1 weight
94
+ weights = typing .cast (dict [tuple [int , int ], float ], weights )
78
95
target_weights = typing .cast (
79
96
list [float ], translator ._attributes ["target_weights" ]
80
97
)
0 commit comments