Skip to content

Commit 0835dfd

Browse files
GarrettWuGenesis929
authored andcommitted
feat: add ColumnTransformer save/load (#541)
1 parent db0afc9 commit 0835dfd

File tree

6 files changed

+204
-135
lines changed

6 files changed

+204
-135
lines changed

bigframes/ml/compose.py

+128-5
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,21 @@
1818

1919
from __future__ import annotations
2020

21+
import re
22+
import types
2123
import typing
22-
from typing import List, Optional, Tuple, Union
24+
from typing import cast, List, Optional, Tuple, Union
2325

2426
import bigframes_vendored.sklearn.compose._column_transformer
27+
from google.cloud import bigquery
2528

29+
import bigframes
30+
from bigframes import constants
2631
from bigframes.core import log_adapter
2732
from bigframes.ml import base, core, globals, preprocessing, utils
2833
import bigframes.pandas as bpd
2934

30-
CompilablePreprocessorType = Union[
35+
_PREPROCESSING_TYPES = Union[
3136
preprocessing.OneHotEncoder,
3237
preprocessing.StandardScaler,
3338
preprocessing.MaxAbsScaler,
@@ -36,6 +41,17 @@
3641
preprocessing.LabelEncoder,
3742
]
3843

44+
_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
45+
{
46+
"ML.STANDARD_SCALER": preprocessing.StandardScaler,
47+
"ML.ONE_HOT_ENCODER": preprocessing.OneHotEncoder,
48+
"ML.MAX_ABS_SCALER": preprocessing.MaxAbsScaler,
49+
"ML.MIN_MAX_SCALER": preprocessing.MinMaxScaler,
50+
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
51+
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
52+
}
53+
)
54+
3955

4056
@log_adapter.class_logger
4157
class ColumnTransformer(
@@ -51,7 +67,7 @@ def __init__(
5167
transformers: List[
5268
Tuple[
5369
str,
54-
CompilablePreprocessorType,
70+
_PREPROCESSING_TYPES,
5571
Union[str, List[str]],
5672
]
5773
],
@@ -66,12 +82,12 @@ def __init__(
6682
@property
6783
def transformers_(
6884
self,
69-
) -> List[Tuple[str, CompilablePreprocessorType, str,]]:
85+
) -> List[Tuple[str, _PREPROCESSING_TYPES, str,]]:
7086
"""The collection of transformers as tuples of (name, transformer, column)."""
7187
result: List[
7288
Tuple[
7389
str,
74-
CompilablePreprocessorType,
90+
_PREPROCESSING_TYPES,
7591
str,
7692
]
7793
] = []
@@ -89,6 +105,96 @@ def transformers_(
89105

90106
return result
91107

108+
@classmethod
109+
def _from_bq(
110+
cls, session: bigframes.Session, model: bigquery.Model
111+
) -> ColumnTransformer:
112+
col_transformer = cls._extract_from_bq_model(model)
113+
col_transformer._bqml_model = core.BqmlModel(session, model)
114+
115+
return col_transformer
116+
117+
@classmethod
118+
def _extract_from_bq_model(
119+
cls,
120+
bq_model: bigquery.Model,
121+
) -> ColumnTransformer:
122+
"""Extract transformers as ColumnTransformer obj from a BQ Model. Keep the _bqml_model field as None."""
123+
assert "transformColumns" in bq_model._properties
124+
125+
transformers: List[
126+
Tuple[
127+
str,
128+
_PREPROCESSING_TYPES,
129+
Union[str, List[str]],
130+
]
131+
] = []
132+
133+
def camel_to_snake(name):
134+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
135+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
136+
137+
for transform_col in bq_model._properties["transformColumns"]:
138+
# pass the columns that are not transformed
139+
if "transformSql" not in transform_col:
140+
continue
141+
transform_sql: str = cast(dict, transform_col)["transformSql"]
142+
if not transform_sql.startswith("ML."):
143+
continue
144+
145+
found_transformer = False
146+
for prefix in _BQML_TRANSFROM_TYPE_MAPPING:
147+
if transform_sql.startswith(prefix):
148+
transformer_cls = _BQML_TRANSFROM_TYPE_MAPPING[prefix]
149+
transformers.append(
150+
(
151+
camel_to_snake(transformer_cls.__name__),
152+
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
153+
)
154+
)
155+
156+
found_transformer = True
157+
break
158+
if not found_transformer:
159+
raise NotImplementedError(
160+
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
161+
)
162+
163+
return cls(transformers=transformers)
164+
165+
def _merge(
166+
self, bq_model: bigquery.Model
167+
) -> Union[
168+
ColumnTransformer,
169+
preprocessing.StandardScaler,
170+
preprocessing.OneHotEncoder,
171+
preprocessing.MaxAbsScaler,
172+
preprocessing.MinMaxScaler,
173+
preprocessing.KBinsDiscretizer,
174+
preprocessing.LabelEncoder,
175+
]:
176+
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
177+
transformers = self.transformers_
178+
179+
assert len(transformers) > 0
180+
_, transformer_0, column_0 = transformers[0]
181+
columns = [column_0]
182+
for _, transformer, column in transformers[1:]:
183+
# all transformers are the same
184+
if transformer != transformer_0:
185+
return self
186+
columns.append(column)
187+
# all feature columns are transformed
188+
if sorted(
189+
[
190+
cast(str, feature_column.name)
191+
for feature_column in bq_model.feature_columns
192+
]
193+
) == sorted(columns):
194+
return transformer_0
195+
196+
return self
197+
92198
def _compile_to_sql(
93199
self,
94200
columns: List[str],
@@ -143,3 +249,20 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
143249
bpd.DataFrame,
144250
df[self._output_names],
145251
)
252+
253+
def to_gbq(self, model_name: str, replace: bool = False) -> ColumnTransformer:
254+
"""Save the transformer as a BigQuery model.
255+
256+
Args:
257+
model_name (str):
258+
the name of the model.
259+
replace (bool, default False):
260+
whether to replace if the model already exists. Default to False.
261+
262+
Returns:
263+
ColumnTransformer: saved model."""
264+
if not self._bqml_model:
265+
raise RuntimeError("A transformer must be fitted before it can be saved")
266+
267+
new_model = self._bqml_model.copy(model_name, replace)
268+
return new_model.session.read_gbq_model(model_name)

bigframes/ml/loader.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.constants as constants
2424
from bigframes.ml import (
2525
cluster,
26+
compose,
2627
decomposition,
2728
ensemble,
2829
forecasting,
@@ -79,6 +80,7 @@ def from_bq(
7980
llm.PaLM2TextGenerator,
8081
llm.PaLM2TextEmbeddingGenerator,
8182
pipeline.Pipeline,
83+
compose.ColumnTransformer,
8284
]:
8385
"""Load a BQML model to BigQuery DataFrames ML.
8486
@@ -89,22 +91,32 @@ def from_bq(
8991
Returns:
9092
A BigQuery DataFrames ML model object.
9193
"""
94+
# TODO(garrettwu): the entire condition only to TRANSFORM_ONLY when b/331679273 is fixed.
95+
if (
96+
bq_model.model_type == "TRANSFORM_ONLY"
97+
or bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
98+
and "transformColumns" in bq_model._properties
99+
and not _is_bq_model_remote(bq_model)
100+
):
101+
return _transformer_from_bq(session, bq_model)
102+
92103
if _is_bq_model_pipeline(bq_model):
93104
return pipeline.Pipeline._from_bq(session, bq_model)
94105

95106
return _model_from_bq(session, bq_model)
96107

97108

109+
def _transformer_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
110+
# TODO(garrettwu): add other transformers
111+
return compose.ColumnTransformer._from_bq(session, bq_model)
112+
113+
98114
def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
99115
if bq_model.model_type in _BQML_MODEL_TYPE_MAPPING:
100116
return _BQML_MODEL_TYPE_MAPPING[bq_model.model_type]._from_bq( # type: ignore
101117
session=session, model=bq_model
102118
)
103-
if (
104-
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
105-
and "remoteModelInfo" in bq_model._properties
106-
and "endpoint" in bq_model._properties["remoteModelInfo"]
107-
):
119+
if _is_bq_model_remote(bq_model):
108120
# Parse the remote model endpoint
109121
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
110122
model_endpoint = bqml_endpoint.split("/")[-1]
@@ -121,3 +133,11 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
121133

122134
def _is_bq_model_pipeline(bq_model: bigquery.Model) -> bool:
123135
return "transformColumns" in bq_model._properties
136+
137+
138+
def _is_bq_model_remote(bq_model: bigquery.Model) -> bool:
139+
return (
140+
bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
141+
and "remoteModelInfo" in bq_model._properties
142+
and "endpoint" in bq_model._properties["remoteModelInfo"]
143+
)

bigframes/ml/pipeline.py

+3-110
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from __future__ import annotations
2020

21-
from typing import cast, List, Optional, Tuple, Union
21+
from typing import List, Optional, Tuple, Union
2222

2323
import bigframes_vendored.sklearn.pipeline
2424
from google.cloud import bigquery
@@ -83,8 +83,8 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
8383

8484
@classmethod
8585
def _from_bq(cls, session: bigframes.Session, bq_model: bigquery.Model) -> Pipeline:
86-
col_transformer = _extract_as_column_transformer(bq_model)
87-
transform = _merge_column_transformer(bq_model, col_transformer)
86+
col_transformer = compose.ColumnTransformer._extract_from_bq_model(bq_model)
87+
transform = col_transformer._merge(bq_model)
8888

8989
estimator = loader._model_from_bq(session, bq_model)
9090
return cls([("transform", transform), ("estimator", estimator)])
@@ -138,110 +138,3 @@ def to_gbq(self, model_name: str, replace: bool = False) -> Pipeline:
138138
new_model = self._estimator._bqml_model.copy(model_name, replace)
139139

140140
return new_model.session.read_gbq_model(model_name)
141-
142-
143-
def _extract_as_column_transformer(
144-
bq_model: bigquery.Model,
145-
) -> compose.ColumnTransformer:
146-
"""Extract transformers as ColumnTransformer obj from a BQ Model."""
147-
assert "transformColumns" in bq_model._properties
148-
149-
transformers: List[
150-
Tuple[
151-
str,
152-
Union[
153-
preprocessing.OneHotEncoder,
154-
preprocessing.StandardScaler,
155-
preprocessing.MaxAbsScaler,
156-
preprocessing.MinMaxScaler,
157-
preprocessing.KBinsDiscretizer,
158-
preprocessing.LabelEncoder,
159-
],
160-
Union[str, List[str]],
161-
]
162-
] = []
163-
for transform_col in bq_model._properties["transformColumns"]:
164-
# pass the columns that are not transformed
165-
if "transformSql" not in transform_col:
166-
continue
167-
168-
transform_sql: str = cast(dict, transform_col)["transformSql"]
169-
if transform_sql.startswith("ML.STANDARD_SCALER"):
170-
transformers.append(
171-
(
172-
"standard_scaler",
173-
*preprocessing.StandardScaler._parse_from_sql(transform_sql),
174-
)
175-
)
176-
elif transform_sql.startswith("ML.ONE_HOT_ENCODER"):
177-
transformers.append(
178-
(
179-
"ont_hot_encoder",
180-
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
181-
)
182-
)
183-
elif transform_sql.startswith("ML.MAX_ABS_SCALER"):
184-
transformers.append(
185-
(
186-
"max_abs_scaler",
187-
*preprocessing.MaxAbsScaler._parse_from_sql(transform_sql),
188-
)
189-
)
190-
elif transform_sql.startswith("ML.MIN_MAX_SCALER"):
191-
transformers.append(
192-
(
193-
"min_max_scaler",
194-
*preprocessing.MinMaxScaler._parse_from_sql(transform_sql),
195-
)
196-
)
197-
elif transform_sql.startswith("ML.BUCKETIZE"):
198-
transformers.append(
199-
(
200-
"k_bins_discretizer",
201-
*preprocessing.KBinsDiscretizer._parse_from_sql(transform_sql),
202-
)
203-
)
204-
elif transform_sql.startswith("ML.LABEL_ENCODER"):
205-
transformers.append(
206-
(
207-
"label_encoder",
208-
*preprocessing.LabelEncoder._parse_from_sql(transform_sql),
209-
)
210-
)
211-
else:
212-
raise NotImplementedError(
213-
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
214-
)
215-
216-
return compose.ColumnTransformer(transformers=transformers)
217-
218-
219-
def _merge_column_transformer(
220-
bq_model: bigquery.Model, column_transformer: compose.ColumnTransformer
221-
) -> Union[
222-
compose.ColumnTransformer,
223-
preprocessing.StandardScaler,
224-
preprocessing.OneHotEncoder,
225-
preprocessing.MaxAbsScaler,
226-
preprocessing.MinMaxScaler,
227-
preprocessing.KBinsDiscretizer,
228-
preprocessing.LabelEncoder,
229-
]:
230-
"""Try to merge the column transformer to a simple transformer."""
231-
transformers = column_transformer.transformers_
232-
233-
assert len(transformers) > 0
234-
_, transformer_0, column_0 = transformers[0]
235-
columns = [column_0]
236-
for _, transformer, column in transformers[1:]:
237-
# all transformers are the same
238-
if transformer != transformer_0:
239-
return column_transformer
240-
columns.append(column)
241-
# all feature columns are transformed
242-
if sorted(
243-
[cast(str, feature_column.name) for feature_column in bq_model.feature_columns]
244-
) == sorted(columns):
245-
return transformer_0
246-
247-
return column_transformer

bigframes/session/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ def read_gbq_model(self, model_name: str):
953953
to load from the default project.
954954
955955
Returns:
956-
A bigframes.ml Model wrapping the model.
956+
A bigframes.ml Model, Transformer or Pipeline wrapping the model.
957957
"""
958958
import bigframes.ml.loader
959959

0 commit comments

Comments
 (0)