Skip to content

enforce float32 dtype for target encoded features. #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ml_garden/core/steps/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def _apply_encoding(

encoded_data = self._restore_column_order(df, encoded_data)
encoded_data = self._restore_numeric_dtypes(encoded_data, original_numeric_dtypes)
encoded_data = self._restore_target_encoded_dtypes(encoded_data, encoder)

encoded_data = self._convert_float64_to_float32(encoded_data)

feature_encoder_map = self._create_feature_encoder_map(encoder)
Expand Down Expand Up @@ -368,6 +370,22 @@ def _restore_numeric_dtypes(
)
return encoded_data

def _restore_target_encoded_dtypes(self, encoded_data, encoder):
"""Convert the columns handled by any TargetEncoder in the given encoder to float32."""
for name, transformer, cols in encoder.transformers:
# Direct TargetEncoder
if type(transformer).__name__ == "TargetEncoder":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use isinstance(transformer, TargetEncoder) here?

for col in cols:
encoded_data[col] = encoded_data[col].astype("float32")

# Nested transformers (if inside pipelines or additional ColumnTransformers)
elif hasattr(transformer, "transformers"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we have cases like these?

for nested_name, nested_transformer, nested_cols in transformer.transformers:
if type(nested_transformer).__name__ == "TargetEncoder":
for col in nested_cols:
encoded_data[col] = encoded_data[col].astype("float32")
return encoded_data

def _convert_float64_to_float32(self, encoded_data: pd.DataFrame) -> pd.DataFrame:
"""Convert float64 columns to float32."""
float64_columns = encoded_data.select_dtypes(include=["float64"]).columns
Expand Down
Loading