Skip to content

feat: add ml.preprocessing.LabelEncoder #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bigframes/ml/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CompilablePreprocessorType = Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.LabelEncoder,
]


Expand Down
15 changes: 14 additions & 1 deletion bigframes/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
compose.ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.LabelEncoder,
),
):
self._transform = transform
Expand Down Expand Up @@ -143,7 +144,11 @@ def _extract_as_column_transformer(
transformers: List[
Tuple[
str,
Union[preprocessing.OneHotEncoder, preprocessing.StandardScaler],
Union[
preprocessing.OneHotEncoder,
preprocessing.StandardScaler,
preprocessing.LabelEncoder,
],
Union[str, List[str]],
]
] = []
Expand All @@ -167,6 +172,13 @@ def _extract_as_column_transformer(
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
)
)
elif transform_sql.startswith("ML.LABEL_ENCODER"):
transformers.append(
(
"label_encoder",
*preprocessing.LabelEncoder._parse_from_sql(transform_sql),
)
)
else:
raise NotImplementedError(
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
Expand All @@ -181,6 +193,7 @@ def _merge_column_transformer(
compose.ColumnTransformer,
preprocessing.StandardScaler,
preprocessing.OneHotEncoder,
preprocessing.LabelEncoder,
]:
"""Try to merge the column transformer to a simple transformer."""
transformers = column_transformer.transformers_
Expand Down
119 changes: 119 additions & 0 deletions bigframes/ml/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import bigframes.pandas as bpd
import third_party.bigframes_vendored.sklearn.preprocessing._data
import third_party.bigframes_vendored.sklearn.preprocessing._encoder
import third_party.bigframes_vendored.sklearn.preprocessing._label


class StandardScaler(
Expand Down Expand Up @@ -229,3 +230,121 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
bpd.DataFrame,
df[self._output_names],
)


class LabelEncoder(
base.Transformer,
third_party.bigframes_vendored.sklearn.preprocessing._label.LabelEncoder,
):
# BQML max value https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder#syntax
TOP_K_DEFAULT = 1000000
FREQUENCY_THRESHOLD_DEFAULT = 0

__doc__ = (
third_party.bigframes_vendored.sklearn.preprocessing._label.LabelEncoder.__doc__
)

# All estimators must implement __init__ to document their parameters, even
# if they don't have any
def __init__(
self,
min_frequency: Optional[int] = None,
max_categories: Optional[int] = None,
):
if max_categories is not None and max_categories < 2:
raise ValueError(
f"max_categories has to be larger than or equal to 2, input is {max_categories}."
)
self.min_frequency = min_frequency
self.max_categories = max_categories
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._base_sql_generator = globals.base_sql_generator()

# TODO(garrettwu): implement __hash__
def __eq__(self, other: Any) -> bool:
return (
type(other) is LabelEncoder
and self._bqml_model == other._bqml_model
and self.min_frequency == other.min_frequency
and self.max_categories == other.max_categories
)

def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
"""Compile this transformer to a list of SQL expressions that can be included in
a BQML TRANSFORM clause

Args:
columns:
a list of column names to transform

Returns: a list of tuples of (sql_expression, output_name)"""

# minus one here since BQML's inplimentation always includes index 0, and top_k is on top of that.
top_k = (
(self.max_categories - 1)
if self.max_categories is not None
else LabelEncoder.TOP_K_DEFAULT
)
frequency_threshold = (
self.min_frequency
if self.min_frequency is not None
else LabelEncoder.FREQUENCY_THRESHOLD_DEFAULT
)
return [
(
self._base_sql_generator.ml_label_encoder(
column, top_k, frequency_threshold, f"labelencoded_{column}"
),
f"labelencoded_{column}",
)
for column in columns
]

@classmethod
def _parse_from_sql(cls, sql: str) -> tuple[LabelEncoder, str]:
"""Parse SQL to tuple(LabelEncoder, column_label).

Args:
sql: SQL string of format "ML.LabelEncoder({col_label}, {top_k}, {frequency_threshold}) OVER() "

Returns:
tuple(LabelEncoder, column_label)"""
s = sql[sql.find("(") + 1 : sql.find(")")]
col_label, top_k, frequency_threshold = s.split(", ")
max_categories = int(top_k) + 1
min_frequency = int(frequency_threshold)

return cls(min_frequency, max_categories), col_label

def fit(
self,
X: Union[bpd.DataFrame, bpd.Series],
y=None, # ignored
) -> LabelEncoder:
(X,) = utils.convert_to_dataframe(X)

compiled_transforms = self._compile_to_sql(X.columns.tolist())
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]

