Skip to content

Commit 238cebd

Browse files
committed
Set explicit variable groups
1 parent c127960 commit 238cebd

19 files changed

+78
-62
lines changed

src/mustela/translation/steps/add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import NumericVariablesGroup, VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class AddTranslator(Translator):
@@ -45,7 +45,7 @@ def process(self) -> None:
4545
"When the first operand is a group of columns, the second operand must contain the same number of values"
4646
)
4747
self.set_output(
48-
VariablesGroup(
48+
ValueVariablesGroup(
4949
{
5050
field: (
5151
self._optimizer.fold_operation(

src/mustela/translation/steps/arrayfeatureextractor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis.expr.types
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import ValueVariablesGroup, VariablesGroup
99

1010

1111
class ArrayFeatureExtractorTranslator(Translator):
@@ -58,7 +58,9 @@ def process(self) -> None:
5858
)
5959

6060
# Pick only the columns that are in the list of indicies.
61-
result = VariablesGroup({data_keys[i]: data_values[i] for i in indices})
61+
result = ValueVariablesGroup(
62+
{data_keys[i]: data_values[i] for i in indices}
63+
)
6264
elif isinstance(data, (tuple, list)):
6365
# We are selecting values out of a list of values
6466
# This is usually used to select "classes" out of a list of

src/mustela/translation/steps/cast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import onnx
77

88
from ..translator import Translator
9-
from ..variables import VariablesGroup
9+
from ..variables import ValueVariablesGroup, VariablesGroup
1010

1111
ONNX_TYPES_TO_IBIS: dict[int, ibis.expr.datatypes.DataType] = {
1212
onnx.TensorProto.FLOAT: ibis.expr.datatypes.float32, # 1: FLOAT
@@ -34,7 +34,7 @@ def process(self) -> None:
3434

3535
target_type = ONNX_TYPES_TO_IBIS[to_type]
3636
if isinstance(expr, VariablesGroup):
37-
casted = VariablesGroup(
37+
casted = ValueVariablesGroup(
3838
{
3939
k: self._optimizer.fold_cast(expr.as_value(k).cast(target_type))
4040
for k in expr
@@ -91,7 +91,7 @@ def process(self) -> None:
9191
target_type: ibis.DataType = like_expr.type()
9292

9393
# Now cast each field in the dictionary to the target type.
94-
casted = VariablesGroup(
94+
casted = ValueVariablesGroup(
9595
{
9696
key: self._optimizer.fold_cast(expr.as_value(key).cast(target_type))
9797
for key in expr

src/mustela/translation/steps/concat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import ValueVariablesGroup, VariablesGroup
99

1010

1111
class ConcatTranslator(Translator):
@@ -42,7 +42,8 @@ def _concatenate_columns(cls, translator: Translator) -> VariablesGroup:
4242
This is used by both Concat and FeatureVectorizer translators,
4343
as they both need to concatenate columns.
4444
"""
45-
result = VariablesGroup()
45+
result = ValueVariablesGroup()
46+
4647
for col in translator.inputs:
4748
feature = translator._variables.consume(col)
4849
if isinstance(feature, dict):

src/mustela/translation/steps/div.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
import ibis
66

7-
from mustela.translation.variables import VariablesGroup
8-
97
from ..translator import Translator
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
109

1110

1211
class DivTranslator(Translator):
@@ -36,7 +35,8 @@ def process(self) -> None:
3635
"Div: Second input (divisor) must be a constant list."
3736
)
3837

39-
if isinstance(first_operand, dict):
38+
if isinstance(first_operand, VariablesGroup):
39+
first_operand = NumericVariablesGroup(first_operand)
4040
struct_fields = list(first_operand.keys())
4141
for value in first_operand.values():
4242
if not isinstance(value, ibis.expr.types.NumericValue):
@@ -50,7 +50,7 @@ def process(self) -> None:
5050
if not isinstance(second_arg, (int, float)):
5151
raise ValueError("Div: The second operand must be a numeric value.")
5252
self.set_output(
53-
VariablesGroup(
53+
ValueVariablesGroup(
5454
{
5555
field: (
5656
self._optimizer.fold_operation(
@@ -67,7 +67,7 @@ def process(self) -> None:
6767
"The number of elements in the second operand must match the number of columns in the first operand."
6868
)
6969
self.set_output(
70-
VariablesGroup(
70+
ValueVariablesGroup(
7171
{
7272
field: (
7373
self._optimizer.fold_operation(

src/mustela/translation/steps/imputer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ibis
44

55
from ..translator import Translator
6-
from ..variables import VariablesGroup
6+
from ..variables import ValueVariablesGroup, VariablesGroup
77

88

99
class ImputerTranslator(Translator):
@@ -29,7 +29,7 @@ def process(self) -> None:
2929
raise ValueError(
3030
"Imputer: number of imputed values does not match number of columns"
3131
)
32-
new_expr = VariablesGroup()
32+
new_expr = ValueVariablesGroup()
3333
for i, key in enumerate(keys):
3434
new_expr[key] = ibis.coalesce(expr[key], imputed_values[i])
3535
self.set_output(new_expr)

src/mustela/translation/steps/linearclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import NumericVariablesGroup, VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class LinearClassifierTranslator(Translator):
@@ -48,7 +48,7 @@ def process(self) -> None:
4848
# Standardize input_operand to a columns group,
4949
# so that we can reuse a single implementation.
5050
if not isinstance(input_operand, VariablesGroup):
51-
input_operand = VariablesGroup({"feature": input_operand})
51+
input_operand = ValueVariablesGroup({"feature": input_operand})
5252

5353
num_features = len(input_operand)
5454
num_classes = len(classlabels)
@@ -75,7 +75,7 @@ def process(self) -> None:
7575
score = self._apply_post_transform(score, post_transform)
7676
scores.append(self._optimizer.fold_operation(score))
7777

78-
scores_struct = VariablesGroup(
78+
scores_struct = ValueVariablesGroup(
7979
{str(label): score for label, score in zip(classlabels, scores)}
8080
)
8181

src/mustela/translation/steps/linearreg.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class LinearRegressorTranslator(Translator):
@@ -39,10 +39,8 @@ def process(self) -> None:
3939

4040
input_operand = self._variables.consume(self._inputs[0])
4141

42-
if isinstance(input_operand, dict):
43-
input_operand = typing.cast(
44-
dict[str, ibis.expr.types.NumericValue], input_operand
45-
)
42+
if isinstance(input_operand, VariablesGroup):
43+
input_operand = NumericVariablesGroup(input_operand)
4644
num_features = len(input_operand)
4745

4846
if len(coefficients) != targets * num_features:
@@ -70,7 +68,7 @@ def process(self) -> None:
7068
prediction
7169
)
7270

73-
self.set_output(VariablesGroup(results))
71+
self.set_output(ValueVariablesGroup(results))
7472

7573
else:
7674
input_operand = typing.cast(ibis.expr.types.NumericValue, input_operand)

src/mustela/translation/steps/matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import ValueVariablesGroup
99

1010

1111
class MatMulTranslator(Translator):
@@ -101,7 +101,7 @@ def process(self) -> None:
101101
result = result_list[0]
102102
else:
103103
# Return a dict of output expressions if there are multiple output columns.
104-
result = VariablesGroup(
104+
result = ValueVariablesGroup(
105105
{f"out_{j}": result_list[j] for j in range(output_dim)}
106106
)
107107
self.set_output(result)
@@ -131,7 +131,7 @@ def process(self) -> None:
131131
result = result_list[0]
132132
self.set_output(result_list[0])
133133
else:
134-
result = VariablesGroup(
134+
result = ValueVariablesGroup(
135135
{f"out_{j}": result_list[j] for j in range(output_dim)}
136136
)
137137
self.set_output(result)

src/mustela/translation/steps/mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import NumericVariablesGroup, VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class MulTranslator(Translator):
@@ -45,7 +45,7 @@ def process(self) -> None:
4545
"When the first operand is a group of columns, the second operand must contain the same number of values"
4646
)
4747
self.set_output(
48-
VariablesGroup(
48+
ValueVariablesGroup(
4949
{
5050
field: (
5151
self._optimizer.fold_operation(

src/mustela/translation/steps/onehotencoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import ValueVariablesGroup
99

1010

1111
class OneHotEncoderTranslator(Translator):
@@ -44,5 +44,7 @@ def process(self) -> None:
4444
# by subsequent operations, so preserving them makes sense.
4545
casted_variables = self.preserve(*casted_variables)
4646
self.set_output(
47-
VariablesGroup({cat: casted_variables[i] for i, cat in enumerate(cats)})
47+
ValueVariablesGroup(
48+
{cat: casted_variables[i] for i, cat in enumerate(cats)}
49+
)
4850
)

src/mustela/translation/steps/scaler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class ScalerTranslator(Translator):
@@ -36,10 +36,8 @@ def process(self) -> None:
3636
if not isinstance(type_check_var, ibis.expr.types.NumericValue):
3737
raise ValueError("Scaler: The input operand must be numeric.")
3838

39-
if isinstance(input_operand, dict):
40-
input_operand = typing.cast(
41-
dict[str, ibis.expr.types.NumericValue], input_operand
42-
)
39+
if isinstance(input_operand, VariablesGroup):
40+
input_operand = NumericVariablesGroup(input_operand)
4341

4442
# If the attributes are len=1,
4543
# it means to apply the same value to all inputs.
@@ -55,7 +53,7 @@ def process(self) -> None:
5553
)
5654

5755
self.set_output(
58-
VariablesGroup(
56+
ValueVariablesGroup(
5957
{
6058
field: self._optimizer.fold_operation(
6159
(val - offset[i]) * scale[i]

src/mustela/translation/steps/softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ..translator import Translator
8-
from ..variables import NumericVariablesGroup, VariablesGroup
8+
from ..variables import NumericVariablesGroup, ValueVariablesGroup, VariablesGroup
99

1010

1111
class SoftmaxTranslator(Translator):
@@ -64,7 +64,7 @@ def compute_softmax(
6464
sum_exp = sum(exp_dict.values())
6565

6666
# Multi columns case: softmax = exp(column_exp) / (exponents_sum)
67-
return VariablesGroup({k: exp_dict[k] / sum_exp for k in data.keys()})
67+
return ValueVariablesGroup({k: exp_dict[k] / sum_exp for k in data.keys()})
6868
elif isinstance(data, ibis.Expr):
6969
# Single column case: softmax(x) = exp(x) / exp(x) = 1
7070
return ibis.literal(1.0)

src/mustela/translation/steps/sub.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
import ibis
66

7-
from mustela.translation.variables import VariablesGroup
7+
from mustela.translation.variables import (
8+
NumericVariablesGroup,
9+
ValueVariablesGroup,
10+
VariablesGroup,
11+
)
812

913
from ..translator import Translator
1014

@@ -37,16 +41,14 @@ def process(self) -> None:
3741
raise ValueError("Sub: The first operand must be a numeric value.")
3842

3943
sub_values = list(second_operand)
40-
if isinstance(first_operand, dict):
41-
first_operand = typing.cast(
42-
dict[str, ibis.expr.types.NumericValue], first_operand
43-
)
44+
if isinstance(first_operand, VariablesGroup):
45+
first_operand = NumericVariablesGroup(first_operand)
4446
struct_fields = list(first_operand.keys())
4547
assert len(sub_values) == len(struct_fields), (
4648
f"The number of values in the initializer ({len(sub_values)}) must match the number of fields ({len(struct_fields)}"
4749
)
4850
self.set_output(
49-
VariablesGroup(
51+
ValueVariablesGroup(
5052
{
5153
field: (
5254
self._optimizer.fold_operation(

src/mustela/translation/steps/trees/classifier.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ibis
66

77
from ...translator import Translator
8-
from ...variables import VariablesGroup
8+
from ...variables import ValueVariablesGroup, VariablesGroup
99
from ..linearclass import LinearClassifierTranslator
1010
from ..softmax import SoftmaxTranslator
1111
from .tree import build_tree, mode_to_condition
@@ -49,7 +49,7 @@ def process(self) -> None:
4949
if post_transform == "SOFTMAX":
5050
prob_colgroup = SoftmaxTranslator.compute_softmax(self, prob_colgroup)
5151
elif post_transform == "LOGISTIC":
52-
prob_colgroup = VariablesGroup(
52+
prob_colgroup = ValueVariablesGroup(
5353
{
5454
lbl: LinearClassifierTranslator._apply_post_transform(
5555
prob_col, post_transform
@@ -159,7 +159,7 @@ def build_tree_case(node: dict) -> dict[str | int, ibis.Expr]:
159159
)
160160
# The order matters, for ONNX the VariableGroup is a list of subvariables
161161
# the names are not important.
162-
prob_dict = VariablesGroup(
162+
prob_dict = ValueVariablesGroup(
163163
{
164164
str(output_classlabels[0]): 1.0 - total_score,
165165
str(output_classlabels[1]): total_score,
@@ -194,13 +194,13 @@ def build_tree_case(node: dict) -> dict[str | int, ibis.Expr]:
194194
if post_transform == "SOFTMAX":
195195
# Use softmax as an hint that we are doing a gradient boosted tree,
196196
# thus the probability is the same as the score and should not be normalized
197-
prob_dict = VariablesGroup(
197+
prob_dict = ValueVariablesGroup(
198198
{str(clslabel): total_votes[clslabel] for clslabel in classlabels}
199199
)
200200
else:
201201
# Compute probability to return it too.
202202
sum_votes = sum(total_votes[clslabel] for clslabel in classlabels)
203-
prob_dict = VariablesGroup(
203+
prob_dict = ValueVariablesGroup(
204204
{
205205
str(clslabel): total_votes[clslabel] / sum_votes
206206
for clslabel in classlabels

0 commit comments

Comments
 (0)