Skip to content

[MRG] Add CAN Method #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 24, 2024
Merged

[MRG] Add CAN Method #251

merged 10 commits into from
Oct 24, 2024

Conversation

YanisLalou
Copy link
Collaborator

@YanisLalou YanisLalou commented Oct 9, 2024

Paper: https://arxiv.org/pdf/1901.00976
Mostly eq 3-4-5 + paragraph 3.4

New Features:

  • CAN and CANLoss Implementation:
    • Added CANLoss class to skada/deep/_divergence.py to implement the contrastive domain discrepancy (CDD) loss.
    • Added CAN function to skada/deep/_divergence.py to implement the CAN domain adaptation method.

New Utilities:

  • SphericalKMeans:
    • Added SphericalKMeans class to skada/deep/utils.py for clustering using cosine similarity.

Testing:

  • New Tests for CAN:
    • Added tests for CAN in test_deep_divergence.py to ensure the new method works as expected.

Still needs to be done:

  • Double check the implementation

if mask.sum() > 0:
class_features = features_s[mask]
normalized_features = F.normalize(class_features, p=2, dim=1)
centroid = normalized_features.mean(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper it seems to be only a sum no ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In spherical k-means paper:
image

# Discard ambiguous classes
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes)
valid_classes = class_counts >= class_threshold
mask_t = valid_classes[cluster_labels_t]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see what this line is doing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. class_counts = torch.bincount(cluster_labels_t, minlength=n_classes) counts how many samples are in each cluster.
  2. valid_classes = class_counts >= class_threshold creates a boolean tensor where True indicates classes that have at least class_threshold samples.
  3. mask_t = valid_classes[cluster_labels_t] is using the cluster labels as indices into the valid_classes tensor. This create a boolean mask_t, where True` indicates samples that belong to classes with enough representation.

This part of the code corresponds to the Filter the ambiguous classes part of the paper pseudo algorithm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!

features_t = features_t[mask_t]
cluster_labels_t = cluster_labels_t[mask_t]

# Define sigmas
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you cannot use the mmd distance from DAN?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formula is not exactly the same as for the mmd since before computing each mean we apply a specific mask


for n_iter in range(self.max_iter):
# Assign samples to closest centroids
dissimilarities = self._compute_dissimilarities(X, centroids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a difference here with the function cosine_similarities de torch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In paper: cosine_dissimilarity is 0.5*(1 − cosine_similarity)

@tgnassou tgnassou merged commit 8e72df4 into scikit-adaptation:main Oct 24, 2024
5 checks passed
@YanisLalou YanisLalou changed the title [WIP] Add CAN Method [MRG] Add CAN Method Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants