-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Add CAN Method #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add CAN Method #251
Conversation
skada/deep/losses.py
Outdated
if mask.sum() > 0: | ||
class_features = features_s[mask] | ||
normalized_features = F.normalize(class_features, p=2, dim=1) | ||
centroid = normalized_features.mean(dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the paper it seems to be only a sum no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Discard ambiguous classes | ||
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes) | ||
valid_classes = class_counts >= class_threshold | ||
mask_t = valid_classes[cluster_labels_t] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see what this line is doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class_counts = torch.bincount(cluster_labels_t, minlength=n_classes)
counts how many samples are in each cluster.valid_classes = class_counts >= class_threshold
creates a boolean tensor whereTrue
indicates classes that have at leastclass_threshold
samples.mask_t = valid_classes[cluster_labels_t]
is using the cluster labels as indices into thevalid_classes tensor. This create a boolean
mask_t, where
True` indicates samples that belong to classes with enough representation.
This part of the code corresponds to the Filter the ambiguous classes
part of the paper pseudo algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it!
features_t = features_t[mask_t] | ||
cluster_labels_t = cluster_labels_t[mask_t] | ||
|
||
# Define sigmas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you cannot use the mmd distance from DAN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The formula is not exactly the same as for the mmd since before computing each mean we apply a specific mask
|
||
for n_iter in range(self.max_iter): | ||
# Assign samples to closest centroids | ||
dissimilarities = self._compute_dissimilarities(X, centroids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a difference here with the function cosine_similarities
de torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In paper: cosine_dissimilarity
is 0.5*(1 − cosine_similarity)
Paper: https://arxiv.org/pdf/1901.00976
Mostly eq 3-4-5 + paragraph 3.4
New Features:
CANLoss
class toskada/deep/_divergence.py
to implement the contrastive domain discrepancy (CDD) loss.CAN
function toskada/deep/_divergence.py
to implement the CAN domain adaptation method.New Utilities:
SphericalKMeans
class toskada/deep/utils.py
for clustering using cosine similarity.Testing:
test_deep_divergence.py
to ensure the new method works as expected.Still needs to be done: