18
18
19
19
from __future__ import annotations
20
20
21
+ import re
22
+ import types
21
23
import typing
22
- from typing import List , Optional , Tuple , Union
24
+ from typing import cast , List , Optional , Tuple , Union
23
25
24
26
import bigframes_vendored .sklearn .compose ._column_transformer
27
+ from google .cloud import bigquery
25
28
29
+ import bigframes
30
+ from bigframes import constants
26
31
from bigframes .core import log_adapter
27
32
from bigframes .ml import base , core , globals , preprocessing , utils
28
33
import bigframes .pandas as bpd
29
34
30
- CompilablePreprocessorType = Union [
35
+ _PREPROCESSING_TYPES = Union [
31
36
preprocessing .OneHotEncoder ,
32
37
preprocessing .StandardScaler ,
33
38
preprocessing .MaxAbsScaler ,
36
41
preprocessing .LabelEncoder ,
37
42
]
38
43
44
+ _BQML_TRANSFROM_TYPE_MAPPING = types .MappingProxyType (
45
+ {
46
+ "ML.STANDARD_SCALER" : preprocessing .StandardScaler ,
47
+ "ML.ONE_HOT_ENCODER" : preprocessing .OneHotEncoder ,
48
+ "ML.MAX_ABS_SCALER" : preprocessing .MaxAbsScaler ,
49
+ "ML.MIN_MAX_SCALER" : preprocessing .MinMaxScaler ,
50
+ "ML.BUCKETIZE" : preprocessing .KBinsDiscretizer ,
51
+ "ML.LABEL_ENCODER" : preprocessing .LabelEncoder ,
52
+ }
53
+ )
54
+
39
55
40
56
@log_adapter .class_logger
41
57
class ColumnTransformer (
@@ -51,7 +67,7 @@ def __init__(
51
67
transformers : List [
52
68
Tuple [
53
69
str ,
54
- CompilablePreprocessorType ,
70
+ _PREPROCESSING_TYPES ,
55
71
Union [str , List [str ]],
56
72
]
57
73
],
@@ -66,12 +82,12 @@ def __init__(
66
82
@property
67
83
def transformers_ (
68
84
self ,
69
- ) -> List [Tuple [str , CompilablePreprocessorType , str ,]]:
85
+ ) -> List [Tuple [str , _PREPROCESSING_TYPES , str ,]]:
70
86
"""The collection of transformers as tuples of (name, transformer, column)."""
71
87
result : List [
72
88
Tuple [
73
89
str ,
74
- CompilablePreprocessorType ,
90
+ _PREPROCESSING_TYPES ,
75
91
str ,
76
92
]
77
93
] = []
@@ -89,6 +105,96 @@ def transformers_(
89
105
90
106
return result
91
107
108
+ @classmethod
109
+ def _from_bq (
110
+ cls , session : bigframes .Session , model : bigquery .Model
111
+ ) -> ColumnTransformer :
112
+ col_transformer = cls ._extract_from_bq_model (model )
113
+ col_transformer ._bqml_model = core .BqmlModel (session , model )
114
+
115
+ return col_transformer
116
+
117
+ @classmethod
118
+ def _extract_from_bq_model (
119
+ cls ,
120
+ bq_model : bigquery .Model ,
121
+ ) -> ColumnTransformer :
122
+ """Extract transformers as ColumnTransformer obj from a BQ Model. Keep the _bqml_model field as None."""
123
+ assert "transformColumns" in bq_model ._properties
124
+
125
+ transformers : List [
126
+ Tuple [
127
+ str ,
128
+ _PREPROCESSING_TYPES ,
129
+ Union [str , List [str ]],
130
+ ]
131
+ ] = []
132
+
133
+ def camel_to_snake (name ):
134
+ name = re .sub ("(.)([A-Z][a-z]+)" , r"\1_\2" , name )
135
+ return re .sub ("([a-z0-9])([A-Z])" , r"\1_\2" , name ).lower ()
136
+
137
+ for transform_col in bq_model ._properties ["transformColumns" ]:
138
+ # pass the columns that are not transformed
139
+ if "transformSql" not in transform_col :
140
+ continue
141
+ transform_sql : str = cast (dict , transform_col )["transformSql" ]
142
+ if not transform_sql .startswith ("ML." ):
143
+ continue
144
+
145
+ found_transformer = False
146
+ for prefix in _BQML_TRANSFROM_TYPE_MAPPING :
147
+ if transform_sql .startswith (prefix ):
148
+ transformer_cls = _BQML_TRANSFROM_TYPE_MAPPING [prefix ]
149
+ transformers .append (
150
+ (
151
+ camel_to_snake (transformer_cls .__name__ ),
152
+ * transformer_cls ._parse_from_sql (transform_sql ), # type: ignore
153
+ )
154
+ )
155
+
156
+ found_transformer = True
157
+ break
158
+ if not found_transformer :
159
+ raise NotImplementedError (
160
+ f"Unsupported transformer type. { constants .FEEDBACK_LINK } "
161
+ )
162
+
163
+ return cls (transformers = transformers )
164
+
165
+ def _merge (
166
+ self , bq_model : bigquery .Model
167
+ ) -> Union [
168
+ ColumnTransformer ,
169
+ preprocessing .StandardScaler ,
170
+ preprocessing .OneHotEncoder ,
171
+ preprocessing .MaxAbsScaler ,
172
+ preprocessing .MinMaxScaler ,
173
+ preprocessing .KBinsDiscretizer ,
174
+ preprocessing .LabelEncoder ,
175
+ ]:
176
+ """Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
177
+ transformers = self .transformers_
178
+
179
+ assert len (transformers ) > 0
180
+ _ , transformer_0 , column_0 = transformers [0 ]
181
+ columns = [column_0 ]
182
+ for _ , transformer , column in transformers [1 :]:
183
+ # all transformers are the same
184
+ if transformer != transformer_0 :
185
+ return self
186
+ columns .append (column )
187
+ # all feature columns are transformed
188
+ if sorted (
189
+ [
190
+ cast (str , feature_column .name )
191
+ for feature_column in bq_model .feature_columns
192
+ ]
193
+ ) == sorted (columns ):
194
+ return transformer_0
195
+
196
+ return self
197
+
92
198
def _compile_to_sql (
93
199
self ,
94
200
columns : List [str ],
@@ -143,3 +249,20 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
143
249
bpd .DataFrame ,
144
250
df [self ._output_names ],
145
251
)
252
+
253
+ def to_gbq (self , model_name : str , replace : bool = False ) -> ColumnTransformer :
254
+ """Save the transformer as a BigQuery model.
255
+
256
+ Args:
257
+ model_name (str):
258
+ the name of the model.
259
+ replace (bool, default False):
260
+ whether to replace if the model already exists. Default to False.
261
+
262
+ Returns:
263
+ ColumnTransformer: saved model."""
264
+ if not self ._bqml_model :
265
+ raise RuntimeError ("A transformer must be fitted before it can be saved" )
266
+
267
+ new_model = self ._bqml_model .copy (model_name , replace )
268
+ return new_model .session .read_gbq_model (model_name )
0 commit comments