@@ -81,7 +81,22 @@ def build_classifier(
81
81
) or self ._attributes .get ("classlabels_int64s" )
82
82
if classlabels is None :
83
83
raise ValueError ("Unable to detect classlabels for classification" )
84
- classlabels = typing .cast (list [str ] | list [int ], classlabels )
84
+ output_classlabels = classlabels = typing .cast (
85
+ list [str ] | list [int ], classlabels
86
+ )
87
+
88
+ # ONNX treats binary classification as a special case:
89
+ # https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L854C1-L871C4
90
+ # https://github.com/microsoft/onnxruntime/blob/5982430af66f52a288cb8b2181e0b5b2e09118c8/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h#L469-L494
91
+ # In this case there is only one weight and it's the probability of the positive class.
92
+ # So we need to check if we are in a binary classification case.
93
+ weights_classid = typing .cast (list [int ], self ._attributes ["class_ids" ])
94
+ is_binary = len (classlabels ) == 2 and len (set (weights_classid )) == 1
95
+ if is_binary :
96
+ # In this case there is only one label, the first one
97
+ # which actually acts as the score of the prediction.
98
+ # When > 0.5 then class 1, when < 0.5 then class 0
99
+ classlabels = typing .cast (list [str ] | list [int ], [classlabels [0 ]])
85
100
86
101
if isinstance (input_expr , VariablesGroup ):
87
102
ordered_features = input_expr .values_value ()
@@ -134,39 +149,53 @@ def build_tree_case(node: dict) -> dict[str | int, ibis.Expr]:
134
149
)
135
150
136
151
# Compute prediction of class itself.
137
- candidate_cls = classlabels [0 ]
138
- candidate_vote = total_votes [candidate_cls ]
139
- for clslabel in classlabels [1 :]:
140
- candidate_cls = optimizer .fold_case (
152
+ if is_binary :
153
+ total_score = total_votes [classlabels [0 ]]
154
+ label_expr = optimizer .fold_case (
141
155
ibis .case ()
142
- .when (total_votes [ clslabel ] > candidate_vote , clslabel )
143
- .else_ (candidate_cls )
156
+ .when (total_score > 0.5 , output_classlabels [ 1 ] )
157
+ .else_ (output_classlabels [ 0 ] )
144
158
.end ()
145
159
)
146
- candidate_vote = optimizer .fold_case (
147
- ibis .case ()
148
- .when (total_votes [clslabel ] > candidate_vote , total_votes [clslabel ])
149
- .else_ (candidate_vote )
150
- .end ()
160
+ # The order matters, for ONNX the VariableGroup is a list of subvariables
161
+ # the names are not important.
162
+ prob_dict = VariablesGroup (
163
+ {
164
+ str (output_classlabels [0 ]): 1.0 - total_score ,
165
+ str (output_classlabels [1 ]): total_score ,
166
+ }
151
167
)
168
+ else :
169
+ candidate_cls = classlabels [0 ]
170
+ candidate_vote = total_votes [candidate_cls ]
171
+ for clslabel in classlabels [1 :]:
172
+ candidate_cls = optimizer .fold_case (
173
+ ibis .case ()
174
+ .when (total_votes [clslabel ] > candidate_vote , clslabel )
175
+ .else_ (candidate_cls )
176
+ .end ()
177
+ )
178
+ candidate_vote = optimizer .fold_case (
179
+ ibis .case ()
180
+ .when (total_votes [clslabel ] > candidate_vote , total_votes [clslabel ])
181
+ .else_ (candidate_vote )
182
+ .end ()
183
+ )
152
184
153
- label_expr = ibis .case ()
154
- for clslabel in classlabels :
155
- label_expr = label_expr .when (candidate_cls == clslabel , clslabel )
156
- label_expr = label_expr .else_ (ibis .null ()).end ()
157
- label_expr = optimizer .fold_case (label_expr )
185
+ label_expr = ibis .case ()
186
+ for clslabel in classlabels :
187
+ label_expr = label_expr .when (candidate_cls == clslabel , clslabel )
188
+ label_expr = label_expr .else_ (ibis .null ()).end ()
189
+ label_expr = optimizer .fold_case (label_expr )
158
190
159
- # Compute probability to return it too.
160
- sum_votes = None
161
- for clslabel in classlabels :
162
- if sum_votes is None :
163
- sum_votes = total_votes [clslabel ]
164
- else :
191
+ # Compute probability to return it too.
192
+ sum_votes = ibis .literal (0.0 )
193
+ for clslabel in classlabels :
165
194
sum_votes = optimizer .fold_operation (sum_votes + total_votes [clslabel ])
166
195
167
- # FIXME: Probabilities are currently broken for gradient boosted trees.
168
- prob_dict = VariablesGroup ()
169
- for clslabel in classlabels :
170
- prob_dict [str (clslabel )] = total_votes [clslabel ] / sum_votes
196
+ # FIXME: Probabilities are currently broken for gradient boosted trees.
197
+ prob_dict = VariablesGroup ()
198
+ for clslabel in classlabels :
199
+ prob_dict [str (clslabel )] = total_votes [clslabel ] / sum_votes
171
200
172
201
return label_expr , prob_dict
0 commit comments