Skip to content

Commit f4f68b1

Browse files
authored
[MRG] Handle edge cases for CAN (#269)
* Handle edge cases for CAN * Prevent computing grad in spherical kmeans + use batches to compute cosine_simiarities * Add eps as an arg
1 parent 781ef0d commit f4f68b1

File tree

5 files changed

+188
-114
lines changed

5 files changed

+188
-114
lines changed

skada/deep/_divergence.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class CANLoss(BaseDALoss):
205205
If None, uses sigmas proposed in [1]_.
206206
target_kmeans : sklearn KMeans instance, default=None,
207207
Pre-computed target KMeans clustering model.
208+
eps : float, default=1e-7
209+
Small constant added to median distance calculation for numerical stability.
208210
209211
References
210212
----------
@@ -220,12 +222,14 @@ def __init__(
220222
class_threshold=3,
221223
sigmas=None,
222224
target_kmeans=None,
225+
eps=1e-7,
223226
):
224227
super().__init__()
225228
self.distance_threshold = distance_threshold
226229
self.class_threshold = class_threshold
227230
self.sigmas = sigmas
228231
self.target_kmeans = target_kmeans
232+
self.eps = eps
229233

230234
def forward(
231235
self,
@@ -245,6 +249,7 @@ def forward(
245249
target_kmeans=self.target_kmeans,
246250
distance_threshold=self.distance_threshold,
247251
class_threshold=self.class_threshold,
252+
eps=self.eps,
248253
)
249254

250255
return loss

skada/deep/callbacks.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,36 @@ def on_epoch_begin(self, net, dataset_train=None, **kwargs):
4040

4141
X_t = X["X"][X["sample_domain"] < 0]
4242

43-
features_s = net.predict_features(X_s)
44-
features_t = net.predict_features(X_t)
45-
46-
features_s = torch.tensor(features_s, device=net.device)
47-
y_s = torch.tensor(y_s, device=net.device)
48-
49-
features_t = torch.tensor(features_t, device=net.device)
50-
51-
n_classes = len(y_s.unique())
52-
source_centroids = []
53-
54-
for c in range(n_classes):
55-
mask = y_s == c
56-
if mask.sum() > 0:
57-
class_features = features_s[mask]
58-
normalized_features = F.normalize(class_features, p=2, dim=1)
59-
centroid = normalized_features.mean(dim=0)
60-
source_centroids.append(centroid)
61-
62-
source_centroids = torch.stack(source_centroids)
63-
64-
# Use source centroids to initialize target clustering
65-
target_kmeans = SphericalKMeans(
66-
n_clusters=n_classes,
67-
random_state=0,
68-
centroids=source_centroids,
69-
device=features_t.device,
70-
)
71-
target_kmeans.fit(features_t)
43+
# Disable gradient computation for feature extraction
44+
with torch.no_grad():
45+
features_s = net.predict_features(X_s)
46+
features_t = net.predict_features(X_t)
47+
48+
features_s = torch.tensor(features_s, device=net.device)
49+
y_s = torch.tensor(y_s, device=net.device)
50+
51+
features_t = torch.tensor(features_t, device=net.device)
52+
53+
n_classes = len(y_s.unique())
54+
source_centroids = []
55+
56+
for c in range(n_classes):
57+
mask = y_s == c
58+
if mask.sum() > 0:
59+
class_features = features_s[mask]
60+
normalized_features = F.normalize(class_features, p=2, dim=1)
61+
centroid = normalized_features.sum(dim=0)
62+
source_centroids.append(centroid)
63+
64+
source_centroids = torch.stack(source_centroids)
65+
66+
# Use source centroids to initialize target clustering
67+
target_kmeans = SphericalKMeans(
68+
n_clusters=n_classes,
69+
random_state=0,
70+
centroids=source_centroids,
71+
device=features_t.device,
72+
)
73+
target_kmeans.fit(features_t)
7274

7375
net.criterion__adapt_criterion.target_kmeans = target_kmeans

skada/deep/losses.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def cdd_loss(
204204
sigmas=None,
205205
distance_threshold=0.5,
206206
class_threshold=3,
207+
eps=1e-7,
207208
):
208209
"""Define the contrastive domain discrepancy loss based on [33]_.
209210
@@ -225,6 +226,8 @@ def cdd_loss(
225226
to far from the centroids.
226227
class_threshold : int, optional (default=3)
227228
Minimum number of samples in a class to be considered for the loss.
229+
eps : float, default=1e-7
230+
Small constant added to median distance calculation for numerical stability.
228231
229232
Returns
230233
-------
@@ -240,31 +243,34 @@ def cdd_loss(
240243
"""
241244
n_classes = len(y_s.unique())
242245

243-
# Use pre-computed cluster_labels_t
246+
# Use pre-computed target_kmeans
244247
if target_kmeans is None:
245-
warnings.warn(
246-
"Source centroids are not computed for the whole training set, "
247-
"computing them on the current batch set."
248-
)
248+
with torch.no_grad():
249+
warnings.warn(
250+
"Source centroids are not computed for the whole training set, "
251+
"computing them on the current batch set."
252+
)
249253

250-
source_centroids = []
251-
252-
for c in range(n_classes):
253-
mask = y_s == c
254-
if mask.sum() > 0:
255-
class_features = features_s[mask]
256-
normalized_features = F.normalize(class_features, p=2, dim=1)
257-
centroid = normalized_features.sum(dim=0)
258-
source_centroids.append(centroid)
259-
260-
# Use source centroids to initialize target clustering
261-
target_kmeans = SphericalKMeans(
262-
n_clusters=n_classes,
263-
random_state=0,
264-
centroids=source_centroids,
265-
device=features_t.device,
266-
)
267-
target_kmeans.fit(features_t)
254+
source_centroids = []
255+
256+
for c in range(n_classes):
257+
mask = y_s == c
258+
if mask.sum() > 0:
259+
class_features = features_s[mask]
260+
normalized_features = F.normalize(class_features, p=2, dim=1)
261+
centroid = normalized_features.sum(dim=0)
262+
source_centroids.append(centroid)
263+
264+
source_centroids = torch.stack(source_centroids)
265+
266+
# Use source centroids to initialize target clustering
267+
target_kmeans = SphericalKMeans(
268+
n_clusters=n_classes,
269+
random_state=0,
270+
centroids=source_centroids,
271+
device=features_t.device,
272+
)
273+
target_kmeans.fit(features_t)
268274

269275
# Predict clusters for target samples
270276
cluster_labels_t = target_kmeans.predict(features_t)
@@ -283,10 +289,11 @@ def cdd_loss(
283289
mask_t = valid_classes[cluster_labels_t]
284290
features_t = features_t[mask_t]
285291
cluster_labels_t = cluster_labels_t[mask_t]
286-
287292
# Define sigmas
288293
if sigmas is None:
289-
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
294+
median_pairwise_distance = (
295+
torch.median(torch.cdist(features_s, features_s)) + eps
296+
)
290297
sigmas = (
291298
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
292299
features_s.device
@@ -299,26 +306,43 @@ def cdd_loss(
299306
# Compute CDD
300307
intraclass = 0
301308
interclass = 0
302-
303309
for c1 in range(n_classes):
304310
for c2 in range(c1, n_classes):
305311
if valid_classes[c1] and valid_classes[c2]:
306312
# Compute e1
307313
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas)
308314
mask_c1_c1 = (y_s == c1).float()
309-
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)
315+
316+
# e1 measure the intra-class domain discrepancy
317+
# Thus if mask_c1_c1.sum() = 0 --> e1 = 0
318+
if mask_c1_c1.sum() > 0:
319+
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)
320+
else:
321+
e1 = 0
310322

311323
# Compute e2
312324
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas)
313325
mask_c2_c2 = (cluster_labels_t == c2).float()
314-
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)
326+
327+
# e2 measure the intra-class domain discrepancy
328+
# Thus if mask_c2_c2.sum() = 0 --> e2 = 0
329+
if mask_c2_c2.sum() > 0:
330+
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)
331+
else:
332+
e2 = 0
315333

