Skip to content

Commit 3647fd9

Browse files
[MRG] In DEV, reshape features to 2D instead of input (#226)
* In DEV, reshape features to 2D instead of input * Add tests for dev scorer * Remove raise arg in cv * comments * improve docstring * use da_dataset fixture * automatic number of channels, ... * modify ToyCNN so that features are more than 2D * rm not necessary reshapes * float32 * change test name * check features are actually more than 2D * spaces * rm from test_scorer_with_nd_input --------- Co-authored-by: Antoine Collas <[email protected]>
1 parent 85550f8 commit 3647fd9

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed

skada/deep/modules.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def forward(
4141

4242

4343
class ToyCNN(nn.Module):
44-
"""Toy CNN for examples and tests.
44+
"""Toy CNN for examples and tests on classification tasks.
45+
46+
Made for 2D data (e.g. time series) with shape (batch_size, n_channels, input_size).
4547
4648
Parameters
4749
----------
@@ -63,15 +65,31 @@ def __init__(
6365
self.feature_extractor = nn.Sequential(
6466
nn.Conv1d(n_channels, out_channels, kernel_size),
6567
nn.ReLU(),
66-
nn.AvgPool1d(kernel_size),
6768
)
6869
self.num_features = self._num_features(n_channels, input_size)
69-
self.fc = nn.Linear(self.num_features, n_classes)
70+
self.fc = nn.Sequential(
71+
nn.AdaptiveAvgPool1d(1),
72+
nn.Flatten(start_dim=1),
73+
nn.Linear(out_channels, n_classes),
74+
)
7075

7176
def forward(self, x, sample_weight=None):
72-
"""XXX add docstring here."""
77+
"""Forward pass of the network.
78+
79+
Parameters
80+
----------
81+
x : torch.Tensor
82+
Input tensor of shape (batch_size, n_channels, input_size).
83+
sample_weight : torch.Tensor, optional
84+
Sample weights for the loss computation of shape (batch_size,).
85+
86+
Returns
87+
-------
88+
torch.Tensor
89+
Output tensor of shape (batch_size, n_classes).
90+
"""
7391
x = self.feature_extractor(x)
74-
x = self.fc(x.flatten(start_dim=1))
92+
x = self.fc(x)
7593
return x
7694

7795
def _num_features(self, n_channels, input_size):

skada/deep/tests/test_deep_scorer.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
import numpy as np
66
import pytest
7+
import torch
78
from sklearn.model_selection import ShuffleSplit, cross_validate
89
from sklearn.preprocessing import StandardScaler
910

10-
from skada import make_da_pipeline
11+
from skada import make_da_pipeline, source_target_split
1112
from skada.deep import DeepCoral
12-
from skada.deep.modules import ToyModule2D
13+
from skada.deep.modules import ToyCNN, ToyModule2D
1314
from skada.metrics import (
1415
CircularValidation,
1516
DeepEmbeddedValidation,
@@ -87,7 +88,47 @@ def test_generic_scorer(scorer, da_dataset):
8788
cv=cv,
8889
params={"sample_domain": sample_domain},
8990
scoring=scorer,
90-
error_score="raise",
91+
)["test_score"]
92+
assert scores.shape[0] == 3, "evaluate 3 splits"
93+
assert np.all(~np.isnan(scores)), "all scores are computed"
94+
95+
96+
def test_dev_cnn_features_nd(da_dataset):
97+
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
98+
X = np.repeat(X[..., np.newaxis], repeats=5, axis=-1) # Make it batched 2D data
99+
X = X.astype(np.float32)
100+
101+
scorer = DeepEmbeddedValidation()
102+
_, n_channels, input_size = X.shape
103+
y_source, _ = source_target_split(y, sample_domain=sample_domain)
104+
n_classes = len(np.unique(y_source))
105+
module = ToyCNN(
106+
n_channels=n_channels,
107+
input_size=input_size,
108+
n_classes=n_classes,
109+
kernel_size=3,
110+
out_channels=2,
111+
)
112+
# Assert features more than 2D
113+
assert module.feature_extractor(torch.tensor(X)).ndim > 2
114+
115+
net = DeepCoral(
116+
module,
117+
reg=1,
118+
layer_name="feature_extractor",
119+
batch_size=10,
120+
max_epochs=10,
121+
train_split=None,
122+
)
123+
124+
cv = ShuffleSplit(n_splits=3, test_size=0.3, random_state=0)
125+
scores = cross_validate(
126+
net,
127+
X,
128+
y,
129+
cv=cv,
130+
params={"sample_domain": sample_domain},
131+
scoring=scorer,
91132
)["test_score"]
92133
assert scores.shape[0] == 3, "evaluate 3 splits"
93134
assert np.all(~np.isnan(scores)), "all scores are computed"

skada/metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def identity(x):
412412
transformer = identity
413413

414414
X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True)
415-
X = X.reshape(X.shape[0], -1)
416415
source_idx = extract_source_indices(sample_domain)
417416
rng = check_random_state(self.random_state)
418417
X_train, X_val, _, y_val, _, sample_domain_val = train_test_split(

skada/tests/test_scorer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def test_mixval_scorer_regression(da_reg_dataset):
316316
ImportanceWeightedScorer(),
317317
PredictionEntropyScorer(),
318318
SoftNeighborhoodDensity(),
319-
DeepEmbeddedValidation(),
320319
CircularValidation(),
321320
MixValScorer(alpha=0.55, random_state=42),
322321
],

0 commit comments

Comments
 (0)