Skip to content

Commit 0b9e889

Browse files
[MRG] Fix order of feature acquisition for deep module (#235)
* fix order of feature acquisition * add test * linter --------- Co-authored-by: Antoine Collas <[email protected]>
1 parent 09771d3 commit 0b9e889

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

skada/deep/base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(
100100
else:
101101
domain_pred_s = None
102102
domain_pred_t = None
103-
103+
104104
if features is not None:
105105
features_s = features[source_idx]
106106
features_t = features[~source_idx]
@@ -312,21 +312,24 @@ def forward(
312312
# Pass sample_weight to base_module_
313313
if sample_weight is not None:
314314
sample_weight_s = sample_weight[source_idx]
315-
sample_weight_t = sample_weight[~source_idx]
316-
317315
y_pred_s = self.base_module_(X_s, sample_weight=sample_weight_s)
318-
319-
y_pred_t = self.base_module_(X_t, sample_weight=sample_weight_t)
320316
else:
321317
y_pred_s = self.base_module_(X_s)
322318

319+
if self.layer_name is not None:
320+
features_s = self.intermediate_layers[self.layer_name]
321+
else:
322+
features_s = None
323+
324+
if sample_weight is not None:
325+
sample_weight_t = sample_weight[~source_idx]
326+
y_pred_t = self.base_module_(X_t, sample_weight=sample_weight_t)
327+
else:
323328
y_pred_t = self.base_module_(X_t)
324329

325330
if self.layer_name is not None:
326-
features_s = self.intermediate_layers[self.layer_name]
327331
features_t = self.intermediate_layers[self.layer_name]
328332
else:
329-
features_s = None
330333
features_t = None
331334

332335
if self.domain_classifier_ is not None:

skada/deep/tests/test_deep_base.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,68 @@
2424
from skada.deep.modules import ToyModule2D
2525

2626

27+
def test_domainawaremodule_features_differ_between_domains():
28+
num_features = 10
29+
module = ToyModule2D(num_features=num_features)
30+
module.eval()
31+
32+
n_samples = 20
33+
dataset = make_shifted_datasets(
34+
n_samples_source=n_samples,
35+
n_samples_target=n_samples,
36+
shift="concept_drift",
37+
noise=0.1,
38+
random_state=42,
39+
return_dataset=True,
40+
)
41+
42+
# Prepare data
43+
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
44+
X = X.astype(np.float32)
45+
sample_domain = np.array(sample_domain)
46+
47+
# Convert to torch tensors
48+
X_tensor = torch.tensor(X)
49+
sample_domain_tensor = torch.tensor(sample_domain)
50+
51+
# Create an instance of DomainAwareModule
52+
domain_module = DomainAwareModule(module, layer_name="dropout")
53+
54+
# Run forward pass
55+
with torch.no_grad():
56+
output = domain_module(
57+
X_tensor,
58+
sample_domain=sample_domain_tensor,
59+
is_fit=True,
60+
return_features=True,
61+
)
62+
63+
# Unpack output
64+
y_pred, domain_pred, features, sample_domain_output = output
65+
66+
# Separate features for source and target domains
67+
source_mask = sample_domain_tensor >= 0
68+
target_mask = sample_domain_tensor < 0
69+
features_s = features[source_mask]
70+
features_t = features[target_mask]
71+
72+
# Ensure we have features from both domains
73+
assert features_s.size(0) > 0, "No source domain features extracted."
74+
assert features_t.size(0) > 0, "No target domain features extracted."
75+
76+
# Compute mean features for source and target
77+
mean_features_s = features_s.mean(dim=0)
78+
mean_features_t = features_t.mean(dim=0)
79+
80+
# Check that the mean features are different
81+
difference = torch.abs(mean_features_s - mean_features_t)
82+
max_difference = difference.max().item()
83+
84+
assert (
85+
max_difference > 0.1
86+
), "Features of source and target domains are too similar."
87+
88+
2789
def test_domainawaretraining():
2890
module = ToyModule2D()
2991
module.eval()

0 commit comments

Comments
 (0)