316334
# Compute e3
317335
kernel_st = _gaussian_kernel(features_s, features_t, sigmas)
318336
mask_c1 = (y_s == c1).float().unsqueeze(1)
319337
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0)
320338
mask_c1_c2 = mask_c1 * mask_c2
321-
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)
339+
340+
# e3 measure the inter-class domain discrepancy
341+
# Thus if mask_c1_c2.sum() = 0 --> e3 = 0
342+
if mask_c1_c2.sum() > 0:
343+
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)
344+
else:
345+
e3 = 0
322346

323347
if c1 == c2:
324348
intraclass += e1 + e2 - 2 * e3

skada/deep/tests/test_deep_divergence.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from skada.datasets import make_shifted_datasets
1515
from skada.deep import CAN, DAN, DeepCoral
16-
from skada.deep.losses import dan_loss
16+
from skada.deep.losses import cdd_loss, dan_loss
1717
from skada.deep.modules import ToyModule2D
1818

1919

@@ -199,6 +199,17 @@ def test_can_with_custom_callbacks():
199199
assert "ComputeSourceCentroids" in callback_classes
200200

201201

202+
def test_cdd_loss_edge_cases():
203+
# Test when median pairwise distance is 0
204+
features_s = torch.ones((4, 2)) # All features are identical
205+
features_t = torch.randn((4, 2))
206+
y_s = torch.tensor([0, 0, 1, 1]) # Two classes
207+
208+
# This should not raise any errors due to the eps we added
209+
loss = cdd_loss(y_s, features_s, features_t)
210+
assert not np.isnan(loss)
211+
212+
202213
def test_dan_loss_edge_cases():
203214
# Create identical source features to get median distance = 0
204215
features_s = torch.tensor([[1.0, 2.0], [1.0, 2.0]], dtype=torch.float32)

0 commit comments

Comments
 (0)