Skip to content

Commit 85550f8

Browse files
[MRG] Allow model selection cv to handle nd inputs (#225)
* Allow model selection cv to handle nd inputs * Add test_check_X_y_domain_multi_nd * Add test_cv_with_nd_dimensional_X * Remove unused comment * Add nd support to scorers * fix IW scorer when data are of more than 2 dimensions * test all scorers * repeat data to test scorers with multi dimensional data * test SupervisedScorer with multidimensional data * fix DEV when having multidimensional features * fix PredictionEntropyScorer when proba == 0 * Remove raise error in test_scorer_with_nd_input cv --------- Co-authored-by: Antoine Collas <[email protected]>
1 parent dccc59e commit 85550f8

File tree

5 files changed

+141
-16
lines changed

5 files changed

+141
-16
lines changed

skada/metrics.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _score(
7272
):
7373
scorer = check_scoring(estimator, self.scoring)
7474

75-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
75+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
7676
source_idx = extract_source_indices(sample_domain)
7777

7878
return self._sign * scorer(
@@ -136,9 +136,11 @@ def _fit(self, X_source, X_target):
136136
137137
Parameters
138138
----------
139-
X : array-like, shape (n_samples, n_features)
139+
X : array-like, shape (n_samples, *), where * is any number
140+
of dimensions of at least 1
140141
The source data.
141-
X_target : array-like, shape (n_samples, n_features)
142+
X_target : array-like, shape (n_samples, *), where * is any number
143+
of dimensions of at least 1
142144
The target data.
143145
144146
Returns
@@ -151,6 +153,8 @@ def _fit(self, X_source, X_target):
151153
weight_estimator = KernelDensity()
152154
self.weight_estimator_source_ = clone(weight_estimator)
153155
self.weight_estimator_target_ = clone(weight_estimator)
156+
X_source = X_source.reshape(X_source.shape[0], -1)
157+
X_target = X_target.reshape(X_target.shape[0], -1)
154158
self.weight_estimator_source_.fit(X_source)
155159
self.weight_estimator_target_.fit(X_target)
156160
return self
@@ -165,10 +169,12 @@ def _score(self, estimator, X, y, sample_domain=None, **params):
165169
f"The estimator {estimator!r} does not."
166170
)
167171

168-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
172+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
169173
X_source, X_target, y_source, _ = source_target_split(
170174
X, y, sample_domain=sample_domain
171175
)
176+
X_source = X_source.reshape(X_source.shape[0], -1)
177+
X_target = X_target.reshape(X_target.shape[0], -1)
172178
self._fit(X_source, X_target)
173179
ws = self.weight_estimator_source_.score_samples(X_source)
174180
wt = self.weight_estimator_target_.score_samples(X_source)
@@ -239,7 +245,7 @@ def _score(self, estimator, X, y, sample_domain=None, **params):
239245
"The estimator passed should have a 'predict_proba' method. "
240246
f"The estimator {estimator!r} does not."
241247
)
242-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
248+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
243249
source_idx = extract_source_indices(sample_domain)
244250
proba = estimator.predict_proba(
245251
X[~source_idx], sample_domain=sample_domain[~source_idx], **params
@@ -250,7 +256,9 @@ def _score(self, estimator, X, y, sample_domain=None, **params):
250256
)
251257
else:
252258
log_proba = np.log(proba + 1e-7)
259+
infty_mask = np.isneginf(log_proba)
253260
entropy_per_sample = -proba * log_proba
261+
entropy_per_sample[infty_mask] = 0 # x*log(x) -> 0 as x -> 0
254262
if self.reduction == "none":
255263
return self._sign * entropy_per_sample
256264
elif self.reduction == "sum":
@@ -298,7 +306,7 @@ def _score(self, estimator, X, y, sample_domain=None, **params):
298306
f"The estimator {estimator!r} does not."
299307
)
300308

