Skip to content

Commit bdeb617

Browse files
committed
Moveing forward with typing
1 parent 3ad2810 commit bdeb617

17 files changed

+178
-100
lines changed

src/mustela/translation/steps/add.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ibis
55

66
from ..translator import Translator
7+
from ..variables import NumericVariablesGroup, VariablesGroup
78

89

910
class AddTranslator(Translator):
@@ -25,33 +26,33 @@ def process(self) -> None:
2526
raise NotImplementedError("Add: Second input (divisor) must be a constant list.")
2627

2728
type_check_var = first_operand
28-
if isinstance(type_check_var, dict):
29+
if isinstance(type_check_var, VariablesGroup):
2930
type_check_var = next(iter(type_check_var.values()), None)
3031
if not isinstance(type_check_var, ibis.expr.types.NumericValue):
3132
raise ValueError("Add: The first operand must be a numeric value.")
3233

3334
add_values = list(second_operand)
34-
if isinstance(first_operand, dict):
35-
first_operand = typing.cast(dict[str, ibis.expr.types.NumericValue], first_operand)
35+
if isinstance(first_operand, VariablesGroup):
36+
first_operand = NumericVariablesGroup(first_operand)
3637
struct_fields = list(first_operand.keys())
3738
if len(add_values) != len(struct_fields):
3839
# TODO: Implement dividing by a single value,
3940
# see Div implementation.
4041
raise ValueError(
4142
"When the first operand is a group of columns, the second operand must contain the same number of values"
4243
)
43-
self._variables[self._output_name] = {
44+
self.set_output(VariablesGroup({
4445
field: (
4546
self._optimizer.fold_operation(first_operand[field] + add_values[i])
4647
)
4748
for i, field in enumerate(struct_fields)
48-
}
49+
}))
4950
else:
5051
if len(add_values) != 1:
5152
raise ValueError(
5253
"When the first operand is a single column, the second operand must contain exactly 1 value"
5354
)
5455
first_operand = typing.cast(ibis.expr.types.NumericValue, first_operand)
55-
self._variables[self._output_name] = self._optimizer.fold_operation(
56+
self.set_output(self._optimizer.fold_operation(
5657
first_operand + add_values[0]
57-
)
58+
))

src/mustela/translation/steps/arrayfeatureextractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ibis.expr.types
44

55
from ..translator import Translator
6+
from ..variables import VariablesGroup
67

78

89
class ArrayFeatureExtractorTranslator(Translator):
@@ -34,7 +35,7 @@ def process(self) -> None:
3435
data = self._variables.consume(self.inputs[0])
3536
indices = self._variables.consume(self.inputs[1])
3637

37-
if isinstance(data, dict):
38+
if isinstance(data, VariablesGroup):
3839
# We are selecting a set of columns out of a column group
3940

4041
# This expects that dictionaries are sorted by insertion order
@@ -49,7 +50,7 @@ def process(self) -> None:
4950
raise ValueError("Indices requested are more than the available numer of columns.")
5051

5152
# Pick only the columns that are in the list of indicies.
52-
result = {data_keys[i]: data[i] for i in indices}
53+
result = VariablesGroup({data_keys[i]: data[i] for i in indices})
5354
elif isinstance(data, (tuple, list)):
5455
# We are selecting values out of a list of values
5556
# This is usually used to select "classes" out of a list of

