Skip to content

Commit 9a976ce

Browse files
[MRG] Add DomainAndLabelStratifiedSubsampleTransformer + Fix DomainStratifiedSubsampleTransformer (#268)
* Add DomainAndLabelStratifiedSubsampleTransformer + fix DomainStratifiedSubsampleTransformer * Add test to check stratification proportions * rename subsamplers --------- Co-authored-by: Antoine Collas <[email protected]>
1 parent 65f1659 commit 9a976ce

File tree

3 files changed

+141
-16
lines changed

3 files changed

+141
-16
lines changed

skada/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
OTLabelProp,
6262
JCPOTLabelPropAdapter,
6363
JCPOTLabelProp)
64-
from .transformers import SubsampleTransformer, DomainStratifiedSubsampleTransformer
64+
from .transformers import Subsampler, DomainSubsampler, StratifiedDomainSubsampler
6565
from ._self_labeling import DASVMClassifier
6666
from ._pipeline import make_da_pipeline
6767
from .utils import source_target_split, per_domain_split

skada/tests/test_transformers.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,27 @@
22
#
33
# License: BSD 3-Clause
44

5+
from collections import Counter
6+
57
import numpy as np
68
from sklearn.preprocessing import StandardScaler
79

810
from skada import CORAL, make_da_pipeline
911
from skada.transformers import (
10-
DomainStratifiedSubsampleTransformer,
11-
SubsampleTransformer,
12+
DomainSubsampler,
13+
StratifiedDomainSubsampler,
14+
Subsampler,
1215
)
1316

1417

15-
def test_SubsampleTransformer(da_dataset):
18+
def test_Subsampler(da_dataset):
1619
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
1720
sample_weight = np.ones_like(y)
1821

1922
train_size = 10
2023

2124
# test size of output on fit_transform
22-
transformer = SubsampleTransformer(train_size=train_size, random_state=42)
25+
transformer = Subsampler(train_size=train_size, random_state=42)
2326

