Skip to content

Commit dd5907f

Browse files
committed
constant column patch
1 parent 73240c9 commit dd5907f

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

ml_grid/pipeline/data.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ml_grid.pipeline import read_in
1010
from ml_grid.pipeline.column_names import get_pertubation_columns
1111
from ml_grid.pipeline.data_clean_up import clean_up_class
12-
from ml_grid.pipeline.data_constant_columns import remove_constant_columns
12+
from ml_grid.pipeline.data_constant_columns import remove_constant_columns, remove_constant_columns_with_debug
1313
from ml_grid.pipeline.data_correlation_matrix import handle_correlation_matrix
1414
from ml_grid.pipeline.data_feature_importance_methods import feature_importance_methods
1515
from ml_grid.pipeline.data_outcome_list import handle_outcome_list
@@ -292,6 +292,14 @@ def __init__(
292292
self.X_test_orig,
293293
self.y_test_orig,
294294
) = get_data_split(X=self.X, y=self.y, local_param_dict=self.local_param_dict)
295+
296+
# Handle columns made constant by splitting
297+
self.X_train, self.X_test, self.X_test_orig = remove_constant_columns_with_debug(
298+
self.X_train,
299+
self.X_test,
300+
self.X_test_orig,
301+
verbosity=self.verbose
302+
)
295303

296304
target_n_features = self.local_param_dict.get("feature_n")
297305

ml_grid/pipeline/data_constant_columns.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,53 @@ def remove_constant_columns(X: pd.DataFrame, drop_list: Optional[List[str]] = No
4747
print("Unhandled exception:", str(e))
4848
raise
4949

50-
return drop_list
50+
return drop_list
51+
52+
def remove_constant_columns_with_debug(X_train, X_test, X_test_orig, verbosity=2):
53+
if verbosity > 0:
54+
# Debug message: Initial shapes of X_train, X_test, X_test_orig
55+
print(f"Initial X_train shape: {X_train.shape}")
56+
print(f"Initial X_test shape: {X_test.shape}")
57+
print(f"Initial X_test_orig shape: {X_test_orig.shape}")
58+
59+
# Calculate the variance for each column in X_train
60+
train_variances = X_train.var(axis=0)
61+
if verbosity > 1:
62+
print(f"Variance of X_train columns:\n{train_variances}")
63+
64+
# Identify and remove constant columns in X_train
65+
constant_columns_train = train_variances[train_variances == 0].index
66+
if verbosity > 0:
67+
print(f"Constant columns in X_train: {list(constant_columns_train)}")
68+
69+
# Calculate the variance for each column in X_test
70+
test_variances = X_test.var(axis=0)
71+
if verbosity > 1:
72+
print(f"Variance of X_test columns:\n{test_variances}")
73+
74+
# Identify constant columns in X_test
75+
constant_columns_test = test_variances[test_variances == 0].index
76+
if verbosity > 0:
77+
print(f"Constant columns in X_test: {list(constant_columns_test)}")
78+
79+
# Combine constant columns from both X_train and X_test
80+
constant_columns = constant_columns_train.union(constant_columns_test)
81+
82+
# Remove the constant columns from both X_train and X_test
83+
X_train = X_train.loc[:, ~X_train.columns.isin(constant_columns)]
84+
X_test = X_test.loc[:, ~X_test.columns.isin(constant_columns)]
85+
86+
# Also remove the same constant columns from X_test_orig
87+
X_test_orig = X_test_orig.loc[:, ~X_test_orig.columns.isin(constant_columns)]
88+
89+
if verbosity > 0:
90+
# Debug message: Shape after removing constant columns from X_train, X_test, X_test_orig
91+
print(f"Shape of X_train after removing constant columns: {X_train.shape}")
92+
print(f"Shape of X_test after removing constant columns: {X_test.shape}")
93+
print(f"Shape of X_test_orig after removing constant columns: {X_test_orig.shape}")
94+
95+
# Return the modified X_train, X_test, and X_test_orig, with y_test_orig unchanged
96+
return X_train, X_test, X_test_orig
97+
98+
# Example usage with verbosity level 2 (most verbose)
99+
# X_train, X_test, X_test_orig = remove_constant_columns_with_debug(X_train, X_test, X_test_orig, verbosity=2)

0 commit comments

Comments
 (0)