Skip to content

[MRG] Handle edge cases for CAN #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 31 additions & 29 deletions skada/deep/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,36 @@ def on_epoch_begin(self, net, dataset_train=None, **kwargs):

X_t = X["X"][X["sample_domain"] < 0]

features_s = net.predict_features(X_s)
features_t = net.predict_features(X_t)

features_s = torch.tensor(features_s, device=net.device)
y_s = torch.tensor(y_s, device=net.device)

features_t = torch.tensor(features_t, device=net.device)

n_classes = len(y_s.unique())
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.mean(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)
# Disable gradient computation for feature extraction
with torch.no_grad():
features_s = net.predict_features(X_s)
features_t = net.predict_features(X_t)

features_s = torch.tensor(features_s, device=net.device)
y_s = torch.tensor(y_s, device=net.device)

features_t = torch.tensor(features_t, device=net.device)

n_classes = len(y_s.unique())
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)

net.criterion__adapt_criterion.target_kmeans = target_kmeans
80 changes: 51 additions & 29 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,31 +236,34 @@ def cdd_loss(
"""
n_classes = len(y_s.unique())

# Use pre-computed cluster_labels_t
# Use pre-computed target_kmeans
if target_kmeans is None:
warnings.warn(
"Source centroids are not computed for the whole training set, "
"computing them on the current batch set."
)
with torch.no_grad():
warnings.warn(
"Source centroids are not computed for the whole training set, "
"computing them on the current batch set."
)

source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)
source_centroids = []

for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.sum(dim=0)
source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)

# Predict clusters for target samples
cluster_labels_t = target_kmeans.predict(features_t)
Expand All @@ -279,10 +282,12 @@ def cdd_loss(
mask_t = valid_classes[cluster_labels_t]
features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]

# Define sigmas
if sigmas is None:
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
eps = 1e-7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps as parameter

median_pairwise_distance = (
torch.median(torch.cdist(features_s, features_s)) + eps
)
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
Expand All @@ -295,26 +300,43 @@ def cdd_loss(
# Compute CDD
intraclass = 0
interclass = 0

for c1 in range(n_classes):
for c2 in range(c1, n_classes):
if valid_classes[c1] and valid_classes[c2]:
# Compute e1
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas)
mask_c1_c1 = (y_s == c1).float()
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)

# e1 measure the intra-class domain discrepancy
# Thus if mask_c1_c1.sum() = 0 --> e1 = 0
if mask_c1_c1.sum() > 0:
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)
else:
e1 = 0

# Compute e2
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas)
mask_c2_c2 = (cluster_labels_t == c2).float()
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)

# e2 measure the intra-class domain discrepancy
# Thus if mask_c2_c2.sum() = 0 --> e2 = 0
if mask_c2_c2.sum() > 0:
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)
else:
e2 = 0

# Compute e3
kernel_st = _gaussian_kernel(features_s, features_t, sigmas)
mask_c1 = (y_s == c1).float().unsqueeze(1)
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0)
mask_c1_c2 = mask_c1 * mask_c2
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)

# e3 measure the inter-class domain discrepancy
# Thus if mask_c1_c2.sum() = 0 --> e3 = 0
if mask_c1_c2.sum() > 0:
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)
else:
e3 = 0

if c1 == c2:
intraclass += e1 + e2 - 2 * e3
Expand Down
13 changes: 13 additions & 0 deletions skada/deep/tests/test_deep_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
pytest.importorskip("torch")

import numpy as np
import torch

from skada.datasets import make_shifted_datasets
from skada.deep import CAN, DAN, DeepCoral
from skada.deep.losses import cdd_loss
from skada.deep.modules import ToyModule2D


Expand Down Expand Up @@ -195,3 +197,14 @@ def test_can_with_custom_callbacks():
callback_classes = [cb.__class__.__name__ for cb in method.callbacks]
assert "EpochScoring" in callback_classes
assert "ComputeSourceCentroids" in callback_classes


def test_cdd_loss_edge_cases():
# Test when median pairwise distance is 0
features_s = torch.ones((4, 2)) # All features are identical
features_t = torch.randn((4, 2))
y_s = torch.tensor([0, 0, 1, 1]) # Two classes

# This should not raise any errors due to the eps we added
loss = cdd_loss(y_s, features_s, features_t)
assert not np.isnan(loss)
142 changes: 87 additions & 55 deletions skada/deep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def __init__(self, n_clusters=8, n_init=10, max_iter=300, tol=1e-4,
self.device = device

def _init_centroids(self, X):
# Randomly initialize centroids
n_samples = X.shape[0]
indices = torch.randperm(n_samples, device=self.device)[:self.n_clusters]
centroids = X[indices]
return centroids / torch.norm(centroids, dim=1, keepdim=True)
with torch.no_grad():
# Randomly initialize centroids
n_samples = X.shape[0]
indices = torch.randperm(n_samples, device=self.device)[:self.n_clusters]
centroids = X[indices]
return centroids / torch.norm(centroids, dim=1, keepdim=True)

def fit(self, X, y=None):
"""Compute spherical k-means clustering.
Expand All @@ -184,55 +185,56 @@ def fit(self, X, y=None):
self : object
Fitted estimator.
"""
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32, device=self.device)
else:
X = X.to(self.device)
with torch.no_grad():
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32, device=self.device)
else:
X = X.to(self.device)

# Normalize X
X = X / torch.norm(X, dim=1, keepdim=True)
# Normalize X
X = X / torch.norm(X, dim=1, keepdim=True)

best_inertia = None
best_centroids = None
best_n_iter = None
best_inertia = None
best_centroids = None
best_n_iter = None

for _ in range(self.n_init):
if self.centroids is None:
centroids = self._init_centroids(X)
else:
centroids = self.centroids.to(self.device)
for _ in range(self.n_init):
if self.centroids is None:
centroids = self._init_centroids(X)
else:
centroids = self.centroids.to(self.device)

for n_iter in range(self.max_iter):
# Assign samples to closest centroids
dissimilarities = self._compute_dissimilarities(X, centroids)
labels = torch.argmin(dissimilarities, dim=1)
for n_iter in range(self.max_iter):
# Assign samples to closest centroids
dissimilarities = self._compute_dissimilarities(X, centroids)
labels = torch.argmin(dissimilarities, dim=1)

# Update centroids
new_centroids = torch.zeros_like(centroids)
for k in range(self.n_clusters):
if torch.any(labels == k):
new_centroids[k] = X[labels == k].sum(dim=0)
# Update centroids
new_centroids = torch.zeros_like(centroids)
for k in range(self.n_clusters):
if torch.any(labels == k):
new_centroids[k] = X[labels == k].sum(dim=0)

# Check for convergence
if torch.allclose(centroids, new_centroids, atol=self.tol):
break
# Check for convergence
if torch.allclose(centroids, new_centroids, atol=self.tol):
break

centroids = new_centroids
centroids = new_centroids

# Compute inertia
dissimilarities = self._compute_dissimilarities(X, centroids[labels])
inertia = dissimilarities.sum().item()
# Compute inertia
dissimilarities = self._compute_dissimilarities(X, centroids[labels])
inertia = dissimilarities.sum().item()

if best_inertia is None or inertia < best_inertia:
best_inertia = inertia
best_centroids = centroids
best_n_iter = n_iter
if best_inertia is None or inertia < best_inertia:
best_inertia = inertia
best_centroids = centroids
best_n_iter = n_iter

self.cluster_centers_ = best_centroids
self.inertia_ = best_inertia
self.n_iter_ = best_n_iter + 1
self.cluster_centers_ = best_centroids
self.inertia_ = best_inertia
self.n_iter_ = best_n_iter + 1

return self
return self

def predict(self, X):
"""Predict the closest cluster each sample in X belongs to.
Expand All @@ -247,18 +249,48 @@ def predict(self, X):
labels : torch.Tensor of shape (n_samples,)
Index of the cluster each sample belongs to.
"""
check_is_fitted(self)
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32, device=self.device)
else:
X = X.to(self.device)
with torch.no_grad():
check_is_fitted(self)
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32, device=self.device)
else:
X = X.to(self.device)

# No need to normalize X as it is going
# to be normalized in cosine_similarity
# No need to normalize X as it is going
# to be normalized in cosine_similarity

dissimilarities = self._compute_dissimilarities(X, self.cluster_centers_)
return torch.argmin(dissimilarities, dim=1)
dissimilarities = self._compute_dissimilarities(X, self.cluster_centers_)
return torch.argmin(dissimilarities, dim=1)

def _compute_dissimilarities(self, X, centroids, batch_size=1024):
"""Compute dissimilarities between points and centroids in batches.

Parameters
----------
X : torch.Tensor of shape (n_samples, n_features)
Input data points
centroids : torch.Tensor of shape (n_clusters, n_features)
Cluster centroids
batch_size : int, default=1024
Size of batches to use for computation

def _compute_dissimilarities(self, X, centroids):
similarities = cosine_similarity(X.unsqueeze(1), centroids.unsqueeze(0), dim=2)
return 0.5 * (1 - similarities)
Returns
-------
dissimilarities : torch.Tensor of shape (n_samples, n_clusters)
Dissimilarity matrix between points and centroids
"""
n_samples = X.shape[0]
n_clusters = centroids.shape[0]
dissimilarities = torch.zeros(n_samples, n_clusters, device=self.device)

for i in range(0, n_samples, batch_size):
batch_end = min(i + batch_size, n_samples)
batch = X[i:batch_end]
similarities = cosine_similarity(
batch.unsqueeze(1),
centroids.unsqueeze(0),
dim=2
)
dissimilarities[i:batch_end] = 0.5 * (1 - similarities)

return dissimilarities
Loading