@@ -204,6 +204,7 @@ def cdd_loss(
204
204
sigmas = None ,
205
205
distance_threshold = 0.5 ,
206
206
class_threshold = 3 ,
207
+ eps = 1e-7 ,
207
208
):
208
209
"""Define the contrastive domain discrepancy loss based on [33]_.
209
210
@@ -225,6 +226,8 @@ def cdd_loss(
225
226
to far from the centroids.
226
227
class_threshold : int, optional (default=3)
227
228
Minimum number of samples in a class to be considered for the loss.
229
+ eps : float, default=1e-7
230
+ Small constant added to median distance calculation for numerical stability.
228
231
229
232
Returns
230
233
-------
@@ -240,31 +243,34 @@ def cdd_loss(
240
243
"""
241
244
n_classes = len (y_s .unique ())
242
245
243
- # Use pre-computed cluster_labels_t
246
+ # Use pre-computed target_kmeans
244
247
if target_kmeans is None :
245
- warnings .warn (
246
- "Source centroids are not computed for the whole training set, "
247
- "computing them on the current batch set."
248
- )
248
+ with torch .no_grad ():
249
+ warnings .warn (
250
+ "Source centroids are not computed for the whole training set, "
251
+ "computing them on the current batch set."
252
+ )
249
253
250
- source_centroids = []
251
-
252
- for c in range (n_classes ):
253
- mask = y_s == c
254
- if mask .sum () > 0 :
255
- class_features = features_s [mask ]
256
- normalized_features = F .normalize (class_features , p = 2 , dim = 1 )
257
- centroid = normalized_features .sum (dim = 0 )
258
- source_centroids .append (centroid )
259
-
260
- # Use source centroids to initialize target clustering
261
- target_kmeans = SphericalKMeans (
262
- n_clusters = n_classes ,
263
- random_state = 0 ,
264
- centroids = source_centroids ,
265
- device = features_t .device ,
266
- )
267
- target_kmeans .fit (features_t )
254
+ source_centroids = []
255
+
256
+ for c in range (n_classes ):
257
+ mask = y_s == c
258
+ if mask .sum () > 0 :
259
+ class_features = features_s [mask ]
260
+ normalized_features = F .normalize (class_features , p = 2 , dim = 1 )
261
+ centroid = normalized_features .sum (dim = 0 )
262
+ source_centroids .append (centroid )
263
+
264
+ source_centroids = torch .stack (source_centroids )
265
+
266
+ # Use source centroids to initialize target clustering
267
+ target_kmeans = SphericalKMeans (
268
+ n_clusters = n_classes ,
269
+ random_state = 0 ,
270
+ centroids = source_centroids ,
271
+ device = features_t .device ,
272
+ )
273
+ target_kmeans .fit (features_t )
268
274
269
275
# Predict clusters for target samples
270
276
cluster_labels_t = target_kmeans .predict (features_t )
@@ -283,10 +289,11 @@ def cdd_loss(
283
289
mask_t = valid_classes [cluster_labels_t ]
284
290
features_t = features_t [mask_t ]
285
291
cluster_labels_t = cluster_labels_t [mask_t ]
286
-
287
292
# Define sigmas
288
293
if sigmas is None :
289
- median_pairwise_distance = torch .median (torch .cdist (features_s , features_s ))
294
+ median_pairwise_distance = (
295
+ torch .median (torch .cdist (features_s , features_s )) + eps
296
+ )
290
297
sigmas = (
291
298
torch .tensor ([2 ** (- 8 ) * 2 ** (i * 1 / 2 ) for i in range (33 )]).to (
292
299
features_s .device
@@ -299,26 +306,43 @@ def cdd_loss(
299
306
# Compute CDD
300
307
intraclass = 0
301
308
interclass = 0
302
-
303
309
for c1 in range (n_classes ):
304
310
for c2 in range (c1 , n_classes ):
305
311
if valid_classes [c1 ] and valid_classes [c2 ]:
306
312
# Compute e1
307
313
kernel_ss = _gaussian_kernel (features_s , features_s , sigmas )
308
314
mask_c1_c1 = (y_s == c1 ).float ()
309
- e1 = (kernel_ss * mask_c1_c1 ).sum () / (mask_c1_c1 .sum () ** 2 )
315
+
316
+ # e1 measure the intra-class domain discrepancy
317
+ # Thus if mask_c1_c1.sum() = 0 --> e1 = 0
318
+ if mask_c1_c1 .sum () > 0 :
319
+ e1 = (kernel_ss * mask_c1_c1 ).sum () / (mask_c1_c1 .sum () ** 2 )
320
+ else :
321
+ e1 = 0
310
322
311
323
# Compute e2
312
324
kernel_tt = _gaussian_kernel (features_t , features_t , sigmas )
313
325
mask_c2_c2 = (cluster_labels_t == c2 ).float ()
314
- e2 = (kernel_tt * mask_c2_c2 ).sum () / (mask_c2_c2 .sum () ** 2 )
326
+
327
+ # e2 measure the intra-class domain discrepancy
328
+ # Thus if mask_c2_c2.sum() = 0 --> e2 = 0
329
+ if mask_c2_c2 .sum () > 0 :
330
+ e2 = (kernel_tt * mask_c2_c2 ).sum () / (mask_c2_c2 .sum () ** 2 )
331
+ else :
332
+ e2 = 0
315
333
316
334
# Compute e3
317
335
kernel_st = _gaussian_kernel (features_s , features_t , sigmas )
318
336
mask_c1 = (y_s == c1 ).float ().unsqueeze (1 )
319
337
mask_c2 = (cluster_labels_t == c2 ).float ().unsqueeze (0 )
320
338
mask_c1_c2 = mask_c1 * mask_c2
321
- e3 = (kernel_st * mask_c1_c2 ).sum () / (mask_c1_c2 .sum () ** 2 )
339
+
340
+ # e3 measure the inter-class domain discrepancy
341
+ # Thus if mask_c1_c2.sum() = 0 --> e3 = 0
342
+ if mask_c1_c2 .sum () > 0 :
343
+ e3 = (kernel_st * mask_c1_c2 ).sum () / (mask_c1_c2 .sum () ** 2 )
344
+ else :
345
+ e3 = 0
322
346
323
347
if c1 == c2 :
324
348
intraclass += e1 + e2 - 2 * e3
0 commit comments