-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Add MCC method #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG] Add MCC method #250
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
4512894
add MCC method
tgnassou 44d9743
add -
tgnassou ff7ee86
Merge branch 'main' into mcc
tgnassou f3450c6
Merge branch 'main' into mcc
tgnassou 29cd6b7
merge main
tgnassou e322150
Merge branch 'main' into mcc
YanisLalou d9164ac
Merge branch 'main' into mcc
YanisLalou File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import torch | ||
|
||
from skada.deep.base import ( | ||
BaseDALoss, | ||
DomainAwareCriterion, | ||
DomainAwareModule, | ||
DomainAwareNet, | ||
DomainBalancedDataLoader, | ||
) | ||
|
||
from .losses import mcc_loss | ||
|
||
|
||
class MCCLoss(BaseDALoss): | ||
"""Loss MCC. | ||
|
||
This loss reduces the class confusion of the predicted label of target domain | ||
See [33]_. | ||
|
||
Parameters | ||
---------- | ||
T : float, default=1 | ||
Temperature parameter for the scaling. | ||
If T=1, the scaling is a softmax function. | ||
|
||
References | ||
---------- | ||
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang. | ||
Minimum Class Confusion for Versatile Domain Adaptation. | ||
In ECCV, 2020. | ||
""" | ||
|
||
def __init__(self, T=1): | ||
super().__init__() | ||
self.T = T | ||
|
||
def forward( | ||
self, | ||
y_s, | ||
y_pred_s, | ||
y_pred_t, | ||
domain_pred_s, | ||
domain_pred_t, | ||
features_s, | ||
features_t, | ||
): | ||
"""Compute the domain adaptation loss""" | ||
loss = mcc_loss( | ||
y_pred_t, | ||
T=self.T, | ||
) | ||
return loss | ||
|
||
|
||
def MCC( | ||
module, | ||
layer_name, | ||
reg=1, | ||
T=1, | ||
base_criterion=None, | ||
**kwargs, | ||
): | ||
"""MCC. | ||
|
||
See [33]_. | ||
|
||
Parameters | ||
---------- | ||
module : torch module (class or instance) | ||
A PyTorch :class:`~torch.nn.Module`. | ||
layer_name : str | ||
The name of the module's layer whose outputs are | ||
collected during the training for the adaptation. | ||
reg : float, default=1 | ||
Regularization parameter for DA loss. | ||
T : float, default=1 | ||
Temperature parameter for the scaling. | ||
base_criterion : torch criterion (class) | ||
The base criterion used to compute the loss with source | ||
labels. If None, the default is `torch.nn.CrossEntropyLoss`. | ||
|
||
References | ||
---------- | ||
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang. | ||
Minimum Class Confusion for Versatile Domain Adaptation. | ||
In ECCV, 2020. | ||
""" | ||
if base_criterion is None: | ||
base_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
net = DomainAwareNet( | ||
module=DomainAwareModule, | ||
module__base_module=module, | ||
module__layer_name=layer_name, | ||
iterator_train=DomainBalancedDataLoader, | ||
criterion=DomainAwareCriterion, | ||
criterion__base_criterion=base_criterion, | ||
criterion__adapt_criterion=MCCLoss(T=T), | ||
criterion__reg=reg, | ||
**kwargs, | ||
) | ||
return net |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Oleksii Kachaiev <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import pytest | ||
|
||
pytest.importorskip("torch") | ||
|
||
import numpy as np | ||
|
||
from skada.datasets import make_shifted_datasets | ||
from skada.deep import MCC | ||
from skada.deep.modules import ToyModule2D | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"T", | ||
[1, 0.5], | ||
) | ||
def test_mcc(T): | ||
module = ToyModule2D(n_classes=5) | ||
module.eval() | ||
|
||
n_samples = 50 | ||
dataset = make_shifted_datasets( | ||
n_samples_source=n_samples, | ||
n_samples_target=n_samples, | ||
shift="concept_drift", | ||
noise=0.1, | ||
random_state=42, | ||
return_dataset=True, | ||
label="multiclass", | ||
) | ||
|
||
method = MCC( | ||
module, | ||
reg=1, | ||
layer_name="dropout", | ||
batch_size=32, | ||
max_epochs=5, | ||
train_split=None, | ||
T=T, | ||
) | ||
|
||
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"]) | ||
method.fit(X.astype(np.float32), y, sample_domain) | ||
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"]) | ||
|
||
y_pred = method.predict(X_test.astype(np.float32), sample_domain_test) | ||
|
||
assert y_pred.shape[0] == X_test.shape[0] | ||
|
||
history = method.history_ | ||
|
||
assert history[0]["train_loss"] > history[-1]["train_loss"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I want is to remove the diag. If I do only one torch.diag I will have a vector and then remove the diag on all the line. Two torch.diag create a diagonal matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah my bad I thought it returned a diagonal matrix directly