301-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
309+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
302310
source_idx = extract_source_indices(sample_domain)
303311
proba = estimator.predict_proba(
304312
X[~source_idx], sample_domain=sample_domain[~source_idx], **params
@@ -403,7 +411,8 @@ def identity(x):
403411
# We use the input data as features
404412
transformer = identity
405413

406-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
414+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
415+
X = X.reshape(X.shape[0], -1)
407416
source_idx = extract_source_indices(sample_domain)
408417
rng = check_random_state(self.random_state)
409418
X_train, X_val, _, y_val, _, sample_domain_val = train_test_split(
@@ -550,7 +559,7 @@ def _score(self, estimator, X, y, sample_domain=None):
550559
float
551560
The computed score.
552561
"""
553-
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
562+
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
554563

555564
try:
556565
_check_y_masking(y)
@@ -707,7 +716,7 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params):
707716
"""
708717
scorer = check_scoring(estimator, self.scoring)
709718

710-
X, _, sample_domain = check_X_y_domain(X, y, sample_domain)
719+
X, _, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
711720
source_idx = extract_source_indices(sample_domain)
712721
X_target = X[~source_idx]
713722

skada/model_selection.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def split(self, X, y=None, sample_domain=None):
8080
"""
8181
# automatically derive sample_domain if it is not provided
8282
X, sample_domain = check_X_domain(
83-
X, sample_domain, allow_auto_sample_domain=True
83+
X,
84+
sample_domain,
85+
allow_auto_sample_domain=True,
86+
allow_nd=True,
8487
)
8588
X, y, sample_domain = indexable(X, y, sample_domain)
8689
yield from self._iter_indices(X, y, sample_domain=sample_domain)
@@ -138,7 +141,7 @@ def __init__(
138141
self._default_test_size = 0.1
139142

140143
def _iter_indices(self, X, y=None, sample_domain=None):
141-
X, sample_domain = check_X_domain(X, sample_domain)
144+
X, sample_domain = check_X_domain(X, sample_domain, allow_nd=True)
142145
indices = extract_source_indices(sample_domain)
143146
(source_idx,) = np.where(indices)
144147
(target_idx,) = np.where(~indices)
@@ -225,7 +228,10 @@ def split(self, X, y=None, sample_domain=None):
225228
"""
226229
# automatically derive sample_domain if it is not provided
227230
X, sample_domain = check_X_domain(
228-
X, sample_domain, allow_auto_sample_domain=True
231+
X,
232+
sample_domain,
233+
allow_auto_sample_domain=True,
234+
allow_nd=True,
229235
)
230236
X, y, sample_domain = indexable(X, y, sample_domain)
231237
# xxx(okachaiev): make sure all domains are given both as sources and targets
@@ -253,7 +259,7 @@ def split(self, X, y=None, sample_domain=None):
253259
yield split_idx[train_idx], split_idx[test_idx]
254260

255261
def _iter_indices(self, X, y=None, sample_domain=None):
256-
X, sample_domain = check_X_domain(X, sample_domain)
262+
X, sample_domain = check_X_domain(X, sample_domain, allow_nd=True)
257263
indices = extract_source_indices(sample_domain)
258264
(source_idx,) = np.where(indices)
259265
(target_idx,) = np.where(~indices)
@@ -383,7 +389,10 @@ def _iter_indices(self, X, y, sample_domain=None):
383389
# License: BSD
384390

385391
X, sample_domain = check_X_domain(
386-
X, sample_domain, allow_auto_sample_domain=True
392+
X,
393+
sample_domain,
394+
allow_auto_sample_domain=True,
395+
allow_nd=True,
387396
)
388397
X, y, sample_domain = indexable(X, y, sample_domain)
389398

@@ -532,7 +541,7 @@ def __init__(
532541
raise ValueError("under_sampling should be between 0 and 1")
533542

534543
def _iter_indices(self, X, y=None, sample_domain=None):
535-
X, sample_domain = check_X_domain(X, sample_domain)
544+
X, sample_domain = check_X_domain(X, sample_domain, allow_nd=True)
536545
domain_source_idx_dict, domain_target_idx_dict = extract_domains_indices(
537546
sample_domain, split_source_target=True
538547
)

skada/tests/test_cv.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,38 @@ def test_stratified_domain_shuffle_split_exceptions():
166166
splitter = StratifiedDomainShuffleSplit(n_splits=4, test_size=0.1, random_state=0)
167167
with pytest.raises(ValueError):
168168
next(iter(splitter.split(X, y, sample_domain)))
169+
170+
171+
@pytest.mark.parametrize(
172+
"cv",
173+
[
174+
(GroupShuffleSplit(n_splits=2, test_size=0.3, random_state=0)),
175+
(GroupKFold(n_splits=2)),
176+
(LeaveOneGroupOut()),
177+
(SourceTargetShuffleSplit(n_splits=2, test_size=0.3, random_state=0)),
178+
(
179+
DomainShuffleSplit(
180+
n_splits=2, test_size=0.3, random_state=0, under_sampling=1
181+
)
182+
),
183+
(StratifiedDomainShuffleSplit(n_splits=2, test_size=0.3, random_state=0)),
184+
],
185+
)
186+
def test_cv_with_nd_dimensional_X(da_dataset, cv):
187+
X, y, sample_domain = da_dataset.pack_lodo()
188+
# Transform X from 2D to 3D
189+
X = X.reshape(X.shape[0], -1, 1) # Reshape to (n_samples, n_features, 1)
190+
assert X.ndim == 3, "X should be 3-dimensional after reshaping"
191+
192+
splits = list(cv.split(X, y, sample_domain))
193+
194+
for train, test in splits:
195+
assert isinstance(train, np.ndarray) and isinstance(
196+
test, np.ndarray
197+
), "split indices should be numpy arrays"
198+
assert len(train) + len(test) == len(
199+
X
200+
), "train and test indices should cover all samples"
201+
assert (
202+
len(np.intersect1d(train, test)) == 0
203+
), "train and test indices should not overlap"

skada/tests/test_scorer.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import pytest
10-
from sklearn.dummy import DummyRegressor
10+
from sklearn.dummy import DummyClassifier, DummyRegressor
1111
from sklearn.linear_model import LinearRegression, LogisticRegression
1212
from sklearn.metrics import mean_squared_error
1313
from sklearn.model_selection import ShuffleSplit, cross_validate
@@ -307,3 +307,47 @@ def test_mixval_scorer_regression(da_reg_dataset):
307307
scorer = MixValScorer(alpha=0.55, random_state=42)
308308
with pytest.raises(ValueError):
309309
scorer(estimator, X, y, sample_domain)
310+
311+
312+
@pytest.mark.parametrize(
313+
"scorer",
314+
[
315+
SupervisedScorer(),
316+
ImportanceWeightedScorer(),
317+
PredictionEntropyScorer(),
318+
SoftNeighborhoodDensity(),
319+
DeepEmbeddedValidation(),
320+
CircularValidation(),
321+
MixValScorer(alpha=0.55, random_state=42),
322+
],
323+
)
324+
def test_scorer_with_nd_input(scorer, da_dataset):
325+
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
326+
327+
# Repeat data to have a 3D input
328+
X_3d = np.repeat(X[:, :, None], repeats=3, axis=2)
329+
330+
estimator = make_da_pipeline(
331+
DummyClassifier(strategy="stratified", random_state=42)
332+
.set_fit_request(sample_weight=True)
333+
.set_score_request(sample_weight=True),
334+
)
335+
cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0)
336+
if isinstance(scorer, SupervisedScorer):
337+
_, target_labels, _ = da_dataset.pack(
338+
as_sources=["s"], as_targets=["t"], train=False
339+
)
340+
params = {"sample_domain": sample_domain, "target_labels": target_labels}
341+
else:
342+
params = {"sample_domain": sample_domain}
343+
scores = cross_validate(
344+
estimator,
345+
X_3d,
346+
y,
347+
cv=cv,
348+
params=params,
349+
scoring=scorer,
350+
)["test_score"]
351+
352+
assert scores.shape[0] == 3, "evaluate 3 splits"
353+
assert np.all(~np.isnan(scores)), "all scores are computed"

skada/tests/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,34 @@ def test_check_X_allow_exceptions():
337337
)
338338

339339

340+
def test_check_X_domain_multi_nd():
341+
# Create a 3D array (10 samples, 2 features, 3 channels)
342+
X = np.random.rand(10, 2, 3)
343+
sample_domain = np.array([1] * 5 + [-1] * 5)
344+
345+
# Test with allow_nd=True
346+
check_X_domain(X, sample_domain=sample_domain, allow_nd=True)
347+
348+
# Test with allow_nd=False (should raise an error)
349+
with pytest.raises(ValueError, match="Found array with dim 3. None expected <= 2."):
350+
check_X_domain(X, sample_domain=sample_domain, allow_nd=False)
351+
352+
353+
def test_check_X_y_domain_multi_nd():
354+
# Create a 3D array for X (10 samples, 2 features, 3 channels)
355+
X = np.random.rand(10, 2, 3)
356+
# Create a 2D array for y (10 samples, 2 outputs)
357+
y = np.random.rand(10, 2)
358+
sample_domain = np.array([1] * 5 + [-1] * 5)
359+
360+
# Test with allow_nd=True
361+
check_X_y_domain(X, y, sample_domain=sample_domain, allow_nd=True)
362+
363+
# Test with allow_nd=False (should raise an error for X)
364+
with pytest.raises(ValueError, match="Found array with dim 3. None expected <= 2."):
365+
check_X_y_domain(X, y, sample_domain=sample_domain, allow_nd=False)
366+
367+
340368
def test_extract_source_indices():
341369
n_samples_source = 50
342370
n_samples_target = 20

0 commit comments

Comments
 (0)