src/mustela/translation/steps/cast.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""Translators for Cast and CastLike operations"""
2+
import typing
23

34
import onnx
45

6+
import ibis
7+
58
from ..translator import Translator
9+
from ..variables import VariablesGroup
610

7-
ONNX_TYPES_TO_IBIS = {
8-
onnx.TensorProto.FLOAT: "float32", # 1: FLOAT
9-
onnx.TensorProto.DOUBLE: "float64", # 11: DOUBLE
10-
onnx.TensorProto.STRING: "string", # 8: STRING
11-
onnx.TensorProto.INT64: "int64", # 7: INT64
12-
onnx.TensorProto.BOOL: "bool", # 9: BOOL
11+
ONNX_TYPES_TO_IBIS: dict[int, ibis.expr.datatypes.DataType] = {
12+
onnx.TensorProto.FLOAT: ibis.expr.datatypes.float32, # 1: FLOAT
13+
onnx.TensorProto.DOUBLE: ibis.expr.datatypes.float64, # 11: DOUBLE
14+
onnx.TensorProto.STRING: ibis.expr.datatypes.string, # 8: STRING
15+
onnx.TensorProto.INT64: ibis.expr.datatypes.int64, # 7: INT64
16+
onnx.TensorProto.BOOL: ibis.expr.datatypes.boolean, # 9: BOOL
1317
}
1418

1519

@@ -23,19 +27,23 @@ def process(self) -> None:
2327
"""Performs the translation and set the output variable."""
2428
# https://onnx.ai/onnx/operators/onnx__Cast.html
2529
expr = self._variables.consume(self.inputs[0])
26-
to_type = self._attributes["to"]
27-
if to_type in ONNX_TYPES_TO_IBIS:
28-
target_type = ONNX_TYPES_TO_IBIS[to_type]
29-
if isinstance(expr, dict):
30-
casted = {
31-
k: self._optimizer.fold_cast(expr[k].cast(target_type))
32-
for k in expr
33-
}
34-
self.set_output(casted)
35-
else:
36-
self.set_output(self._optimizer.fold_cast(expr.cast(target_type)))
37-
else:
30+
to_type: int = typing.cast(int, self._attributes["to"])
31+
if to_type not in ONNX_TYPES_TO_IBIS:
3832
raise NotImplementedError(f"Cast: type {to_type} not supported")
33+
34+
target_type = ONNX_TYPES_TO_IBIS[to_type]
35+
if isinstance(expr, VariablesGroup):
36+
casted = VariablesGroup({
37+
k: self._optimizer.fold_cast(expr.as_value(k).cast(target_type))
38+
for k in expr
39+
})
40+
self.set_output(casted)
41+
elif isinstance(expr, ibis.Value):
42+
self.set_output(self._optimizer.fold_cast(expr.cast(target_type)))
43+
else:
44+
raise ValueError(
45+
f"Cast: expected a column group or a single column. Got {type(expr)}"
46+
)
3947

4048

4149
class CastLikeTranslator(Translator):
@@ -56,25 +64,26 @@ def process(self) -> None:
5664
like_expr = self._variables.consume(self.inputs[1])
5765

5866
# Assert that the first input is a dict (multiple concatenated columns).
59-
if not isinstance(expr, dict):
67+
if not isinstance(expr, VariablesGroup):
6068
# TODO: Support single variables as well.
6169
# This should be fairly straightforward to implement,
6270
# but there hasn't been the need for it yet.
6371
raise NotImplementedError("CastLike currently only supports casting a group of columns.")
6472

6573
# Assert that the second input is a single expression.
66-
if isinstance(like_expr, dict):
74+
if isinstance(like_expr, VariablesGroup):
6775
raise NotImplementedError("CastLike currently only supports casting to a single column type, not a group.")
6876

69-
assert hasattr(like_expr, "type"), (
70-
"CastLike: second input must have a 'type' attribute."
71-
)
77+
if not isinstance(like_expr, ibis.Value):
78+
raise ValueError(
79+
f"CastLike: expected a single column. Got {type(like_expr)}"
80+
)
7281

7382
# Get the target type from the second input.
74-
target_type = like_expr.type()
83+
target_type: ibis.DataType = like_expr.type()
7584

7685
# Now cast each field in the dictionary to the target type.
77-
casted = {
78-
key: self._optimizer.fold_cast(expr[key].cast(target_type)) for key in expr
79-
}
86+
casted = VariablesGroup({
87+
key: self._optimizer.fold_cast(expr.as_value(key).cast(target_type)) for key in expr
88+
})
8089
self.set_output(casted)

src/mustela/translation/steps/concat.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
"""Translator for Concat and FeatureVectorizer operations."""
22

3+
import ibis
4+
5+
import typing
6+
37
from ..translator import Translator
8+
from ..variables import VariablesGroup
49

510

611
class ConcatTranslator(Translator):
@@ -30,13 +35,13 @@ def process(self) -> None:
3035
self.set_output(self._concatenate_columns(self))
3136

3237
@classmethod
33-
def _concatenate_columns(cls, translator: Translator) -> dict:
38+
def _concatenate_columns(cls, translator: Translator) -> VariablesGroup:
3439
"""Implement actual operation of concatenating columns.
3540
3641
This is used by both Concat and FeatureVectorizer translators,
3742
as they both need to concatenate columns.
3843
"""
39-
result = {}
44+
result = VariablesGroup()
4045
for col in translator.inputs:
4146
feature = translator._variables.consume(col)
4247
if isinstance(feature, dict):
@@ -47,8 +52,12 @@ def _concatenate_columns(cls, translator: Translator) -> dict:
4752
for key in feature:
4853
varname = col + "." + key
4954
result[varname] = feature[key]
50-
else:
55+
elif isinstance(feature, ibis.Expr):
5156
result[col] = feature
57+
else:
58+
raise ValueError(
59+
f"Concat: expected a column group or a single column. Got {type(feature)}"
60+
)
5261

5362
return result
5463

@@ -68,7 +77,7 @@ def process(self) -> None:
6877

6978
# We can support this by doing the same as Concat,
7079
# in most cases it's sufficient
71-
ninputdimensions = self._attributes["inputdimensions"]
80+
ninputdimensions = typing.cast(list[int], self._attributes["inputdimensions"])
7281

7382
if len(ninputdimensions) != len(self._inputs):
7483
raise ValueError("Number of input dimensions should be equal to number of inputs.")

src/mustela/translation/steps/div.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import ibis
55

6+
from mustela.translation.variables import VariablesGroup
7+
68
from ..translator import Translator
79

810

@@ -42,29 +44,29 @@ def process(self) -> None:
4244
if not isinstance(second_arg, (int, float)):
4345
raise ValueError("Div: The second operand must be a numeric value.")
4446
self.set_output(
45-
{
47+
VariablesGroup({
4648
field: (
4749
self._optimizer.fold_operation(
4850
first_operand[field] / ibis.literal(second_arg)
4951
)
5052
)
5153
for field in struct_fields
52-
}
54+
})
5355
)
5456
else:
5557
if len(second_arg) != len(first_operand):
5658
raise ValueError(
5759
"The number of elements in the second operand must match the number of columns in the first operand."
5860
)
5961
self.set_output(
60-
{
62+
VariablesGroup({
6163
field: (
6264
self._optimizer.fold_operation(
6365
first_operand[field] / second_arg[i]
6466
)
6567
)
6668
for i, field in enumerate(struct_fields)
67-
}
69+
})
6870
)
6971
else:
7072
if not isinstance(first_operand, ibis.expr.types.NumericValue):

src/mustela/translation/steps/gather.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Defines the translation step for the Gather operation."""
22

