-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Add CAN Method #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add CAN Method #251
Changes from 2 commits
dd134c1
e588483
d2eb3ee
9355a67
b3f301f
f64873d
03f3cd5
efff28d
8efe6bd
3dfdc3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Remi Flamary <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
|
||
|
@@ -14,7 +15,7 @@ | |
"torch and skorch are required for importing skada.deep.* modules." | ||
) from e | ||
|
||
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN | ||
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss | ||
from ._optimal_transport import DeepJDOT, DeepJDOTLoss | ||
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss | ||
from ._baseline import SourceOnly, TargetOnly | ||
|
@@ -35,6 +36,8 @@ | |
'DANN', | ||
'CDANLoss', | ||
'CDAN', | ||
'CANLoss', | ||
'CAN', | ||
'SourceOnly', | ||
'TargetOnly', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Remi Flamary <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import torch | ||
|
@@ -12,7 +13,7 @@ | |
DomainBalancedDataLoader, | ||
) | ||
|
||
from .losses import dan_loss, deepcoral_loss | ||
from .losses import cdd_loss, dan_loss, deepcoral_loss | ||
|
||
|
||
class DeepCoralLoss(BaseDALoss): | ||
|
@@ -180,3 +181,124 @@ def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs): | |
**kwargs, | ||
) | ||
return net | ||
|
||
|
||
class CANLoss(BaseDALoss): | ||
"""Loss for Contrastive Adaptation Network (CAN) | ||
|
||
This loss implements the contrastive domain discrepancy (CDD) | ||
as described in [33]. | ||
|
||
Parameters | ||
---------- | ||
distance_threshold : float, optional (default=0.5) | ||
Distance threshold for discarding the samples that are | ||
to far from the centroids. | ||
class_threshold : int, optional (default=3) | ||
Minimum number of samples in a class to be considered for the loss. | ||
sigmas : array like, default=None, | ||
If array, sigmas used for the multi gaussian kernel. | ||
If None, uses sigmas proposed in [1]_. | ||
|
||
References | ||
---------- | ||
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). | ||
Contrastive adaptation network for unsupervised domain adaptation. | ||
In Proceedings of the IEEE/CVF Conference on Computer Vision | ||
and Pattern Recognition (pp. 4893-4902). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
distance_threshold=0.5, | ||
class_threshold=3, | ||
sigmas=None, | ||
): | ||
super().__init__() | ||
self.distance_threshold = distance_threshold | ||
self.class_threshold = class_threshold | ||
self.sigmas = sigmas | ||
|
||
def forward( | ||
self, | ||
y_s, | ||
y_pred_s, | ||
y_pred_t, | ||
domain_pred_s, | ||
domain_pred_t, | ||
features_s, | ||
features_t, | ||
): | ||
loss = cdd_loss( | ||
y_s, | ||
features_s, | ||
features_t, | ||
sigmas=self.sigmas, | ||
distance_threshold=self.distance_threshold, | ||
class_threshold=self.class_threshold, | ||
) | ||
|
||
return loss | ||
|
||
|
||
def CAN( | ||
module, | ||
layer_name, | ||
reg=1, | ||
distance_threshold=0.5, | ||
class_threshold=3, | ||
sigmas=None, | ||
base_criterion=None, | ||
**kwargs, | ||
): | ||
"""Contrastive Adaptation Network (CAN) domain adaptation method. | ||
|
||
From [33]. | ||
|
||
Parameters | ||
---------- | ||
module : torch module (class or instance) | ||
A PyTorch :class:`~torch.nn.Module`. | ||
layer_name : str | ||
The name of the module's layer whose outputs are | ||
collected during the training for the adaptation. | ||
reg : float, optional (default=1) | ||
Regularization parameter for DA loss. | ||
distance_threshold : float, optional (default=0.5) | ||
Distance threshold for discarding the samples that are | ||
to far from the centroids. | ||
class_threshold : int, optional (default=3) | ||
Minimum number of samples in a class to be considered for the loss. | ||
sigmas : array like, default=None, | ||
If array, sigmas used for the multi gaussian kernel. | ||
If None, uses sigmas proposed in [1]_. | ||
base_criterion : torch criterion (class) | ||
The base criterion used to compute the loss with source | ||
labels. If None, the default is `torch.nn.CrossEntropyLoss`. | ||
|
||
References | ||
---------- | ||
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). | ||
Contrastive adaptation network for unsupervised domain adaptation. | ||
In Proceedings of the IEEE/CVF Conference on Computer Vision | ||
and Pattern Recognition (pp. 4893-4902). | ||
""" | ||
if base_criterion is None: | ||
base_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
net = DomainAwareNet( | ||
module=DomainAwareModule, | ||
module__base_module=module, | ||
module__layer_name=layer_name, | ||
iterator_train=DomainBalancedDataLoader, | ||
criterion=DomainAwareCriterion, | ||
criterion__base_criterion=base_criterion, | ||
criterion__reg=reg, | ||
criterion__adapt_criterion=CANLoss( | ||
distance_threshold=distance_threshold, | ||
class_threshold=class_threshold, | ||
sigmas=sigmas, | ||
), | ||
**kwargs, | ||
) | ||
return net |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,11 @@ | |
import ot | ||
import skorch # noqa: F401 | ||
import torch # noqa: F401 | ||
import torch.nn.functional as F | ||
from torch.nn.functional import mse_loss | ||
|
||
from skada.deep.base import BaseDALoss | ||
from skada.deep.utils import SphericalKMeans | ||
|
||
|
||
def deepcoral_loss(features, features_target, assume_centered=False): | ||
|
@@ -189,6 +191,134 @@ def dan_loss(features_s, features_t, sigmas=None): | |
return loss | ||
|
||
|
||
def cdd_loss( | ||
y_s, | ||
features_s, | ||
features_t, | ||
sigmas=None, | ||
distance_threshold=0.5, | ||
class_threshold=3, | ||
): | ||
"""Define the contrastive domain discrepancy loss based on [33]_. | ||
|
||
Parameters | ||
---------- | ||
y_s : tensor | ||
labels of the source data used to compute the loss. | ||
features_s : tensor | ||
features of the source data used to compute the loss. | ||
features_t : tensor | ||
features of the target data used to compute the loss. | ||
sigmas : array like, default=None, | ||
If array, sigmas used for the multi gaussian kernel. | ||
If None, uses sigmas proposed in [1]_. | ||
distance_threshold : float, optional (default=0.5) | ||
Distance threshold for discarding the samples that are | ||
to far from the centroids. | ||
class_threshold : int, optional (default=3) | ||
Minimum number of samples in a class to be considered for the loss. | ||
|
||
Returns | ||
------- | ||
loss : float | ||
The loss of the method. | ||
|
||
References | ||
---------- | ||
.. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). | ||
Contrastive adaptation network for unsupervised domain adaptation. | ||
In Proceedings of the IEEE/CVF Conference on Computer Vision | ||
and Pattern Recognition (pp. 4893-4902). | ||
""" | ||
n_classes = len(y_s.unique()) | ||
|
||
# Calculate source centroids | ||
source_centroids = [] | ||
for c in range(n_classes): | ||
mask = y_s == c | ||
if mask.sum() > 0: | ||
class_features = features_s[mask] | ||
normalized_features = F.normalize(class_features, p=2, dim=1) | ||
centroid = normalized_features.mean(dim=0) | ||
source_centroids.append(centroid) | ||
|
||
source_centroids = torch.stack(source_centroids) | ||
|
||
# Use source centroids to initialize target clustering | ||
target_kmeans = SphericalKMeans( | ||
YanisLalou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n_clusters=n_classes, | ||
random_state=0, | ||
centroids=source_centroids, | ||
device=features_t.device, | ||
) | ||
target_kmeans.fit(features_t) | ||
|
||
# Predict clusters for target samples | ||
cluster_labels_t = target_kmeans.predict(features_t) | ||
|
||
# Discard ambiguous target samples | ||
similarities = F.cosine_similarity( | ||
YanisLalou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
features_t.unsqueeze(1), target_kmeans.cluster_centers_.unsqueeze(0) | ||
) | ||
mask_t = 0.5 * (1 - similarities.max(dim=1)[0]) < distance_threshold | ||
features_t = features_t[mask_t] | ||
cluster_labels_t = cluster_labels_t[mask_t] | ||
|
||
# Discard ambiguous classes | ||
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes) | ||
valid_classes = class_counts >= class_threshold | ||
mask_t = valid_classes[cluster_labels_t] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see what this line is doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This part of the code corresponds to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it! |
||
features_t = features_t[mask_t] | ||
cluster_labels_t = cluster_labels_t[mask_t] | ||
|
||
# Define sigmas | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you cannot use the mmd distance from DAN? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The formula is not exactly the same as for the mmd since before computing each mean we apply a specific mask |
||
if sigmas is None: | ||
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s)) | ||
sigmas = ( | ||
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to( | ||
features_s.device | ||
) | ||
* median_pairwise_distance | ||
) | ||
else: | ||
sigmas = torch.tensor(sigmas).to(features_s.device) | ||
|
||
# Compute CDD | ||
intraclass = 0 | ||
interclass = 0 | ||
|
||
for c1 in range(n_classes): | ||
for c2 in range(c1, n_classes): | ||
if valid_classes[c1] and valid_classes[c2]: | ||
# Compute e1 | ||
kernel_ss = _gaussian_kernel(features_s, features_s, sigmas) | ||
mask_c1_c1 = (y_s == c1).float() | ||
e1 = (kernel_ss * mask_c1_c1).sum() / (mask_c1_c1.sum() ** 2) | ||
|
||
# Compute e2 | ||
kernel_tt = _gaussian_kernel(features_t, features_t, sigmas) | ||
mask_c2_c2 = (cluster_labels_t == c2).float() | ||
e2 = (kernel_tt * mask_c2_c2).sum() / (mask_c2_c2.sum() ** 2) | ||
|
||
# Compute e3 | ||
kernel_st = _gaussian_kernel(features_s, features_t, sigmas) | ||
mask_c1 = (y_s == c1).float().unsqueeze(1) | ||
mask_c2 = (cluster_labels_t == c2).float().unsqueeze(0) | ||
mask_c1_c2 = mask_c1 * mask_c2 | ||
e3 = (kernel_st * mask_c1_c2).sum() / (mask_c1_c2.sum() ** 2) | ||
|
||
if c1 == c2: | ||
intraclass += e1 + e2 - 2 * e3 | ||
else: | ||
interclass += e1 + e2 - 2 * e3 | ||
|
||
cdd = (intraclass / len(valid_classes)) - ( | ||
interclass / (len(valid_classes) ** 2 - len(valid_classes)) | ||
) | ||
|
||
return cdd | ||
|
||
|
||
class TestLoss(BaseDALoss): | ||
"""Test Loss to check the deep API""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the paper it seems to be only a sum no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In spherical k-means paper:
