Skip to content

[MRG] Add CAN Method #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 24, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,4 @@ The library is distributed under the 3-Clause BSD license.

[32] Hu, D., Liang, J., Liew, J. H., Xue, C., Bai, S., & Wang, X. (2023). [Mixed Samples as Probes for Unsupervised Model Selection in Domain Adaptation](https://proceedings.neurips.cc/paper_files/paper/2023/file/7721f1fea280e9ffae528dc78c732576-Paper-Conference.pdf). Advances in Neural Information Processing Systems 36 (2024).


[33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). [Contrastive Adaptation Network for Unsupervised Domain Adaptation](https://arxiv.org/abs/1901.00976). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).
3 changes: 3 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Deep learning DA :py:mod:`skada.deep`:
DeepJDOTLoss
DANLoss
CDANLoss
CANLoss

.. autosummary::
:toctree: gen_modules/
Expand All @@ -155,10 +156,12 @@ Deep learning DA :py:mod:`skada.deep`:
dan_loss
deepcoral_loss
deepjdot_loss
cdd_loss
DeepCoral
DeepJDOT
DANN
CDAN
CAN



Expand Down
5 changes: 4 additions & 1 deletion skada/deep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Author: Theo Gnassounou <[email protected]>
# Remi Flamary <[email protected]>
# Yanis Lalou <[email protected]>
#
# License: BSD 3-Clause

Expand All @@ -14,7 +15,7 @@
"torch and skorch are required for importing skada.deep.* modules."
) from e

from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss
from ._baseline import SourceOnly, TargetOnly
Expand All @@ -35,6 +36,8 @@
'DANN',
'CDANLoss',
'CDAN',
'CANLoss',
'CAN',
'SourceOnly',
'TargetOnly',
]
124 changes: 123 additions & 1 deletion skada/deep/_divergence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Author: Theo Gnassounou <[email protected]>
# Remi Flamary <[email protected]>
# Yanis Lalou <[email protected]>
#
# License: BSD 3-Clause
import torch
Expand All @@ -12,7 +13,7 @@
DomainBalancedDataLoader,
)

from .losses import dan_loss, deepcoral_loss
from .losses import cdd_loss, dan_loss, deepcoral_loss


class DeepCoralLoss(BaseDALoss):
Expand Down Expand Up @@ -180,3 +181,124 @@ def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs):
**kwargs,
)
return net


class CANLoss(BaseDALoss):
"""Loss for Contrastive Adaptation Network (CAN)

This loss implements the contrastive domain discrepancy (CDD)
as described in [33].

Parameters
----------
distance_threshold : float, optional (default=0.5)
Distance threshold for discarding the samples that are
to far from the centroids.
class_threshold : int, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.

References
----------
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
Contrastive adaptation network for unsupervised domain adaptation.
In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition (pp. 4893-4902).
"""

def __init__(
self,
distance_threshold=0.5,
class_threshold=3,
sigmas=None,
):
super().__init__()
self.distance_threshold = distance_threshold
self.class_threshold = class_threshold
self.sigmas = sigmas

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
):
loss = cdd_loss(
y_s,
features_s,
features_t,
sigmas=self.sigmas,
distance_threshold=self.distance_threshold,
class_threshold=self.class_threshold,
)

return loss


def CAN(
module,
layer_name,
reg=1,
distance_threshold=0.5,
class_threshold=3,
sigmas=None,
base_criterion=None,
**kwargs,
):
"""Contrastive Adaptation Network (CAN) domain adaptation method.

From [33].

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for the adaptation.
reg : float, optional (default=1)
Regularization parameter for DA loss.
distance_threshold : float, optional (default=0.5)
Distance threshold for discarding the samples that are
to far from the centroids.
class_threshold : int, optional (default=3)
Minimum number of samples in a class to be considered for the loss.
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.

References
----------
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
Contrastive adaptation network for unsupervised domain adaptation.
In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition (pp. 4893-4902).
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=CANLoss(
distance_threshold=distance_threshold,
class_threshold=class_threshold,
sigmas=sigmas,
),
**kwargs,
)
return net
130 changes: 130 additions & 0 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import ot
import skorch # noqa: F401
import torch # noqa: F401
import torch.nn.functional as F
from torch.nn.functional import mse_loss

from skada.deep.base import BaseDALoss
from skada.deep.utils import SphericalKMeans


def deepcoral_loss(features, features_target, assume_centered=False):
Expand Down Expand Up @@ -189,6 +191,134 @@ def dan_loss(features_s, features_t, sigmas=None):
return loss


def cdd_loss(
y_s,
features_s,
features_t,
sigmas=None,
distance_threshold=0.5,
class_threshold=3,
):
"""Define the contrastive domain discrepancy loss based on [33]_.

