Skip to content

Commit 0dea718

Browse files
committed
tree parser done
1 parent 96f9d0f commit 0dea718

File tree

1 file changed

+27
-19
lines changed
  • src/mustela/translation/steps/trees

1 file changed

+27
-19
lines changed

src/mustela/translation/steps/trees/tree.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
"""Prase tree definitions and return a graph of nodes."""
12
import itertools
3+
import typing
24

35
import ibis
46

@@ -10,16 +12,16 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
1012
1113
The tree is built based on the node and attributes of the translator.
1214
"""
13-
nodes_treeids: list[int] = translator._attributes["nodes_treeids"]
14-
nodes_nodeids: list[int] = translator._attributes["nodes_nodeids"]
15-
nodes_modes: list[str] = translator._attributes["nodes_modes"]
16-
nodes_truenodeids: list[int] = translator._attributes["nodes_truenodeids"]
17-
nodes_falsenodeids: list[int] = translator._attributes["nodes_falsenodeids"]
18-
nodes_thresholds: list[float] = translator._attributes["nodes_values"]
19-
nodes_featureids: list[int] = translator._attributes["nodes_featureids"]
20-
nodes_missing_value_tracks_true: list[int] = translator._attributes[
15+
nodes_treeids = typing.cast(list[int], translator._attributes["nodes_treeids"])
16+
nodes_nodeids = typing.cast(list[int], translator._attributes["nodes_nodeids"])
17+
nodes_modes = typing.cast(list[str], translator._attributes["nodes_modes"])
18+
nodes_truenodeids = typing.cast(list[int], translator._attributes["nodes_truenodeids"])
19+
nodes_falsenodeids = typing.cast(list[int], translator._attributes["nodes_falsenodeids"])
20+
nodes_thresholds = typing.cast(list[float], translator._attributes["nodes_values"])
21+
nodes_featureids = typing.cast(list[int], translator._attributes["nodes_featureids"])
22+
nodes_missing_value_tracks_true = typing.cast(list[int], translator._attributes[
2123
"nodes_missing_value_tracks_true"
22-
]
24+
])
2325
node = translator._node
2426

2527
# Assert a few things to ensure we don't ed up genearting a tree with wrong data
@@ -37,19 +39,19 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
3739
weights = {}
3840
if node.op_type == "TreeEnsembleClassifier":
3941
# Weights for classifier, in this case the weights are per-class
40-
class_nodeids: list[int] = translator._attributes["class_nodeids"]
41-
class_treeids: list[int] = translator._attributes["class_treeids"]
42-
class_weights: list[float] = translator._attributes["class_weights"]
43-
weights_classid: list[int] = translator._attributes["class_ids"]
42+
class_nodeids = typing.cast(list[int], translator._attributes["class_nodeids"])
43+
class_treeids = typing.cast(list[int], translator._attributes["class_treeids"])
44+
class_weights = typing.cast(list[float], translator._attributes["class_weights"])
45+
weights_classid = typing.cast(list[int], translator._attributes["class_ids"])
4446
assert (
4547
len(class_treeids)
4648
== len(class_nodeids)
4749
== len(class_weights)
4850
== len(weights_classid)
4951
)
50-
classlabels: None|list[str|int] = translator._attributes.get(
52+
classlabels = typing.cast(None|list[str|int], translator._attributes.get(
5153
"classlabels_strings"
52-
) or translator._attributes.get("classlabels_int64s")
54+
) or translator._attributes.get("classlabels_int64s"))
5355
if not classlabels:
5456
raise ValueError("Missing class labels when building tree")
5557

@@ -61,9 +63,9 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
6163
)
6264
elif node.op_type == "TreeEnsembleRegressor":
6365
# Weights for the regressor, in this case leaf nodes have only 1 weight
64-
target_weights: list[float] = translator._attributes["target_weights"]
65-
target_nodeids: list[int] = translator._attributes["target_nodeids"]
66-
target_treeids: list[int] = translator._attributes["target_treeids"]
66+
target_weights = typing.cast(list[float], translator._attributes["target_weights"])
67+
target_nodeids = typing.cast(list[int], translator._attributes["target_nodeids"])
68+
target_treeids = typing.cast(list[int], translator._attributes["target_treeids"])
6769
assert len(target_treeids) == len(target_nodeids) == len(target_weights)
6870
for tree_id, node_id, weight in zip(
6971
target_treeids, target_nodeids, target_weights
@@ -125,7 +127,13 @@ def build_tree(translator: Translator) -> dict[int, dict[int, dict]]:
125127
return {tree_id: trees[tree_id][0] for tree_id in trees}
126128

127129

128-
def mode_to_condition(node, feature_expr: ibis.Expr) -> ibis.Expr:
130+
def mode_to_condition(node: dict, feature_expr: ibis.Expr) -> ibis.Expr:
131+
"""Build a comparison expression for a branch node.
132+
133+
The comparison is based on the mode of the node and the threshold
134+
for that noode. The feature will be compared to the threshold
135+
using the operator defined by the mode.
136+
"""
129137
threshold = node["treshold"]
130138
if node["mode"] == "BRANCH_LEQ":
131139
condition = feature_expr <= threshold

0 commit comments

Comments
 (0)