1
+ """Prase tree definitions and return a graph of nodes."""
1
2
import itertools
3
+ import typing
2
4
3
5
import ibis
4
6
@@ -10,16 +12,16 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
10
12
11
13
The tree is built based on the node and attributes of the translator.
12
14
"""
13
- nodes_treeids : list [int ] = translator ._attributes ["nodes_treeids" ]
14
- nodes_nodeids : list [int ] = translator ._attributes ["nodes_nodeids" ]
15
- nodes_modes : list [str ] = translator ._attributes ["nodes_modes" ]
16
- nodes_truenodeids : list [int ] = translator ._attributes ["nodes_truenodeids" ]
17
- nodes_falsenodeids : list [int ] = translator ._attributes ["nodes_falsenodeids" ]
18
- nodes_thresholds : list [float ] = translator ._attributes ["nodes_values" ]
19
- nodes_featureids : list [int ] = translator ._attributes ["nodes_featureids" ]
20
- nodes_missing_value_tracks_true : list [int ] = translator ._attributes [
15
+ nodes_treeids = typing . cast ( list [int ], translator ._attributes ["nodes_treeids" ])
16
+ nodes_nodeids = typing . cast ( list [int ], translator ._attributes ["nodes_nodeids" ])
17
+ nodes_modes = typing . cast ( list [str ], translator ._attributes ["nodes_modes" ])
18
+ nodes_truenodeids = typing . cast ( list [int ], translator ._attributes ["nodes_truenodeids" ])
19
+ nodes_falsenodeids = typing . cast ( list [int ], translator ._attributes ["nodes_falsenodeids" ])
20
+ nodes_thresholds = typing . cast ( list [float ], translator ._attributes ["nodes_values" ])
21
+ nodes_featureids = typing . cast ( list [int ], translator ._attributes ["nodes_featureids" ])
22
+ nodes_missing_value_tracks_true = typing . cast ( list [int ], translator ._attributes [
21
23
"nodes_missing_value_tracks_true"
22
- ]
24
+ ])
23
25
node = translator ._node
24
26
25
27
# Assert a few things to ensure we don't ed up genearting a tree with wrong data
@@ -37,19 +39,19 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
37
39
weights = {}
38
40
if node .op_type == "TreeEnsembleClassifier" :
39
41
# Weights for classifier, in this case the weights are per-class
40
- class_nodeids : list [int ] = translator ._attributes ["class_nodeids" ]
41
- class_treeids : list [int ] = translator ._attributes ["class_treeids" ]
42
- class_weights : list [float ] = translator ._attributes ["class_weights" ]
43
- weights_classid : list [int ] = translator ._attributes ["class_ids" ]
42
+ class_nodeids = typing . cast ( list [int ], translator ._attributes ["class_nodeids" ])
43
+ class_treeids = typing . cast ( list [int ], translator ._attributes ["class_treeids" ])
44
+ class_weights = typing . cast ( list [float ], translator ._attributes ["class_weights" ])
45
+ weights_classid = typing . cast ( list [int ], translator ._attributes ["class_ids" ])
44
46
assert (
45
47
len (class_treeids )
46
48
== len (class_nodeids )
47
49
== len (class_weights )
48
50
== len (weights_classid )
49
51
)
50
- classlabels : None | list [str | int ] = translator ._attributes .get (
52
+ classlabels = typing . cast ( None | list [str | int ], translator ._attributes .get (
51
53
"classlabels_strings"
52
- ) or translator ._attributes .get ("classlabels_int64s" )
54
+ ) or translator ._attributes .get ("classlabels_int64s" ))
53
55
if not classlabels :
54
56
raise ValueError ("Missing class labels when building tree" )
55
57
@@ -61,9 +63,9 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
61
63
)
62
64
elif node .op_type == "TreeEnsembleRegressor" :
63
65
# Weights for the regressor, in this case leaf nodes have only 1 weight
64
- target_weights : list [float ] = translator ._attributes ["target_weights" ]
65
- target_nodeids : list [int ] = translator ._attributes ["target_nodeids" ]
66
- target_treeids : list [int ] = translator ._attributes ["target_treeids" ]
66
+ target_weights = typing . cast ( list [float ], translator ._attributes ["target_weights" ])
67
+ target_nodeids = typing . cast ( list [int ], translator ._attributes ["target_nodeids" ])
68
+ target_treeids = typing . cast ( list [int ], translator ._attributes ["target_treeids" ])
67
69
assert len (target_treeids ) == len (target_nodeids ) == len (target_weights )
68
70
for tree_id , node_id , weight in zip (
69
71
target_treeids , target_nodeids , target_weights
@@ -125,7 +127,13 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
125
127
return {tree_id : trees [tree_id ][0 ] for tree_id in trees }
126
128
127
129
128
- def mode_to_condition (node , feature_expr : ibis .Expr ) -> ibis .Expr :
130
+ def mode_to_condition (node : dict , feature_expr : ibis .Expr ) -> ibis .Expr :
131
+ """Build a comparison expression for a branch node.
132
+
133
+ The comparison is based on the mode of the node and the threshold
134
+ for that noode. The feature will be compared to the threshold
135
+ using the operator defined by the mode.
136
+ """
129
137
threshold = node ["treshold" ]
130
138
if node ["mode" ] == "BRANCH_LEQ" :
131
139
condition = feature_expr <= threshold
0 commit comments