Skip to content

Commit c1a6889

Browse files
author
Reinier Koops
authored
SHAPRFECV speedup for bigger use-cases and some simple refactoring. (#252)
This PR depends on the PR to be accepted: #248 ______ This cleanup removes some more unused code and simplifies parts of our implementations. It should allow for a boost in performance for bigger use-cases, although minimal. Also fix: - [x] #242 comments - [x] #255 - [x] #245
1 parent ab672d4 commit c1a6889

19 files changed

+179606
-5608
lines changed

LICENCE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2020 ING Bank N.V.
1+
Copyright (c) ING Bank N.V.
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
44

docs/tutorials/nb_shap_feature_elimination.ipynb

+178,569-4,229
Large diffs are not rendered by default.
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .feature_elimination import ShapRFECV, EarlyStoppingShapRFECV
1+
from .feature_elimination import ShapRFECV
2+
from .early_stopping_feature_elimination import EarlyStoppingShapRFECV
23

34
__all__ = ["ShapRFECV", "EarlyStoppingShapRFECV"]

probatus/feature_elimination/early_stopping_feature_elimination.py

+543
Large diffs are not rendered by default.

probatus/feature_elimination/feature_elimination.py

+416-982
Large diffs are not rendered by default.

probatus/interpret/model_interpret.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
BaseFitComputePlotClass,
1010
assure_list_of_strings,
1111
calculate_shap_importance,
12-
get_single_scorer,
1312
preprocess_data,
1413
preprocess_labels,
14+
get_single_scorer,
1515
shap_calc,
1616
)
1717

probatus/sample_similarity/resemblance_model.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.inspection import permutation_importance
99
from sklearn.model_selection import train_test_split
1010

11-
from probatus.utils import BaseFitComputePlotClass, get_single_scorer, preprocess_data, preprocess_labels
11+
from probatus.utils import BaseFitComputePlotClass, preprocess_data, preprocess_labels, get_single_scorer
1212
from probatus.utils.shap_helpers import calculate_shap_importance, shap_calc
1313

1414

@@ -108,10 +108,6 @@ def fit(self, X1, X2, column_names=None, class_names=None):
108108
(BaseResemblanceModel):
109109
Fitted object
110110
"""
111-
# Set seed for results reproducibility
112-
if self.random_state is not None:
113-
np.random.seed(self.random_state)
114-
115111
# Set class names
116112
self.class_names = class_names
117113
if self.class_names is None:

probatus/utils/__init__.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
1-
from .exceptions import NotFittedError, UnsupportedModelError
2-
from .scoring import Scorer, get_scorers, get_single_scorer
1+
from .exceptions import NotFittedError
32
from .arrayfuncs import (
43
assure_pandas_df,
54
assure_pandas_series,
65
preprocess_data,
76
preprocess_labels,
87
)
8+
from .scoring import Scorer, get_single_scorer
99
from .shap_helpers import shap_calc, shap_to_df, calculate_shap_importance
1010
from ._utils import assure_list_of_strings
1111
from .base_class_interface import BaseFitComputeClass, BaseFitComputePlotClass
1212

1313
__all__ = [
14-
"NotFittedError",
15-
"UnsupportedModelError",
16-
"Scorer",
17-
"assure_pandas_df",
18-
"get_scorers",
1914
"assure_list_of_strings",
20-
"shap_calc",
21-
"shap_to_df",
22-
"calculate_shap_importance",
15+
"assure_pandas_df",
2316
"assure_pandas_series",
2417
"preprocess_data",
2518
"preprocess_labels",
2619
"BaseFitComputeClass",
2720
"BaseFitComputePlotClass",
21+
"NotFittedError",
2822
"get_single_scorer",
23+
"Scorer",
24+
"shap_calc",
25+
"shap_to_df",
26+
"calculate_shap_importance",
2927
]

probatus/utils/arrayfuncs.py

+27-38
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,15 @@ def assure_pandas_df(x, column_names=None):
1515
pandas DataFrame
1616
"""
1717
if isinstance(x, pd.DataFrame):
18-
# Check if column_names are passed correctly
1918
if column_names is not None:
2019
x.columns = column_names
21-
return x
22-
elif any(
23-
[
24-
isinstance(x, np.ndarray),
25-
isinstance(x, pd.core.series.Series),
26-
isinstance(x, list),
27-
]
28-
):
29-
return pd.DataFrame(x, columns=column_names)
20+
elif isinstance(x, (np.ndarray, pd.Series, list)):
21+
x = pd.DataFrame(x, columns=column_names)
3022
else:
3123
raise TypeError("Please supply a list, numpy array, pandas Series or pandas DataFrame")
3224

25+
return x
26+
3327

