diff --git a/bigframes/ml/base.py b/bigframes/ml/base.py index 6ae06c9d9f..81181b58cf 100644 --- a/bigframes/ml/base.py +++ b/bigframes/ml/base.py @@ -198,10 +198,6 @@ def _extract_output_names(self): # pass the columns that are not transformed if "transformSql" not in transform_col_dict: continue - transform_sql: str = transform_col_dict["transformSql"] - if not transform_sql.startswith("ML."): - continue - output_names.append(transform_col_dict["name"]) self._output_names = output_names diff --git a/bigframes/ml/compose.py b/bigframes/ml/compose.py index 3cfa1851f5..08c9761cc3 100644 --- a/bigframes/ml/compose.py +++ b/bigframes/ml/compose.py @@ -46,6 +46,101 @@ ) +class SQLScalarColumnTransformer: + r""" + Wrapper for plain SQL code contained in a ColumnTransformer. + + Create a single column transformer in plain sql. + This transformer can only be used inside ColumnTransformer. + + When creating an instance '{0}' can be used as placeholder + for the column to transform: + + SQLScalarColumnTransformer("{0}+1") + + The default target column gets the prefix 'transformed\_' + but can also be changed when creating an instance: + + SQLScalarColumnTransformer("{0}+1", "inc_{0}") + + **Examples:** + + >>> from bigframes.ml.compose import ColumnTransformer, SQLScalarColumnTransformer + >>> import bigframes.pandas as bpd + + >>> df = bpd.DataFrame({'name': ["James", None, "Mary"], 'city': ["New York", "Boston", None]}) + >>> col_trans = ColumnTransformer([ + ... ("strlen", + ... SQLScalarColumnTransformer("CASE WHEN {0} IS NULL THEN 15 ELSE LENGTH({0}) END"), + ... ['name', 'city']), + ... ]) + >>> col_trans = col_trans.fit(df) + >>> df_transformed = col_trans.transform(df) + >>> df_transformed + transformed_name transformed_city + 0 5 8 + 1 15 6 + 2 4 15 + + [3 rows x 2 columns] + + SQLScalarColumnTransformer can be combined with other transformers, like StandardScaler: + + >>> col_trans = ColumnTransformer([ + ... ("identity", SQLScalarColumnTransformer("{0}", target_column="{0}"), ["col1", "col5"]), + ... ("increment", SQLScalarColumnTransformer("{0}+1", target_column="inc_{0}"), "col2"), + ... ("stdscale", preprocessing.StandardScaler(), "col3"), + ... # ... + ... ]) + + """ + + def __init__(self, sql: str, target_column: str = "transformed_{0}"): + super().__init__() + self._sql = sql + self._target_column = target_column.replace("`", "") + + PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE) + + def escape(self, colname: str): + colname = colname.replace("`", "") + if self.PLAIN_COLNAME_RX.match(colname): + return colname + return f"`{colname}`" + + def _compile_to_sql( + self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None + ) -> List[str]: + if columns is None: + columns = X.columns + result = [] + for column in columns: + current_sql = self._sql.format(self.escape(column)) + current_target_column = self.escape(self._target_column.format(column)) + result.append(f"{current_sql} AS {current_target_column}") + return result + + def __repr__(self): + return f"SQLScalarColumnTransformer(sql='{self._sql}', target_column='{self._target_column}')" + + def __eq__(self, other) -> bool: + return type(self) is type(other) and self._keys() == other._keys() + + def __hash__(self) -> int: + return hash(self._keys()) + + def _keys(self): + return (self._sql, self._target_column) + + +# Type hints for transformers contained in ColumnTransformer +SingleColTransformer = Union[ + preprocessing.PreprocessingType, + impute.SimpleImputer, + SQLScalarColumnTransformer, +] + + @log_adapter.class_logger class ColumnTransformer( base.Transformer, @@ -60,7 +155,7 @@ def __init__( transformers: Iterable[ Tuple[ str, - Union[preprocessing.PreprocessingType, impute.SimpleImputer], + SingleColTransformer, Union[str, Iterable[str]], ] ], @@ -78,14 +173,12 @@ def _keys(self): @property def transformers_( self, - ) -> List[ - Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str] - ]: + ) -> List[Tuple[str, SingleColTransformer, str,]]: """The collection of transformers as tuples of (name, transformer, column).""" result: List[ Tuple[ str, - Union[preprocessing.PreprocessingType, impute.SimpleImputer], + SingleColTransformer, str, ] ] = [] @@ -103,6 +196,8 @@ def transformers_( return result + AS_FLEXNAME_SUFFIX_RX = re.compile("^(.*)\\bAS\\s*`[^`]+`\\s*$", re.IGNORECASE) + @classmethod def _extract_from_bq_model( cls, @@ -114,7 +209,7 @@ def _extract_from_bq_model( transformers_set: Set[ Tuple[ str, - Union[preprocessing.PreprocessingType, impute.SimpleImputer], + SingleColTransformer, Union[str, List[str]], ] ] = set() @@ -130,8 +225,11 @@ def camel_to_snake(name): if "transformSql" not in transform_col_dict: continue transform_sql: str = transform_col_dict["transformSql"] - if not transform_sql.startswith("ML."): - continue + + # workaround for bug in bq_model returning " AS `...`" suffix for flexible names + flex_name_match = cls.AS_FLEXNAME_SUFFIX_RX.match(transform_sql) + if flex_name_match: + transform_sql = flex_name_match.group(1) output_names.append(transform_col_dict["name"]) found_transformer = False @@ -148,8 +246,22 @@ def camel_to_snake(name): found_transformer = True break if not found_transformer: - raise NotImplementedError( - f"Unsupported transformer type. {constants.FEEDBACK_LINK}" + if transform_sql.startswith("ML."): + raise NotImplementedError( + f"Unsupported transformer type. {constants.FEEDBACK_LINK}" + ) + + target_column = transform_col_dict["name"] + sql_transformer = SQLScalarColumnTransformer( + transform_sql, target_column=target_column + ) + input_column_name = f"?{target_column}" + transformers_set.add( + ( + camel_to_snake(sql_transformer.__class__.__name__), + sql_transformer, + input_column_name, + ) ) transformer = cls(transformers=list(transformers_set)) @@ -167,6 +279,8 @@ def _merge( assert len(transformers) > 0 _, transformer_0, column_0 = transformers[0] + if isinstance(transformer_0, SQLScalarColumnTransformer): + return self # SQLScalarColumnTransformer only work inside ColumnTransformer feature_columns_sorted = sorted( [ cast(str, feature_column.name) diff --git a/tests/system/large/ml/test_compose.py b/tests/system/large/ml/test_compose.py index 59c5a1538f..ba963837e5 100644 --- a/tests/system/large/ml/test_compose.py +++ b/tests/system/large/ml/test_compose.py @@ -36,6 +36,32 @@ def test_columntransformer_standalone_fit_and_transform( preprocessing.MinMaxScaler(), ["culmen_length_mm"], ), + ( + "increment", + compose.SQLScalarColumnTransformer("{0}+1"), + ["culmen_length_mm", "flipper_length_mm"], + ), + ( + "length", + compose.SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END", + target_column="len_{0}", + ), + "species", + ), + ( + "ohe", + compose.SQLScalarColumnTransformer( + "CASE WHEN {0}='Adelie Penguin (Pygoscelis adeliae)' THEN 1 ELSE 0 END", + target_column="ohe_adelie", + ), + "species", + ), + ( + "identity", + compose.SQLScalarColumnTransformer("{0}", target_column="{0}"), + ["culmen_length_mm", "flipper_length_mm"], + ), ] ) @@ -51,6 +77,12 @@ def test_columntransformer_standalone_fit_and_transform( "standard_scaled_culmen_length_mm", "min_max_scaled_culmen_length_mm", "standard_scaled_flipper_length_mm", + "transformed_culmen_length_mm", + "transformed_flipper_length_mm", + "len_species", + "ohe_adelie", + "culmen_length_mm", + "flipper_length_mm", ], index=[1633, 1672, 1690], col_exact=False, @@ -70,6 +102,19 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df): preprocessing.StandardScaler(), ["culmen_length_mm", "flipper_length_mm"], ), + ( + "length", + compose.SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END", + target_column="len_{0}", + ), + "species", + ), + ( + "identity", + compose.SQLScalarColumnTransformer("{0}", target_column="{0}"), + ["culmen_length_mm", "flipper_length_mm"], + ), ] ) @@ -83,6 +128,9 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df): "onehotencoded_species", "standard_scaled_culmen_length_mm", "standard_scaled_flipper_length_mm", + "len_species", + "culmen_length_mm", + "flipper_length_mm", ], index=[1633, 1672, 1690], col_exact=False, @@ -102,6 +150,27 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id): preprocessing.StandardScaler(), ["culmen_length_mm", "flipper_length_mm"], ), + ( + "length", + compose.SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END", + target_column="len_{0}", + ), + "species", + ), + ( + "identity", + compose.SQLScalarColumnTransformer("{0}", target_column="{0}"), + ["culmen_length_mm", "flipper_length_mm"], + ), + ( + "flexname", + compose.SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -1 ELSE LENGTH({0}) END", + target_column="Flex {0} Name", + ), + "species", + ), ] ) transformer.fit( @@ -122,6 +191,36 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id): ), ("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"), ("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"), + ( + "sql_scalar_column_transformer", + compose.SQLScalarColumnTransformer( + "CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END", + target_column="len_species", + ), + "?len_species", + ), + ( + "sql_scalar_column_transformer", + compose.SQLScalarColumnTransformer( + "flipper_length_mm", target_column="flipper_length_mm" + ), + "?flipper_length_mm", + ), + ( + "sql_scalar_column_transformer", + compose.SQLScalarColumnTransformer( + "culmen_length_mm", target_column="culmen_length_mm" + ), + "?culmen_length_mm", + ), + ( + "sql_scalar_column_transformer", + compose.SQLScalarColumnTransformer( + "CASE WHEN species IS NULL THEN -1 ELSE LENGTH(species) END ", + target_column="Flex species Name", + ), + "?Flex species Name", + ), ] assert set(reloaded_transformer.transformers) == set(expected) assert reloaded_transformer._bqml_model is not None @@ -136,6 +235,10 @@ def test_columntransformer_save_load(new_penguins_df, dataset_id): "onehotencoded_species", "standard_scaled_culmen_length_mm", "standard_scaled_flipper_length_mm", + "len_species", + "culmen_length_mm", + "flipper_length_mm", + "Flex species Name", ], index=[1633, 1672, 1690], col_exact=False, diff --git a/tests/unit/ml/test_compose.py b/tests/unit/ml/test_compose.py index 60dcc75b63..7643f76e56 100644 --- a/tests/unit/ml/test_compose.py +++ b/tests/unit/ml/test_compose.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock +from google.cloud import bigquery +import pytest import sklearn.compose as sklearn_compose # type: ignore import sklearn.preprocessing as sklearn_preprocessing # type: ignore from bigframes.ml import compose, preprocessing +from bigframes.ml.compose import ColumnTransformer, SQLScalarColumnTransformer +from bigframes.ml.core import BqmlModel +import bigframes.pandas as bpd def test_columntransformer_init_expectedtransforms(): @@ -173,3 +179,403 @@ def test_columntransformer_repr_matches_sklearn(): ) assert bf_column_transformer.__repr__() == sk_column_transformer.__repr__() + + +@pytest.fixture(scope="session") +def mock_X(): + mock_df = mock.create_autospec(spec=bpd.DataFrame) + return mock_df + + +def test_columntransformer_init_with_sqltransformers(): + ident_transformer = SQLScalarColumnTransformer("{0}", target_column="ident_{0}") + len1_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -2 ELSE LENGTH({0}) END", target_column="len1_{0}" + ) + len2_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END", target_column="len2_{0}" + ) + label_transformer = preprocessing.LabelEncoder() + column_transformer = compose.ColumnTransformer( + [ + ( + "ident_trafo", + ident_transformer, + ["culmen_length_mm", "flipper_length_mm"], + ), + ("len1_trafo", len1_transformer, ["species"]), + ("len2_trafo", len2_transformer, ["species"]), + ("label", label_transformer, "species"), + ] + ) + + assert column_transformer.transformers_ == [ + ("ident_trafo", ident_transformer, "culmen_length_mm"), + ("ident_trafo", ident_transformer, "flipper_length_mm"), + ("len1_trafo", len1_transformer, "species"), + ("len2_trafo", len2_transformer, "species"), + ("label", label_transformer, "species"), + ] + + +def test_columntransformer_repr_sqltransformers(): + ident_transformer = SQLScalarColumnTransformer("{0}", target_column="ident_{0}") + len1_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -2 ELSE LENGTH({0}) END", target_column="len1_{0}" + ) + len2_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END", target_column="len2_{0}" + ) + label_transformer = preprocessing.LabelEncoder() + column_transformer = compose.ColumnTransformer( + [ + ( + "ident_trafo", + ident_transformer, + ["culmen_length_mm", "flipper_length_mm"], + ), + ("len1_trafo", len1_transformer, ["species"]), + ("len2_trafo", len2_transformer, ["species"]), + ("label", label_transformer, "species"), + ] + ) + + expected = """ColumnTransformer(transformers=[('ident_trafo', + SQLScalarColumnTransformer(sql='{0}', target_column='ident_{0}'), + ['culmen_length_mm', 'flipper_length_mm']), + ('len1_trafo', + SQLScalarColumnTransformer(sql='CASE WHEN {0} IS NULL THEN -2 ELSE LENGTH({0}) END', target_column='len1_{0}'), + ['species']), + ('len2_trafo', + SQLScalarColumnTransformer(sql='CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END', target_column='len2_{0}'), + ['species']), + ('label', LabelEncoder(), 'species')])""" + actual = column_transformer.__repr__() + assert expected == actual + + +def test_customtransformer_compile_sql(mock_X): + ident_trafo = SQLScalarColumnTransformer("{0}", target_column="ident_{0}") + sqls = ident_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) + assert sqls == [ + "col1 AS ident_col1", + "col2 AS ident_col2", + ] + + len1_trafo = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -5 ELSE LENGTH({0}) END", target_column="len1_{0}" + ) + sqls = len1_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) + assert sqls == [ + "CASE WHEN col1 IS NULL THEN -5 ELSE LENGTH(col1) END AS len1_col1", + "CASE WHEN col2 IS NULL THEN -5 ELSE LENGTH(col2) END AS len1_col2", + ] + + len2_trafo = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END", target_column="len2_{0}" + ) + sqls = len2_trafo._compile_to_sql(X=mock_X, columns=["col1", "col2"]) + assert sqls == [ + "CASE WHEN col1 IS NULL THEN 99 ELSE LENGTH(col1) END AS len2_col1", + "CASE WHEN col2 IS NULL THEN 99 ELSE LENGTH(col2) END AS len2_col2", + ] + + +def create_bq_model_mock(mocker, transform_columns, feature_columns=None): + properties = {"transformColumns": transform_columns} + mock_bq_model = bigquery.Model("model_project.model_dataset.model_id") + type(mock_bq_model)._properties = mock.PropertyMock(return_value=properties) + if feature_columns: + result = [ + bigquery.standard_sql.StandardSqlField(col, None) for col in feature_columns + ] + mocker.patch( + "google.cloud.bigquery.model.Model.feature_columns", + new_callable=mock.PropertyMock(return_value=result), + ) + + return mock_bq_model + + +@pytest.fixture +def bq_model_good(mocker): + return create_bq_model_mock( + mocker, + [ + { + "name": "ident_culmen_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "culmen_length_mm /*CT.IDENT()*/", + }, + { + "name": "ident_flipper_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "flipper_length_mm /*CT.IDENT()*/", + }, + { + "name": "len1_species", + "type": {"typeKind": "INT64"}, + "transformSql": "CASE WHEN species IS NULL THEN -5 ELSE LENGTH(species) END /*CT.LEN1()*/", + }, + { + "name": "len2_species", + "type": {"typeKind": "INT64"}, + "transformSql": "CASE WHEN species IS NULL THEN 99 ELSE LENGTH(species) END /*CT.LEN2([99])*/", + }, + { + "name": "labelencoded_county", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.LABEL_ENCODER(county, 1000000, 0) OVER()", + }, + { + "name": "labelencoded_species", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.LABEL_ENCODER(species, 1000000, 0) OVER()", + }, + ], + ) + + +@pytest.fixture +def bq_model_merge(mocker): + return create_bq_model_mock( + mocker, + [ + { + "name": "labelencoded_county", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.LABEL_ENCODER(county, 1000000, 0) OVER()", + }, + { + "name": "labelencoded_species", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.LABEL_ENCODER(species, 1000000, 0) OVER()", + }, + ], + ["county", "species"], + ) + + +@pytest.fixture +def bq_model_no_merge(mocker): + return create_bq_model_mock( + mocker, + [ + { + "name": "ident_culmen_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "culmen_length_mm /*CT.IDENT()*/", + } + ], + ["culmen_length_mm"], + ) + + +@pytest.fixture +def bq_model_unknown_ML(mocker): + return create_bq_model_mock( + mocker, + [ + { + "name": "unknownml_culmen_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.UNKNOWN(culmen_length_mm)", + }, + { + "name": "labelencoded_county", + "type": {"typeKind": "INT64"}, + "transformSql": "ML.LABEL_ENCODER(county, 1000000, 0) OVER()", + }, + ], + ) + + +@pytest.fixture +def bq_model_flexnames(mocker): + return create_bq_model_mock( + mocker, + [ + { + "name": "Flex Name culmen_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "culmen_length_mm", + }, + { + "name": "transformed_Culmen Length MM", + "type": {"typeKind": "INT64"}, + "transformSql": "`Culmen Length MM`*/", + }, + # test workaround for bug in get_model + { + "name": "Flex Name flipper_length_mm", + "type": {"typeKind": "INT64"}, + "transformSql": "flipper_length_mm AS `Flex Name flipper_length_mm`", + }, + { + "name": "transformed_Flipper Length MM", + "type": {"typeKind": "INT64"}, + "transformSql": "`Flipper Length MM` AS `transformed_Flipper Length MM`*/", + }, + ], + ) + + +def test_columntransformer_extract_from_bq_model_good(bq_model_good): + col_trans = ColumnTransformer._extract_from_bq_model(bq_model_good) + assert len(col_trans.transformers) == 6 + # normalize the representation for string comparing + col_trans.transformers.sort(key=lambda trafo: str(trafo)) + actual = col_trans.__repr__() + expected = """ColumnTransformer(transformers=[('label_encoder', + LabelEncoder(max_categories=1000001, + min_frequency=0), + 'county'), + ('label_encoder', + LabelEncoder(max_categories=1000001, + min_frequency=0), + 'species'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='CASE WHEN species IS NULL THEN -5 ELSE LENGTH(species) END /*CT.LEN1()*/', target_column='len1_species'), + '?len1_species'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='CASE WHEN species IS NULL THEN 99 ELSE LENGTH(species) END /*CT.LEN2([99])*/', target_column='len2_species'), + '?len2_species'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='culmen_length_mm /*CT.IDENT()*/', target_column='ident_culmen_length_mm'), + '?ident_culmen_length_mm'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='flipper_length_mm /*CT.IDENT()*/', target_column='ident_flipper_length_mm'), + '?ident_flipper_length_mm')])""" + assert expected == actual + + +def test_columntransformer_extract_from_bq_model_merge(bq_model_merge): + col_trans = ColumnTransformer._extract_from_bq_model(bq_model_merge) + assert isinstance(col_trans, ColumnTransformer) + merged_col_trans = col_trans._merge(bq_model_merge) + assert isinstance(merged_col_trans, preprocessing.LabelEncoder) + assert ( + merged_col_trans.__repr__() + == """LabelEncoder(max_categories=1000001, min_frequency=0)""" + ) + assert merged_col_trans._output_names == [ + "labelencoded_county", + "labelencoded_species", + ] + + +def test_columntransformer_extract_from_bq_model_no_merge(bq_model_no_merge): + col_trans = ColumnTransformer._extract_from_bq_model(bq_model_no_merge) + merged_col_trans = col_trans._merge(bq_model_no_merge) + assert isinstance(merged_col_trans, ColumnTransformer) + expected = """ColumnTransformer(transformers=[('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='culmen_length_mm /*CT.IDENT()*/', target_column='ident_culmen_length_mm'), + '?ident_culmen_length_mm')])""" + actual = merged_col_trans.__repr__() + assert expected == actual + + +def test_columntransformer_extract_from_bq_model_unknown_ML(bq_model_unknown_ML): + try: + _ = ColumnTransformer._extract_from_bq_model(bq_model_unknown_ML) + assert False + except NotImplementedError as e: + assert "Unsupported transformer type" in e.args[0] + + +def test_columntransformer_extract_output_names(bq_model_good): + class BQMLModel(BqmlModel): + def __init__(self, bq_model): + self._model = bq_model + + col_trans = ColumnTransformer._extract_from_bq_model(bq_model_good) + col_trans._bqml_model = BQMLModel(bq_model_good) + col_trans._extract_output_names() + assert col_trans._output_names == [ + "ident_culmen_length_mm", + "ident_flipper_length_mm", + "len1_species", + "len2_species", + "labelencoded_county", + "labelencoded_species", + ] + + +def test_columntransformer_compile_to_sql(mock_X): + ident_transformer = SQLScalarColumnTransformer("{0}", target_column="ident_{0}") + len1_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -2 ELSE LENGTH({0}) END", target_column="len1_{0}" + ) + len2_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END", target_column="len2_{0}" + ) + label_transformer = preprocessing.LabelEncoder() + column_transformer = compose.ColumnTransformer( + [ + ( + "ident_trafo", + ident_transformer, + ["culmen_length_mm", "flipper_length_mm"], + ), + ("len1_trafo", len1_transformer, ["species"]), + ("len2_trafo", len2_transformer, ["species"]), + ("label", label_transformer, "species"), + ] + ) + sqls = column_transformer._compile_to_sql(mock_X) + assert sqls == [ + "culmen_length_mm AS ident_culmen_length_mm", + "flipper_length_mm AS ident_flipper_length_mm", + "CASE WHEN species IS NULL THEN -2 ELSE LENGTH(species) END AS len1_species", + "CASE WHEN species IS NULL THEN 99 ELSE LENGTH(species) END AS len2_species", + "ML.LABEL_ENCODER(species, 1000000, 0) OVER() AS labelencoded_species", + ] + + +def test_columntransformer_flexible_column_names(mock_X): + ident_transformer = SQLScalarColumnTransformer("{0}", target_column="ident {0}") + len1_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN -2 ELSE LENGTH({0}) END", target_column="len1_{0}" + ) + len2_transformer = SQLScalarColumnTransformer( + "CASE WHEN {0} IS NULL THEN 99 ELSE LENGTH({0}) END", target_column="len2_{0}" + ) + column_transformer = compose.ColumnTransformer( + [ + ( + "ident_trafo", + ident_transformer, + ["culmen_length_mm", "flipper_length_mm"], + ), + ("len1_trafo", len1_transformer, ["species shortname"]), + ("len2_trafo", len2_transformer, ["`species longname`"]), + ] + ) + sqls = column_transformer._compile_to_sql(mock_X) + assert sqls == [ + "culmen_length_mm AS `ident culmen_length_mm`", + "flipper_length_mm AS `ident flipper_length_mm`", + "CASE WHEN `species shortname` IS NULL THEN -2 ELSE LENGTH(`species shortname`) END AS `len1_species shortname`", + "CASE WHEN `species longname` IS NULL THEN 99 ELSE LENGTH(`species longname`) END AS `len2_species longname`", + ] + + +def test_columntransformer_extract_from_bq_model_flexnames(bq_model_flexnames): + col_trans = ColumnTransformer._extract_from_bq_model(bq_model_flexnames) + assert len(col_trans.transformers) == 4 + # normalize the representation for string comparing + col_trans.transformers.sort(key=lambda trafo: str(trafo)) + actual = col_trans.__repr__() + expected = """ColumnTransformer(transformers=[('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='`Culmen Length MM`*/', target_column='transformed_Culmen Length MM'), + '?transformed_Culmen Length MM'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='`Flipper Length MM` AS `transformed_Flipper Length MM`*/', target_column='transformed_Flipper Length MM'), + '?transformed_Flipper Length MM'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='culmen_length_mm', target_column='Flex Name culmen_length_mm'), + '?Flex Name culmen_length_mm'), + ('sql_scalar_column_transformer', + SQLScalarColumnTransformer(sql='flipper_length_mm ', target_column='Flex Name flipper_length_mm'), + '?Flex Name flipper_length_mm')])""" + assert expected == actual