-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Add MDD method #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG] Add MDD method #263
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
e203380
Update refs
ambroiseodt f7252e9
Correct typos
ambroiseodt 687aaec
Add MDDLoss
ambroiseodt d3d10b0
Add MDD module
ambroiseodt 9ad3b73
Update __init__
ambroiseodt 6ece26a
Update docs
ambroiseodt 17b853a
Merge branch 'main' into add_mdd
ambroiseodt a6aeada
Update docs
ambroiseodt 97e8eb2
Merge branch 'main' into add_mdd
ambroiseodt 8bcbbf7
Merge branch 'main' into add_mdd
ambroiseodt 39b3f0d
Intergrate review
ambroiseodt d4968a7
Update: binary classif not handled
ambroiseodt 3ea625b
Add mdd test
ambroiseodt 12f60fb
Add import MDD
ambroiseodt 81565be
Update docs & binary case not handled
ambroiseodt 3cbba75
Update test
ambroiseodt ada2a56
Update order and type of labels in loss
ambroiseodt f0e318d
debug test MDD
ambroiseodt 51219c3
Debug test
ambroiseodt 37eb528
Merge branch 'main' into add_mdd
YanisLalou daa6108
Debug test
ambroiseodt 795656f
Merge branch 'main' into add_mdd
ambroiseodt f82aadb
Debug MDD
ambroiseodt 1b99644
Debug MDD
ambroiseodt 211f852
Update MDD (docs and code)
ambroiseodt 43a78dd
Update domain_classifier
ambroiseodt f5cc644
Debug test_mdd
ambroiseodt 8d5991b
Update authors
ambroiseodt 677451e
Update __init__
ambroiseodt 659dce5
Update __init__.py
ambroiseodt f98b053
Update __init__.py
ambroiseodt 8ed1305
Debug test
ambroiseodt 15649ae
Update authors
ambroiseodt 4f43050
Update nomenclature "domain" to "disc"
ambroiseodt eb6c984
Merge branch 'main' into add_mdd
ambroiseodt 2de5598
Set disc_criterion default as base_criterion
ambroiseodt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Remi Flamary <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# Ambroise Odonnat <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
|
||
|
@@ -17,30 +18,32 @@ | |
|
||
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss | ||
from ._optimal_transport import DeepJDOT, DeepJDOTLoss | ||
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss | ||
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss | ||
from ._class_confusion import MCC, MCCLoss | ||
from ._baseline import SourceOnly, TargetOnly | ||
|
||
from . import losses | ||
from . import modules | ||
|
||
__all__ = [ | ||
'losses', | ||
'modules', | ||
'DeepCoralLoss', | ||
'DeepCoral', | ||
'DANLoss', | ||
'DAN', | ||
'DeepJDOTLoss', | ||
'DeepJDOT', | ||
'DANNLoss', | ||
'DANN', | ||
'CDANLoss', | ||
'CDAN', | ||
'MCCLoss', | ||
'MCC', | ||
'CANLoss', | ||
'CAN', | ||
'SourceOnly', | ||
'TargetOnly', | ||
"losses", | ||
"modules", | ||
"DeepCoralLoss", | ||
"DeepCoral", | ||
"DANLoss", | ||
"DAN", | ||
"DeepJDOTLoss", | ||
"DeepJDOT", | ||
"DANNLoss", | ||
"DANN", | ||
"CDANLoss", | ||
"CDAN", | ||
"MCCLoss", | ||
"MCC", | ||
"MDDLoss", | ||
"MDD", | ||
"CANLoss", | ||
"CAN", | ||
"SourceOnly", | ||
"TargetOnly", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Ambroise Odonnat <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import copy | ||
import math | ||
|
||
import numpy as np | ||
|
@@ -30,14 +32,14 @@ | |
|
||
Parameters | ||
---------- | ||
target_criterion : torch criterion (class), default=None | ||
domain_criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
DANN loss. If None, a BCELoss is used. | ||
|
||
References | ||
---------- | ||
.. [15] Yaroslav Ganin et. al. Domain-Adversarial Training | ||
of Neural Networks In Journal of Machine Learning | ||
of Neural Networks. In Journal of Machine Learning | ||
Research, 2016. | ||
""" | ||
|
||
|
@@ -120,7 +122,7 @@ | |
References | ||
---------- | ||
.. [15] Yaroslav Ganin et. al. Domain-Adversarial Training | ||
of Neural Networks In Journal of Machine Learning | ||
of Neural Networks. In Journal of Machine Learning | ||
Research, 2016. | ||
""" | ||
if domain_classifier is None: | ||
|
@@ -161,9 +163,7 @@ | |
|
||
Parameters | ||
---------- | ||
reg : float, default=1 | ||
Regularization parameter. | ||
target_criterion : torch criterion (class), default=None | ||
domain_criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
CDAN loss. If None, a BCELoss is used. | ||
|
||
|
@@ -370,6 +370,154 @@ | |
return net | ||
|
||
|
||
class MDDLoss(BaseDALoss): | ||
"""Loss MDD. | ||
|
||
This loss tries to minimize the disparity discrepancy between | ||
the source and target domains. The discrepancy is estimated | ||
through adversarial training of three networks: an encoder, | ||
a task network and a discriminator. | ||
|
||
See [35]_ for details. | ||
|
||
Parameters | ||
---------- | ||
disc_criterion : torch criterion (class) | ||
The criterion (loss) used to compute the | ||
MDD loss for the discriminator. It should | ||
be the same loss as the base criterion. | ||
If None, a CrossEntropyLoss is used. | ||
gamma : float (default=4.0) | ||
Margin parameter following [35]_ | ||
|
||
References | ||
---------- | ||
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm | ||
for Domain Adaptation. In International Conference on | ||
Machine Learning, 2019. | ||
""" | ||
|
||
def __init__(self, disc_criterion, gamma=4.0): | ||
super().__init__() | ||
if disc_criterion is None: | ||
self.disc_criterion_ = torch.nn.CrossEntropyLoss() | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
self.disc_criterion_ = disc_criterion | ||
self.gamma = gamma | ||
|
||
def forward( | ||
self, | ||
y_s, | ||
y_pred_s, | ||
y_pred_t, | ||
disc_pred_s, | ||
disc_pred_t, | ||
features_s, | ||
features_t, | ||
): | ||
"""Compute the domain adaptation loss""" | ||
# TODO: handle binary classification | ||
# Multiclass classification | ||
pseudo_label_s = torch.argmax(y_pred_s, axis=-1) | ||
pseudo_label_t = torch.argmax(y_pred_t, axis=-1) | ||
|
||
disc_loss_src = self.disc_criterion_(disc_pred_s, pseudo_label_s) | ||
disc_loss_tgt = self.disc_criterion_(disc_pred_t, pseudo_label_t) | ||
|
||
# Compute the MDD loss value | ||
disc_loss = self.gamma * disc_loss_src - disc_loss_tgt | ||
|
||
return disc_loss | ||
|
||
|
||
def MDD( | ||
module, | ||
layer_name, | ||
reg=1, | ||
gamma=4.0, | ||
disc_classifier=None, | ||
num_features=None, | ||
n_classes=None, | ||
base_criterion=None, | ||
disc_criterion=None, | ||
**kwargs, | ||
): | ||
"""Margin Disparity Discrepancy (MDD). | ||
|
||
From [35]_. | ||
|
||
Parameters | ||
---------- | ||
module : torch module (class or instance) | ||
A PyTorch :class:`~torch.nn.Module`. In general, the | ||
uninstantiated class should be passed, although instantiated | ||
modules will also work. | ||
layer_name : str | ||
The name of the module's layer whose outputs are | ||
collected during the training. | ||
reg : float, default=1 | ||
Regularization parameter for DA loss. | ||
disc_classifier : torch module, default=None | ||
A PyTorch :class:`~torch.nn.Module` used as a discriminator. | ||
It should have the same architecture than the classifier | ||
used on the source. If None, a domain classifier is | ||
created following [1]_. | ||
num_features : int, default=None | ||
Size of the input of domain classifier, | ||
e.g size of the last layer of | ||
the feature extractor. | ||
If domain_classifier is None, num_features has to be | ||
provided. | ||
n_classes : int, default=None | ||
Number of classes. If domain_classifier is None, | ||
n_classes has to be provided. | ||
base_criterion : torch criterion (class) | ||
The base criterion used to compute the loss with source | ||
labels. If None, the default is `torch.nn.CrossEntropyLoss`. | ||
disc_criterion : torch criterion (class) | ||
The criterion (loss) used to compute the | ||
MDD loss for the discriminator. | ||
If None, use the same loss as base_criterion. | ||
gamma : float (default=4.0) | ||
Margin parameter following [35]_. | ||
|
||
References | ||
---------- | ||
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm | ||
for Domain Adaptation. In International Conference on | ||
Machine Learning, 2019. | ||
""" | ||
if disc_classifier is None: | ||
# raise error if num_feature is None | ||
if num_features is None: | ||
raise ValueError( | ||
"If disc_classifier is None, num_features has to be provided" | ||
) | ||
disc_classifier = DomainClassifier( | ||
num_features=num_features, n_classes=n_classes | ||
) | ||
|
||
if base_criterion is None: | ||
base_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
if disc_criterion is None: | ||
disc_criterion = copy.deepcopy(base_criterion) | ||
|
||
net = DomainAwareNet( | ||
module=DomainAwareModule, | ||
module__base_module=module, | ||
module__layer_name=layer_name, | ||
module__domain_classifier=disc_classifier, | ||
iterator_train=DomainBalancedDataLoader, | ||
criterion=DomainAwareCriterion, | ||
criterion__base_criterion=base_criterion, | ||
criterion__reg=reg, | ||
criterion__adapt_criterion=MDDLoss(gamma=gamma, disc_criterion=disc_criterion), | ||
**kwargs, | ||
) | ||
return net | ||
|
||
|
||
class _RandomLayer(nn.Module): | ||
"""Randomized multilinear map layer. | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
# Remi Flamary <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# Antoine Collas <[email protected]> | ||
# Ambroise Odonnat <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import torch | ||
|
@@ -192,7 +193,7 @@ def forward(self, x, sample_weight=None): | |
Parameter for the reverse layer. | ||
""" | ||
reverse_x = GradientReversalLayer.apply(x, self.alpha) | ||
return self.classifier(reverse_x).flatten() | ||
return self.classifier(reverse_x).squeeze() | ||
|
||
|
||
class MNISTtoUSPSNet(nn.Module): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,17 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Oleksii Kachaiev <[email protected]> | ||
# Ambroise Odonnat <[email protected]> | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# License: BSD 3-Clause | ||
import pytest | ||
|
||
torch = pytest.importorskip("torch") | ||
|
||
import numpy as np | ||
from torch.nn import BCELoss | ||
from torch.nn import BCELoss, CrossEntropyLoss | ||
|
||
from skada.datasets import make_shifted_datasets | ||
from skada.deep import CDAN, DANN | ||
from skada.deep import CDAN, DANN, MDD | ||
from skada.deep.modules import DomainClassifier, ToyModule2D | ||
|
||
|
||
|
@@ -114,6 +115,64 @@ def test_cdan(domain_classifier, domain_criterion, num_feature, max_feature, n_c | |
assert history[0]["train_loss"] > history[-1]["train_loss"] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"disc_classifier, disc_criterion, num_features, n_classes", | ||
[ | ||
( | ||
DomainClassifier(num_features=10, n_classes=5), | ||
CrossEntropyLoss(), | ||
10, | ||
5, | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
( | ||
DomainClassifier(num_features=10, n_classes=5), | ||
CrossEntropyLoss(), | ||
None, | ||
None, | ||
), | ||
(DomainClassifier(num_features=10, n_classes=5), None, None, None), | ||
(None, None, 10, 5), | ||
], | ||
) | ||
def test_mdd(disc_classifier, disc_criterion, num_features, n_classes): | ||
n_samples = 20 | ||
dataset = make_shifted_datasets( | ||
n_samples_source=n_samples, | ||
n_samples_target=n_samples, | ||
shift="concept_drift", | ||
noise=0.1, | ||
random_state=42, | ||
return_dataset=True, | ||
) | ||
|
||
method = MDD( | ||
ToyModule2D(n_classes=5), | ||
reg=1, | ||
gamma=4.0, | ||
disc_classifier=disc_classifier, | ||
num_features=num_features, | ||
n_classes=n_classes, | ||
disc_criterion=disc_criterion, | ||
layer_name="dropout", | ||
batch_size=10, | ||
max_epochs=50, | ||
train_split=None, | ||
) | ||
|
||
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"]) | ||
method.fit(X.astype(np.float32), y, sample_domain) | ||
|
||
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"]) | ||
|
||
y_pred = method.predict(X_test.astype(np.float32), sample_domain_test) | ||
|
||
assert y_pred.shape[0] == X_test.shape[0] | ||
|
||
history = method.history_ | ||
|
||
assert history[0]["train_loss"] > history[-1]["train_loss"] | ||
|
||
|
||
def test_missing_num_features(): | ||
with pytest.raises(ValueError): | ||
DANN( | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.