3428
def assure_pandas_series(x, index=None):
3529
"""
@@ -42,7 +36,7 @@ def assure_pandas_series(x, index=None):
4236
pandas Series
4337
"""
4438
if isinstance(x, pd.Series):
45-
if isinstance(index, list) or isinstance(index, np.ndarray):
39+
if isinstance(index, (list, np.ndarray)):
4640
index = pd.Index(index)
4741
current_x_index = pd.Index(x.index.values)
4842
if current_x_index.equals(index):
@@ -55,7 +49,7 @@ def assure_pandas_series(x, index=None):
5549
# If indexes have different values, overwrite
5650
x.index = index
5751
return x
58-
elif any([isinstance(x, np.ndarray), isinstance(x, list)]):
52+
elif any([isinstance(x, (np.ndarray, list))]):
5953
return pd.Series(x, index=index)
6054
else:
6155
raise TypeError("Please supply a list, numpy array, pandas Series")
@@ -92,40 +86,36 @@ def preprocess_data(X, X_name=None, column_names=None, verbose=0):
9286
(pd.DataFrame):
9387
Preprocessed dataset.
9488
"""
95-
if X_name is None:
96-
X_name = "X"
89+
X_name = "X" if X_name is None else X_name
9790

9891
# Make sure that X is a pd.DataFrame with correct column names
9992
X = assure_pandas_df(X, column_names=column_names)
10093

101-
# Warn if missing
102-
columns_with_missing = [column for column in X.columns if X[column].isnull().values.any()]
103-
if len(columns_with_missing) > 0:
104-
if verbose > 0:
94+
if verbose > 0:
95+
# Warn if missing
96+
columns_with_missing = X.columns[X.isnull().any()].tolist()
97+
if columns_with_missing:
10598
warnings.warn(
10699
f"The following variables in {X_name} contains missing values {columns_with_missing}. "
107100
f"Make sure to impute missing or apply a model that handles them automatically."
108101
)
109102

110-
# Warn if categorical features and change to category
111-
indices_categorical_features = [
112-
column[0] for column in enumerate(X.dtypes) if column[1].name in ["category", "object"]
113-
]
114-
categorical_features = list(X.columns[indices_categorical_features])
115-
116-
# Set categorical features type to category
117-
if len(categorical_features) > 0:
118-
if verbose > 0:
119-
warnings.warn(
120-
f"The following variables in {X_name} contains categorical variables: "
121-
f"{categorical_features}. Make sure to use a model that handles them automatically or "
122-
f"encode them into numerical variables."
123-
)
103+
# Warn if categorical features and change to category
104+
categorical_features = X.select_dtypes(include=["category", "object"]).columns.tolist()
105+
# Set categorical features type to category
106+
if categorical_features:
107+
if verbose > 0:
108+
warnings.warn(
109+
f"The following variables in {X_name} contains categorical variables: "
110+
f"{categorical_features}. Make sure to use a model that handles them automatically or "
111+
f"encode them into numerical variables."
112+
)
113+
114+
# Ensure category dtype, to enable models e.g. LighGBM, handle them automatically
115+
object_columns = X.select_dtypes(include=["object"]).columns
116+
if not object_columns.empty:
117+
X[object_columns] = X[object_columns].astype("category")
124118

125-
# Ensure category dtype, to enable models e.g. LighGBM, handle them automatically
126-
for categorical_feature in categorical_features:
127-
if X[categorical_feature].dtype.name == "object":
128-
X[categorical_feature] = X[categorical_feature].astype("category")
129119
return X, X.columns.tolist()
130120

131121

@@ -157,8 +147,7 @@ def preprocess_labels(y, y_name=None, index=None, verbose=0):
157147
(pd.Series):
158148
Labels in the form of pd.Series.
159149
"""
160-
if y_name is None:
161-
y_name = "y"
150+
y_name = "y" if y_name is None else y_name
162151

163152
# Make sure that y is a series with correct index
164153
y = assure_pandas_series(y, index=index)

probatus/utils/exceptions.py

-13
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,3 @@ def __init__(self, message):
88
Init error.
99
"""
1010
self.message = message
11-
12-
13-
class UnsupportedModelError(Exception):
14-
"""
15-
Error.
16-
"""
17-
18-
def __init__(self, message):
19-
# TODO: Add this check for unsupported models to our implementations.
20-
"""
21-
Init error.
22-
"""
23-
self.message = message

probatus/utils/scoring.py

+3-27
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,9 @@
11
from sklearn.metrics import get_scorer
22

33

4-
def get_scorers(scoring):
5-
"""
6-
Returns Scorers list based on the provided scoring.
7-
8-
Args:
9-
scoring (string, list of strings, probatus.utils.Scorer or list of probatus.utils.Scorers):
10-
Metrics for which the score is calculated. It can be either a name or list of names metric names and
11-
needs to be aligned with predefined classification scorers names in sklearn
12-
([link](https://scikit-learn.org/stable/modules/model_evaluation.html)).
13-
Another option is using probatus.utils.Scorer to define a custom metric.
14-
15-
Returns:
16-
(list of probatus.utils.Scorer):
17-
List of scorers that can be used for scoring models
18-
"""
19-
scorers = []
20-
if isinstance(scoring, list):
21-
for scorer in scoring:
22-
scorers.append(get_single_scorer(scorer))
23-
else:
24-
scorers.append(get_single_scorer(scoring))
25-
return scorers
26-
27-
284
def get_single_scorer(scoring):
295
"""
30-
Returns single Scorer, based on provided input in scoring argument.
6+
Returns Scorer, based on provided input in scoring argument.
317
328
Args:
339
scoring (string or probatus.utils.Scorer, optional):
@@ -67,7 +43,7 @@ class Scorer:
6743
6844
# Make custom scorer with following function:
6945
def custom_metric(y_true, y_pred):
70-
return (y_true == y_pred).sum()
46+
return (y_true == y_pred).sum()
7147
scorer2 = Scorer('custom_metric', custom_scorer=make_scorer(custom_metric))
7248
7349
# Prepare two samples
@@ -110,7 +86,7 @@ def score(self, model, X, y):
11086
"""
11187
Scores the samples model based on the provided metric name.
11288
113-
Args:
89+
Args
11490
model (model object):
11591
Model to be scored.
11692

0 commit comments

Comments
 (0)