Skip to content

Fix issues regarding sklearn param name change #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions lofo/lofo_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from lofo.infer_defaults import infer_model
from lofo.utils import lofo_to_df, parallel_apply
import sklearn


class LOFOImportance:
Expand All @@ -25,7 +26,7 @@ class LOFOImportance:
n_jobs: int, optional
Number of jobs for parallel computation
cv_groups: array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into train/test set.
Group labels for the samples used while splitting the dataset into train/test set.
Only used in conjunction with a “Group” cv instance (e.g., GroupKFold).
"""

Expand All @@ -48,13 +49,19 @@ def __init__(self, dataset, scoring, model=None, fit_params=None, cv=4, n_jobs=N
"of jobs of LOFO to be equal to 1, otherwise you may experience performance issues.")
warnings.warn(warning_str)

sklearn_version = tuple(map(int, sklearn.__version__.split(".")[:2]))
self._cv_param_name = "params" if sklearn_version >= (1, 4) else "fit_params"

def _get_cv_score(self, feature_to_remove):
X, fit_params = self.dataset.getX(feature_to_remove=feature_to_remove, fit_params=self.fit_params)
y = self.dataset.y

kwargs = {self._cv_param_name: fit_params,
"cv": self.cv, "scoring": self.scoring, "groups": self.cv_groups}

with warnings.catch_warnings():
warnings.simplefilter("ignore")
cv_results = cross_validate(self.model, X, y, cv=self.cv, scoring=self.scoring, fit_params=fit_params, groups=self.cv_groups)
cv_results = cross_validate(self.model, X, y, **kwargs)
return cv_results['test_score']

def _get_cv_score_parallel(self, feature, result_queue):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name='lofo-importance',
version='0.3.4',
version='0.3.5',
url="https://github.com/aerdem4/lofo-importance",
author="Ahmet Erdem",
author_email="[email protected]",
Expand Down