-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Subsampling transformer #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
2b427e5
try again
rflamary 5b94809
stuff
rflamary 5b825a7
A few quick fixes
kachayev c6b5a90
Change the order of rng init
kachayev 65b017f
test with full pipeline
rflamary eb498a4
working pipe
rflamary 7d87144
aupdate test
rflamary f34b1a6
import SubsampleTransformer
rflamary 9af315e
update to more general parameter and use of sklearn spliutter
rflamary 9ffbeda
add stratified smapler
rflamary 3196ead
upate init file
rflamary 2574f89
better tests
rflamary ad6a2ad
Merge branch 'main' into subsample_transfomer
antoinecollas 47b9e31
better test
rflamary e6ec846
Merge branch 'main' into subsample_transfomer
antoinecollas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Author: Yanis Lalou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
|
||
|
||
from sklearn.preprocessing import StandardScaler | ||
|
||
from skada import CORAL, make_da_pipeline | ||
from skada.transformers import ( | ||
DomainStratifiedSubsampleTransformer, | ||
SubsampleTransformer, | ||
) | ||
|
||
|
||
def test_SubsampleTransformer(da_dataset): | ||
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) | ||
|
||
train_size = 10 | ||
|
||
transformer = SubsampleTransformer(train_size=train_size) | ||
|
||
X_subsampled, y_subsampled, params = transformer.fit_transform( | ||
X, y, sample_domain=sample_domain | ||
) | ||
|
||
assert X_subsampled.shape[0] == train_size | ||
assert y_subsampled.shape[0] == train_size | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert "sample_domain" in params | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
X_target, y_target, sample_domain_target = da_dataset.pack_test(as_targets=["t"]) | ||
|
||
X_target_subsampled = transformer.transform( | ||
X_target, y_target, sample_domain=sample_domain_target | ||
) | ||
|
||
assert X_target_subsampled.shape[0] == X_target.shape[0] | ||
|
||
# within a pipeline | ||
|
||
transformer = SubsampleTransformer(train_size=train_size) | ||
pipeline = make_da_pipeline(StandardScaler(), transformer) | ||
|
||
temp = pipeline.fit_transform(X, y, sample_domain=sample_domain) | ||
|
||
assert temp is not None | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# now with a pipeline with end task | ||
transformer = SubsampleTransformer(train_size=train_size) | ||
pipeline = make_da_pipeline(StandardScaler(), transformer, CORAL()) | ||
|
||
pipeline.fit(X, y, sample_domain=sample_domain) | ||
|
||
ypred = pipeline.predict(X_target, sample_domain=sample_domain_target) | ||
assert ypred.shape[0] == X_target.shape[0] | ||
assert ypred.shape[0] == X_target.shape[0] | ||
|
||
|
||
def test_DomainStratifiedSubsampleTransformer(da_dataset): | ||
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) | ||
|
||
train_size = 10 | ||
|
||
transformer = DomainStratifiedSubsampleTransformer(train_size=train_size) | ||
|
||
X_subsampled, y_subsampled, params = transformer.fit_transform( | ||
X, y, sample_domain=sample_domain | ||
) | ||
|
||
assert X_subsampled.shape[0] == train_size | ||
assert y_subsampled.shape[0] == train_size | ||
assert "sample_domain" in params | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
X_target, y_target, sample_domain_target = da_dataset.pack_test(as_targets=["t"]) | ||
|
||
X_target_subsampled = transformer.transform( | ||
X_target, y_target, sample_domain=sample_domain_target | ||
) | ||
|
||
assert X_target_subsampled.shape[0] == X_target.shape[0] | ||
|
||
# within a pipeline | ||
|
||
transformer = DomainStratifiedSubsampleTransformer(train_size=train_size) | ||
pipeline = make_da_pipeline(StandardScaler(), transformer) | ||
|
||
temp = pipeline.fit_transform(X, y, sample_domain=sample_domain) | ||
|
||
assert temp is not None | ||
|
||
# now with a pipeline with end task | ||
transformer = DomainStratifiedSubsampleTransformer(train_size=train_size) | ||
pipeline = make_da_pipeline(StandardScaler(), transformer, CORAL()) | ||
|
||
pipeline.fit(X, y, sample_domain=sample_domain) | ||
|
||
ypred = pipeline.predict(X_target, sample_domain=sample_domain_target) | ||
assert ypred.shape[0] == X_target.shape[0] | ||
assert ypred.shape[0] == X_target.shape[0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Author: Remi Flamary <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
|
||
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit | ||
from sklearn.utils import check_random_state | ||
|
||
from .base import BaseAdapter | ||
from .utils import check_X_y_domain | ||
|
||
|
||
class SubsampleTransformer(BaseAdapter): | ||
"""Transformer that subsamples the data. | ||
|
||
This transformer is useful to speed up computations when the data is too | ||
large. It randomly selects a subset of the data to work with during training | ||
but does not change the data during testing. | ||
|
||
.. note:: | ||
This transformer should not be used as the last step of a pipeline | ||
because it returns non standard output. | ||
|
||
Parameters | ||
---------- | ||
train_size : int, float | ||
Number of samples to keep (keep all if data smaller) if integer, or | ||
proportion of train sample if float 0<= train_size <= 1. | ||
random_state : int, RandomState instance or None, default=None | ||
Controls the random resampling of the data. | ||
""" | ||
|
||
def __init__(self, train_size, random_state=None): | ||
self.train_size = train_size | ||
self.random_state = random_state | ||
|
||
def _pack_params(self, idx, **params): | ||
return { | ||
k: (v[idx] if idx is not None else v) | ||
for k, v in params.items() | ||
if v is not None | ||
} | ||
|
||
def fit_transform(self, X, y=None, *, sample_domain=None, sample_weight=None): | ||
rflamary marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Fit and transform the data.""" | ||
X, y, sample_domain2 = check_X_y_domain(X, y, sample_domain) | ||
|
||
self.rng_ = check_random_state(self.random_state) | ||
|
||
if self.train_size >= X.shape[0]: | ||
return ( | ||
X, | ||
y, | ||
self._pack_params( | ||
None, sample_domain=sample_domain, sample_weight=sample_weight | ||
), | ||
) | ||
|
||
splitter = ShuffleSplit( | ||
n_splits=1, train_size=self.train_size, random_state=self.rng_ | ||
) | ||
|
||
idx = next(splitter.split(X))[0] | ||
X_subsampled = X[idx] | ||
y_subsampled = y[idx] if y is not None else None | ||
params = self._pack_params( | ||
idx, sample_domain=sample_domain2, sample_weight=sample_weight | ||
) | ||
return X_subsampled, y_subsampled, params | ||
|
||
def transform(self, X, y=None, *, sample_domain=None, sample_weight=None): | ||
"""Transform the data.""" | ||
return X | ||
|
||
|
||
class DomainStratifiedSubsampleTransformer(BaseAdapter): | ||
"""Transformer that subsamples the data in a domain stratified way. | ||
|
||
This transformer is useful to speed up computations when the data is too | ||
large. It randomly selects a subset of the data to work with during training | ||
but does not change the data during testing. | ||
|
||
.. note:: | ||
This transformer should not be used as the last step of a pipeline | ||
because it returns non standard output. | ||
|
||
Parameters | ||
---------- | ||
train_size : int, float | ||
Number of samples to keep (keep all if data smaller) if integer, or | ||
proportion of train sample if float 0<= train_size <= 1. | ||
random_state : int, RandomState instance or None, default=None | ||
Controls the random resampling of the data. | ||
""" | ||
|
||
def __init__(self, train_size, random_state=None): | ||
self.train_size = train_size | ||
self.random_state = random_state | ||
|
||
def _pack_params(self, idx, **params): | ||
return { | ||
k: (v[idx] if idx is not None else v) | ||
for k, v in params.items() | ||
if v is not None | ||
} | ||
|
||
def fit_transform(self, X, y=None, *, sample_domain=None, sample_weight=None): | ||
"""Fit and transform the data.""" | ||
X, y, sample_domain2 = check_X_y_domain(X, y, sample_domain) | ||
|
||
self.rng_ = check_random_state(self.random_state) | ||
|
||
if self.train_size >= X.shape[0]: | ||
return ( | ||
X, | ||
y, | ||
self._pack_params( | ||
None, sample_domain=sample_domain, sample_weight=sample_weight | ||
), | ||
) | ||
|
||
splitter = StratifiedShuffleSplit( | ||
n_splits=1, train_size=self.train_size, random_state=self.rng_ | ||
) | ||
idx = next(splitter.split(X, sample_domain2))[0] | ||
X_subsampled = X[idx] | ||
y_subsampled = y[idx] if y is not None else None | ||
params = self._pack_params( | ||
idx, sample_domain=sample_domain2, sample_weight=sample_weight | ||
) | ||
return X_subsampled, y_subsampled, params | ||
|
||
def transform(self, X, y=None, *, sample_domain=None, sample_weight=None): | ||
"""Transform the data.""" | ||
return X |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.