2427
X_subsampled, y_subsampled, params = transformer.fit_transform(
2528
X, y, sample_domain=sample_domain, sample_weight=sample_weight
@@ -40,26 +43,26 @@ def test_SubsampleTransformer(da_dataset):
4043
assert X_target_subsampled.shape[0] == X_target.shape[0]
4144

4245
# now with a pipeline with end task
43-
transformer = SubsampleTransformer(train_size=train_size)
46+
transformer = Subsampler(train_size=train_size)
4447
pipeline = make_da_pipeline(StandardScaler(), transformer, CORAL())
4548

4649
pipeline.fit(X, y, sample_domain=sample_domain)
4750

4851
ypred = pipeline.predict(X_target, sample_domain=sample_domain_target)
4952
assert ypred.shape[0] == X_target.shape[0]
50-
assert ypred.shape[0] == X_target.shape[0]
53+
54+
ypred = pipeline.predict(X, sample_domain=sample_domain, allow_source=True)
55+
assert ypred.shape[0] == X.shape[0]
5156

5257

53-
def test_DomainStratifiedSubsampleTransformer(da_dataset):
58+
def test_DomainSubsampler(da_dataset):
5459
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
5560
sample_weight = np.ones_like(y)
5661

5762
train_size = 10
5863

5964
# test size of output on fit_transform
60-
transformer = DomainStratifiedSubsampleTransformer(
61-
train_size=train_size, random_state=42
62-
)
65+
transformer = DomainSubsampler(train_size=train_size, random_state=42)
6366

6467
X_subsampled, y_subsampled, params = transformer.fit_transform(
6568
X, y, sample_domain=sample_domain, sample_weight=sample_weight
@@ -82,11 +85,64 @@ def test_DomainStratifiedSubsampleTransformer(da_dataset):
8285
assert X_target_subsampled.shape[0] == X_target.shape[0]
8386

8487
# now with a pipeline with end task
85-
transformer = DomainStratifiedSubsampleTransformer(train_size=train_size)
88+
transformer = DomainSubsampler(train_size=train_size)
8689
pipeline = make_da_pipeline(StandardScaler(), transformer, CORAL())
8790

8891
pipeline.fit(X, y, sample_domain=sample_domain)
8992

9093
ypred = pipeline.predict(X_target, sample_domain=sample_domain_target)
9194
assert ypred.shape[0] == X_target.shape[0]
95+
96+
ypred = pipeline.predict(X, sample_domain=sample_domain, allow_source=True)
97+
assert ypred.shape[0] == X.shape[0]
98+
99+
100+
def test_StratifiedDomainSubsampler(da_dataset):
101+
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
102+
sample_weight = np.ones_like(y)
103+
104+
train_size = 10
105+
106+
# test size of output on fit_transform
107+
transformer = StratifiedDomainSubsampler(train_size=train_size, random_state=42)
108+
109+
X_subsampled, y_subsampled, params = transformer.fit_transform(
110+
X, y, sample_domain=sample_domain, sample_weight=sample_weight
111+
)
112+
113+
assert X_subsampled.shape == (train_size, X.shape[1])
114+
assert y_subsampled.shape[0] == train_size
115+
assert params["sample_domain"].shape[0] == train_size
116+
assert params["sample_weight"].shape[0] == train_size
117+
118+
# Check stratification proportions
119+
original_freq = Counter(zip(sample_domain, y))
120+
subsampled_freq = Counter(zip(params["sample_domain"], y_subsampled))
121+
122+
for key in original_freq:
123+
original_ratio = original_freq[key] / len(y)
124+
subsampled_ratio = subsampled_freq[key] / train_size
125+
assert np.isclose(
126+
original_ratio, subsampled_ratio, atol=0.1
127+
), f"Stratification not preserved for {key}"
128+
129+
# test size of output on transform
130+
X_target, y_target, sample_domain_target = da_dataset.pack_test(as_targets=["t"])
131+
132+
X_target_subsampled = transformer.transform(
133+
X_target, y_target, sample_domain=sample_domain_target
134+
)
135+
136+
assert X_target_subsampled.shape[0] == X_target.shape[0]
137+
138+
# now with a pipeline with end task
139+
transformer = StratifiedDomainSubsampler(train_size=train_size)
140+
pipeline = make_da_pipeline(StandardScaler(), transformer, CORAL())
141+
142+
pipeline.fit(X, y, sample_domain=sample_domain)
143+
144+
ypred = pipeline.predict(X_target, sample_domain=sample_domain_target)
92145
assert ypred.shape[0] == X_target.shape[0]
146+
147+
ypred = pipeline.predict(X, sample_domain=sample_domain, allow_source=True)
148+
assert ypred.shape[0] == X.shape[0]

skada/transformers.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from sklearn.utils import check_random_state
77

88
from .base import BaseAdapter
9+
from .model_selection import StratifiedDomainShuffleSplit
910
from .utils import check_X_y_domain
1011

1112

12-
class SubsampleTransformer(BaseAdapter):
13+
class Subsampler(BaseAdapter):
1314
"""Transformer that subsamples the data.
1415
1516
This transformer is useful to speed up computations when the data is too
@@ -67,12 +68,14 @@ def fit_transform(self, X, y=None, *, sample_domain=None, sample_weight=None):
6768
)
6869
return X_subsampled, y_subsampled, params
6970

70-
def transform(self, X, y=None, *, sample_domain=None, sample_weight=None):
71+
def transform(
72+
self, X, y=None, *, sample_domain=None, sample_weight=None, allow_source=None
73+
):
7174
"""Transform the data."""
7275
return X
7376

7477

75-
class DomainStratifiedSubsampleTransformer(BaseAdapter):
78+
class DomainSubsampler(BaseAdapter):
7679
"""Transformer that subsamples the data in a domain stratified way.
7780
7881
This transformer is useful to speed up computations when the data is too
@@ -129,6 +132,72 @@ def fit_transform(self, X, y=None, *, sample_domain=None, sample_weight=None):
129132
)
130133
return X_subsampled, y_subsampled, params
131134

132-
def transform(self, X, y=None, *, sample_domain=None, sample_weight=None):
135+
def transform(
136+
self, X, y=None, *, sample_domain=None, sample_weight=None, allow_source=None
137+
):
138+
"""Transform the data."""
139+
return X
140+
141+
142+
class StratifiedDomainSubsampler(BaseAdapter):
143+
"""Transformer that subsamples the data in a domain and label stratified way.
144+
This transformer is useful to speed up computations when the data is too
145+
large. It randomly selects a subset of the data to work with during training
146+
but does not change the data during testing.
147+
148+
.. note::
149+
This transformer should not be used as the last step of a pipeline
150+
because it returns non standard output.
151+
152+
Parameters
153+
----------
154+
train_size : int, float
155+
Number of samples to keep (keep all if data smaller) if integer, or
156+
proportion of train sample if float 0<= train_size <= 1.
157+
random_state : int, RandomState instance or None, default=None
158+
Controls the random resampling of the data.
159+
"""
160+
161+
def __init__(self, train_size, random_state=None):
162+
self.train_size = train_size
163+
self.random_state = random_state
164+
165+
def _pack_params(self, idx, **params):
166+
return {
167+
k: (v[idx] if idx is not None else v)
168+
for k, v in params.items()
169+
if v is not None
170+
}
171+
172+
def fit_transform(self, X, y=None, *, sample_domain=None, sample_weight=None):
173+
"""Fit and transform the data."""
174+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
175+
176+
self.rng_ = check_random_state(self.random_state)
177+
178+
if self.train_size >= X.shape[0]:
179+
return (
180+
X,
181+
y,
182+
self._pack_params(
183+
None, sample_domain=sample_domain, sample_weight=sample_weight
184+
),
185+
)
186+
187+
splitter = StratifiedDomainShuffleSplit(
188+
n_splits=1, train_size=self.train_size, random_state=self.rng_
189+
)
190+
191+
train_idx, _ = next(splitter.split(X, y, sample_domain))
192+
X_subsampled = X[train_idx]
193+
y_subsampled = y[train_idx] if y is not None else None
194+
params = self._pack_params(
195+
train_idx, sample_domain=sample_domain, sample_weight=sample_weight
196+
)
197+
return X_subsampled, y_subsampled, params
198+
199+
def transform(
200+
self, X, y=None, *, sample_domain=None, sample_weight=None, allow_source=None
201+
):
133202
"""Transform the data."""
134203
return X

0 commit comments

Comments
 (0)