self._bqml_model = self._bqml_model_factory.create_model(
X,
options={"model_type": "transform_only"},
transforms=transform_sqls,
)

# The schema of TRANSFORM output is not available in the model API, so save it during fitting
self._output_names = [name for _, name in compiled_transforms]
return self

def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("Must be fitted before transform")

(X,) = utils.convert_to_dataframe(X)

df = self._bqml_model.transform(X)
return typing.cast(
bpd.DataFrame,
df[self._output_names],
)
11 changes: 11 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def ml_one_hot_encoder(
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder for params."""
return f"""ML.ONE_HOT_ENCODER({numeric_expr_sql}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {name}"""

def ml_label_encoder(
self,
numeric_expr_sql: str,
top_k: int,
frequency_threshold: int,
name: str,
) -> str:
"""Encode ML.LABEL_ENCODER for BQML.
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-label-encoder for params."""
return f"""ML.LABEL_ENCODER({numeric_expr_sql}, {top_k}, {frequency_threshold}) OVER() AS {name}"""


class ModelCreationSqlGenerator(BaseSqlGenerator):
"""Sql generator for creating a model entity. Model id is the standalone id without project id and dataset id."""
Expand Down
61 changes: 56 additions & 5 deletions tests/system/large/ml/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ def test_pipeline_columntransformer_fit_predict(session, penguins_df_default_ind
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"label",
preprocessing.LabelEncoder(),
"species",
),
]
),
),
Expand Down Expand Up @@ -632,6 +637,11 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
preprocessing.StandardScaler(),
["culmen_length_mm", "flipper_length_mm"],
),
(
"label",
preprocessing.LabelEncoder(),
"species",
),
]
),
),
Expand All @@ -650,7 +660,7 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id

assert isinstance(pl_loaded._transform, compose.ColumnTransformer)
transformers = pl_loaded._transform.transformers_
assert len(transformers) == 3
assert len(transformers) == 4

assert transformers[0][0] == "ont_hot_encoder"
assert isinstance(transformers[0][1], preprocessing.OneHotEncoder)
Expand All @@ -660,13 +670,20 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
assert one_hot_encoder.max_categories == 100
assert transformers[0][2] == "species"

assert transformers[1][0] == "standard_scaler"
assert isinstance(transformers[1][1], preprocessing.StandardScaler)
assert transformers[1][2] == "culmen_length_mm"
assert transformers[1][0] == "label_encoder"
assert isinstance(transformers[1][1], preprocessing.LabelEncoder)
one_hot_encoder = transformers[1][1]
assert one_hot_encoder.min_frequency == 0
assert one_hot_encoder.max_categories == 1000001
assert transformers[1][2] == "species"

assert transformers[2][0] == "standard_scaler"
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
assert transformers[2][2] == "flipper_length_mm"
assert transformers[2][2] == "culmen_length_mm"

assert transformers[3][0] == "standard_scaler"
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
assert transformers[3][2] == "flipper_length_mm"

assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
assert pl_loaded._estimator.fit_intercept is False
Expand Down Expand Up @@ -735,3 +752,37 @@ def test_pipeline_one_hot_encoder_to_gbq(penguins_df_default_index, dataset_id):

assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
assert pl_loaded._estimator.fit_intercept is False


def test_pipeline_label_encoder_to_gbq(penguins_df_default_index, dataset_id):
pl = pipeline.Pipeline(
[
(
"transform",
preprocessing.LabelEncoder(min_frequency=5, max_categories=100),
),
("estimator", linear_model.LinearRegression(fit_intercept=False)),
]
)

df = penguins_df_default_index.dropna()
X_train = df[
[
"sex",
"species",
]
]
y_train = df[["body_mass_g"]]
pl.fit(X_train, y_train)

pl_loaded = pl.to_gbq(
f"{dataset_id}.test_penguins_pipeline_label_encoder", replace=True
)
assert isinstance(pl_loaded._transform, preprocessing.LabelEncoder)

label_encoder = pl_loaded._transform
assert label_encoder.min_frequency == 5
assert label_encoder.max_categories == 100

assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
assert pl_loaded._estimator.fit_intercept is False
Loading