Parameters
----------
y_s : tensor
labels of the source data used to compute the loss.
features_s : tensor
features of the source data used to compute the loss.
features_t : tensor
features of the target data used to compute the loss.
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.
distance_threshold : float, optional (default=0.5)
Distance threshold for discarding the samples that are
to far from the centroids.
class_threshold : int, optional (default=3)
Minimum number of samples in a class to be considered for the loss.

Returns
-------
loss : float
The loss of the method.

References
----------
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019).
Contrastive adaptation network for unsupervised domain adaptation.
In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition (pp. 4893-4902).
"""
n_classes = len(y_s.unique())

# Calculate source centroids
source_centroids = []
for c in range(n_classes):
mask = y_s == c
if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.mean(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper it seems to be only a sum no ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In spherical k-means paper:
image

source_centroids.append(centroid)

source_centroids = torch.stack(source_centroids)

# Use source centroids to initialize target clustering
target_kmeans = SphericalKMeans(
n_clusters=n_classes,
random_state=0,
centroids=source_centroids,
device=features_t.device,
)
target_kmeans.fit(features_t)

# Predict clusters for target samples
cluster_labels_t = target_kmeans.predict(features_t)

# Discard ambiguous target samples
similarities = F.cosine_similarity(
features_t.unsqueeze(1), target_kmeans.cluster_centers_.unsqueeze(0)
)
mask_t = 0.5 * (1 - similarities.max(dim=1)[0]) < distance_threshold
features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]

# Discard ambiguous classes
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes)
valid_classes = class_counts >= class_threshold
mask_t = valid_classes[cluster_labels_t]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see what this line is doing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. class_counts = torch.bincount(cluster_labels_t, minlength=n_classes) counts how many samples are in each cluster.
  2. valid_classes = class_counts >= class_threshold creates a boolean tensor where True indicates classes that have at least class_threshold samples.
  3. mask_t = valid_classes[cluster_labels_t] is using the cluster labels as indices into the valid_classes tensor. This create a boolean mask_t, where True` indicates samples that belong to classes with enough representation.

This part of the code corresponds to the Filter the ambiguous classes part of the paper pseudo algorithm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!

features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]

# Define sigmas
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you cannot use the mmd distance from DAN?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formula is not exactly the same as for the mmd since before computing each mean we apply a specific mask

if sigmas is None:
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
)
* median_pairwise_distance
)
else:
sigmas = torch.tensor(sigmas).to(features_s.device)

# Compute CDD
intraclass = 0
interclass = 0

for c1 in range(n_classes):
for c2 in range(c1, n_classes):
if valid_classes[c1] and valid_classes[c2]:
# Compute e1
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas)
mask_c1_c1 = (y_s == c1).float()
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2)

# Compute e2
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas)
mask_c2_c2 = (cluster_labels_t == c2).float()
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2)

# Compute e3
kernel_st = _gaussian_kernel(features_s, features_t, sigmas)
mask_c1 = (y_s == c1).float().unsqueeze(1)
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0)
mask_c1_c2 = mask_c1 * mask_c2
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2)

if c1 == c2:
intraclass += e1 + e2 - 2 * e3
else:
interclass += e1 + e2 - 2 * e3

cdd = (intraclass / len(valid_classes)) - (
interclass / (len(valid_classes) ** 2 - len(valid_classes))
)

return cdd


class TestLoss(BaseDALoss):
"""Test Loss to check the deep API"""

Expand Down
Loading
Loading