|
24 | 24 | from skada.deep.modules import ToyModule2D
|
25 | 25 |
|
26 | 26 |
|
| 27 | +def test_domainawaremodule_features_differ_between_domains(): |
| 28 | + num_features = 10 |
| 29 | + module = ToyModule2D(num_features=num_features) |
| 30 | + module.eval() |
| 31 | + |
| 32 | + n_samples = 20 |
| 33 | + dataset = make_shifted_datasets( |
| 34 | + n_samples_source=n_samples, |
| 35 | + n_samples_target=n_samples, |
| 36 | + shift="concept_drift", |
| 37 | + noise=0.1, |
| 38 | + random_state=42, |
| 39 | + return_dataset=True, |
| 40 | + ) |
| 41 | + |
| 42 | + # Prepare data |
| 43 | + X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"]) |
| 44 | + X = X.astype(np.float32) |
| 45 | + sample_domain = np.array(sample_domain) |
| 46 | + |
| 47 | + # Convert to torch tensors |
| 48 | + X_tensor = torch.tensor(X) |
| 49 | + sample_domain_tensor = torch.tensor(sample_domain) |
| 50 | + |
| 51 | + # Create an instance of DomainAwareModule |
| 52 | + domain_module = DomainAwareModule(module, layer_name="dropout") |
| 53 | + |
| 54 | + # Run forward pass |
| 55 | + with torch.no_grad(): |
| 56 | + output = domain_module( |
| 57 | + X_tensor, |
| 58 | + sample_domain=sample_domain_tensor, |
| 59 | + is_fit=True, |
| 60 | + return_features=True, |
| 61 | + ) |
| 62 | + |
| 63 | + # Unpack output |
| 64 | + y_pred, domain_pred, features, sample_domain_output = output |
| 65 | + |
| 66 | + # Separate features for source and target domains |
| 67 | + source_mask = sample_domain_tensor >= 0 |
| 68 | + target_mask = sample_domain_tensor < 0 |
| 69 | + features_s = features[source_mask] |
| 70 | + features_t = features[target_mask] |
| 71 | + |
| 72 | + # Ensure we have features from both domains |
| 73 | + assert features_s.size(0) > 0, "No source domain features extracted." |
| 74 | + assert features_t.size(0) > 0, "No target domain features extracted." |
| 75 | + |
| 76 | + # Compute mean features for source and target |
| 77 | + mean_features_s = features_s.mean(dim=0) |
| 78 | + mean_features_t = features_t.mean(dim=0) |
| 79 | + |
| 80 | + # Check that the mean features are different |
| 81 | + difference = torch.abs(mean_features_s - mean_features_t) |
| 82 | + max_difference = difference.max().item() |
| 83 | + |
| 84 | + assert ( |
| 85 | + max_difference > 0.1 |
| 86 | + ), "Features of source and target domains are too similar." |
| 87 | + |
| 88 | + |
27 | 89 | def test_domainawaretraining():
|
28 | 90 | module = ToyModule2D()
|
29 | 91 | module.eval()
|
|
0 commit comments