3+
from ibis.common.egraph import Variable
4+
5+
from mustela.translation.variables import VariablesGroup
36
from ..translator import Translator
47

58

@@ -38,7 +41,7 @@ def process(self) -> None:
3841
if not isinstance(idx, int):
3942
raise ValueError("Gather: index must be an integer constant")
4043

41-
if isinstance(expr, dict):
44+
if isinstance(expr, VariablesGroup):
4245
keys = list(expr.keys())
4346
if idx < 0 or idx >= len(keys):
4447
raise IndexError("Gather: index out of bounds")
@@ -48,6 +51,6 @@ def process(self) -> None:
4851
# support axis=1, then the index must be 0.
4952
if idx != 0:
5053
raise NotImplementedError(
51-
f"Gather: index {idx} not supported for non-dict expression of type {type(expr)}"
54+
f"Gather: index {idx} not supported for single columns"
5255
)
5356
self.set_output(expr)

src/mustela/translation/steps/imputer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ibis
33

44
from ..translator import Translator
5+
from ..variables import VariablesGroup
56

67

78
class ImputerTranslator(Translator):
@@ -23,13 +24,13 @@ def process(self) -> None:
2324
)
2425

2526
expr = self._variables.consume(self.inputs[0])
26-
if isinstance(expr, dict):
27+
if isinstance(expr, VariablesGroup):
2728
keys = list(expr.keys())
2829
if len(keys) != len(imputed_values):
2930
raise ValueError(
3031
"Imputer: number of imputed values does not match number of columns"
3132
)
32-
new_expr = {}
33+
new_expr = VariablesGroup()
3334
for i, key in enumerate(keys):
3435
new_expr[key] = ibis.coalesce(expr[key], imputed_values[i])
3536
self.set_output(new_expr)

src/mustela/translation/steps/matmul.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ibis
33

44
from ..translator import Translator
5+
from ..variables import VariablesGroup
56

67

78
class MatMulTranslator(Translator):
@@ -86,9 +87,9 @@ def process(self) -> None:
8687
result = result_list[0]
8788
else:
8889
# Return a dict of output expressions if there are multiple output columns.
89-
result = {
90+
result = VariablesGroup({
9091
f"out_{j}": result_list[j] for j in range(output_dim)
91-
}
92+
})
9293
self.set_output(result)
9394
else:
9495
raise NotImplementedError(
@@ -118,9 +119,9 @@ def process(self) -> None:
118119
result = result_list[0]
119120
self._variables[self._output_name] = result_list[0]
120121
else:
121-
result = {
122+
result = VariablesGroup({
122123
f"out_{j}": result_list[j] for j in range(output_dim)
123-
}
124+
})
124125
self.set_output(result)
125126
elif coef_shape[1] == 1:
126127
# This case implies the left operand is a vector of length matching coef_shape[0],

src/mustela/translation/steps/onehotencoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ibis
55

66
from ..translator import Translator
7+
from ..variables import VariablesGroup
78

89

910
class OneHotEncoderTranslator(Translator):
@@ -39,4 +40,4 @@ def process(self) -> None:
3940
# OneHot encoded features are usually consumed multiple times
4041
# by subsequent operations, so preserving them makes sense.
4142
casted_variables = self.preserve(*casted_variables)
42-
self.set_output({cat: casted_variables[i] for i, cat in enumerate(cats)})
43+
self.set_output(VariablesGroup({cat: casted_variables[i] for i, cat in enumerate(cats)}))

src/mustela/translation/steps/softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def compute_softmax(cls, data: ibis.expr.types.NumericValue | dict[str, ibis.exp
4646
sum_exp = expr if sum_exp is None else sum_exp + expr
4747

4848
# Multi columns case: softmax = exp(column_exp) / (exponents_sum)
49-
softmax_result = {k: exp_dict[k] / sum_exp for k in data.keys()}
49+
softmax_result = VariablesGroup({k: exp_dict[k] / sum_exp for k in data.keys()})
5050
else:
5151
# Single column case: softmax(x) = exp(x) / exp(x) = 1
5252
softmax_result = ibis.literal(1.0)

src/mustela/translation/steps/trees/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Translators for trees based models."""
2+
13
from .classifier import TreeEnsembleClassifierTranslator
24
from .regressor import TreeEnsembleRegressorTranslator
35

0 commit comments

Comments
 (0)