-
Notifications
You must be signed in to change notification settings - Fork 23
[MRG] Add MDD method #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG] Add MDD method #263
Changes from 21 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
e203380
Update refs
ambroiseodt f7252e9
Correct typos
ambroiseodt 687aaec
Add MDDLoss
ambroiseodt d3d10b0
Add MDD module
ambroiseodt 9ad3b73
Update __init__
ambroiseodt 6ece26a
Update docs
ambroiseodt 17b853a
Merge branch 'main' into add_mdd
ambroiseodt a6aeada
Update docs
ambroiseodt 97e8eb2
Merge branch 'main' into add_mdd
ambroiseodt 8bcbbf7
Merge branch 'main' into add_mdd
ambroiseodt 39b3f0d
Intergrate review
ambroiseodt d4968a7
Update: binary classif not handled
ambroiseodt 3ea625b
Add mdd test
ambroiseodt 12f60fb
Add import MDD
ambroiseodt 81565be
Update docs & binary case not handled
ambroiseodt 3cbba75
Update test
ambroiseodt ada2a56
Update order and type of labels in loss
ambroiseodt f0e318d
debug test MDD
ambroiseodt 51219c3
Debug test
ambroiseodt 37eb528
Merge branch 'main' into add_mdd
YanisLalou daa6108
Debug test
ambroiseodt 795656f
Merge branch 'main' into add_mdd
ambroiseodt f82aadb
Debug MDD
ambroiseodt 1b99644
Debug MDD
ambroiseodt 211f852
Update MDD (docs and code)
ambroiseodt 43a78dd
Update domain_classifier
ambroiseodt f5cc644
Debug test_mdd
ambroiseodt 8d5991b
Update authors
ambroiseodt 677451e
Update __init__
ambroiseodt 659dce5
Update __init__.py
ambroiseodt f98b053
Update __init__.py
ambroiseodt 8ed1305
Debug test
ambroiseodt 15649ae
Update authors
ambroiseodt 4f43050
Update nomenclature "domain" to "disc"
ambroiseodt eb6c984
Merge branch 'main' into add_mdd
ambroiseodt 2de5598
Set disc_criterion default as base_criterion
ambroiseodt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Remi Flamary <[email protected]> | ||
# Yanis Lalou <[email protected]> | ||
# | ||
# Ambroise Odonnat <[email protected]> | ||
# License: BSD 3-Clause | ||
|
||
""" | ||
|
@@ -17,30 +17,31 @@ | |
|
||
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss | ||
from ._optimal_transport import DeepJDOT, DeepJDOTLoss | ||
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss | ||
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss | ||
from ._class_confusion import MCC, MCCLoss | ||
from ._baseline import SourceOnly, TargetOnly | ||
|
||
from . import losses | ||
from . import modules | ||
|
||
__all__ = [ | ||
'losses', | ||
'modules', | ||
'DeepCoralLoss', | ||
'DeepCoral', | ||
'DANLoss', | ||
'DAN', | ||
'DeepJDOTLoss', | ||
'DeepJDOT', | ||
'DANNLoss', | ||
'DANN', | ||
'CDANLoss', | ||
'CDAN', | ||
'MCCLoss', | ||
'MCC', | ||
'CANLoss', | ||
'CAN', | ||
'SourceOnly', | ||
'TargetOnly', | ||
"losses", | ||
"modules", | ||
"DeepCoralLoss", | ||
"DeepCoral", | ||
"DANLoss", | ||
"DAN", | ||
"DeepJDOTLoss", | ||
"DeepJDOT", | ||
"DANNLoss", | ||
"DANN", | ||
"CDANLoss", | ||
"CDAN", | ||
"MCCLoss", | ||
"MCC", | ||
"CANLoss", | ||
"CAN", | ||
"MDD", | ||
"SourceOnly", | ||
"TargetOnly", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Ambroise Odonnat <[email protected]> | ||
# | ||
# License: BSD 3-Clause | ||
import math | ||
|
@@ -30,7 +31,7 @@ class DANNLoss(BaseDALoss): | |
|
||
Parameters | ||
---------- | ||
target_criterion : torch criterion (class), default=None | ||
domain_criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
DANN loss. If None, a BCELoss is used. | ||
|
||
|
@@ -161,9 +162,7 @@ class CDANLoss(BaseDALoss): | |
|
||
Parameters | ||
---------- | ||
reg : float, default=1 | ||
Regularization parameter. | ||
target_criterion : torch criterion (class), default=None | ||
domain_criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
CDAN loss. If None, a BCELoss is used. | ||
|
||
|
@@ -370,6 +369,162 @@ def CDAN( | |
return net | ||
|
||
|
||
class MDDLoss(BaseDALoss): | ||
"""Loss MDD. | ||
|
||
This loss tries to minimize the disparity discrepancy between | ||
the source and target domains. The discrepancy is estimated | ||
through adversarial training of three networks: an encoder, | ||
a task network and a discriminator. | ||
|
||
See [35]_ for details. | ||
|
||
Parameters | ||
---------- | ||
gamma : float (default=4.0) | ||
Margin parameter following [35]_ | ||
criterion : torch criterion (class), default=None | ||
The initialized criterion (loss) used to compute the | ||
MDD loss. If None, a CrossEntropyLoss is used. | ||
|
||
References | ||
---------- | ||
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm | ||
for Domain Adaptation. In International Conference on | ||
Machine Learning, 2019. | ||
""" | ||
|
||
def __init__(self, gamma=4.0, criterion=None): | ||
super().__init__() | ||
self.gamma = gamma | ||
if criterion is None: | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
else: | ||
self.criterion = criterion | ||
|
||
def forward( | ||
self, | ||
y_s, | ||
y_pred_s, | ||
y_pred_t, | ||
domain_pred_s, | ||
domain_pred_t, | ||
features_s, | ||
features_t, | ||
): | ||
"""Compute the domain adaptation loss""" | ||
# TODO: handle binary classification | ||
# if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): | ||
# pseudo_label_s = y_pred_s > 0 | ||
# pseudo_label_t = y_pred_t > 0 | ||
|
||
# elif isinstance(self.criterion, torch.nn.BCELoss): | ||
# pseudo_label_s = y_pred_s > 0.5 | ||
# pseudo_label_t = y_pred_t > 0.5 | ||
|
||
if isinstance(self.criterion, torch.nn.CrossEntropyLoss): | ||
pseudo_label_s = torch.argmax(y_pred_s, axis=-1) | ||
pseudo_label_t = torch.argmax(y_pred_t, axis=-1) | ||
else: | ||
pseudo_label_s = y_pred_s | ||
pseudo_label_t = y_pred_t | ||
|
||
pseudo_label_s = pseudo_label_s.to(domain_pred_s.device) | ||
pseudo_label_t = pseudo_label_t.to(domain_pred_t.device) | ||
|
||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Deal with CrossEntropyLoss convention in Pytorch | ||
if domain_pred_s.shape == pseudo_label_s.shape: | ||
domain_pred_s = domain_pred_s.float() | ||
pseudo_label_s = pseudo_label_s.float() | ||
else: | ||
domain_pred_s = domain_pred_s.float() | ||
pseudo_label_s = pseudo_label_s.long() | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
disc_loss_src = self.criterion(domain_pred_s, pseudo_label_s) | ||
disc_loss_tgt = self.criterion(domain_pred_t, pseudo_label_t) | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Compute the loss value | ||
disc_loss = self.gamma * disc_loss_src - disc_loss_tgt | ||
|
||
return disc_loss | ||
|
||
|
||
def MDD( | ||
module, | ||
layer_name, | ||
reg=1, | ||
gamma=4.0, | ||
domain_classifier=None, | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_features=None, | ||
base_criterion=None, | ||
criterion=None, | ||
**kwargs, | ||
): | ||
"""Margin Disparity Discrepancy (MDD). | ||
|
||
From [35]_. | ||
|
||
Parameters | ||
---------- | ||
module : torch module (class or instance) | ||
A PyTorch :class:`~torch.nn.Module`. In general, the | ||
uninstantiated class should be passed, although instantiated | ||
modules will also work. | ||
layer_name : str | ||
The name of the module's layer whose outputs are | ||
collected during the training. | ||
reg : float, default=1 | ||
Regularization parameter for DA loss. | ||
gamma : float (default=4.0) | ||
Margin parameter following [35]_. | ||
domain_classifier : torch module, default=None | ||
A PyTorch :class:`~torch.nn.Module` used to classify the | ||
domain. If None, a domain classifier is created following [1]_. | ||
num_features : int, default=None | ||
Size of the input of domain classifier, | ||
e.g size of the last layer of | ||
the feature extractor. | ||
If domain_classifier is None, num_features has to be | ||
provided. | ||
base_criterion : torch criterion (class) | ||
The base criterion used to compute the loss with source | ||
labels. If None, the default is `torch.nn.CrossEntropyLoss`. | ||
criterion : torch criterion (class) | ||
The criterion (loss) used to compute the | ||
MDD loss. If None, a CrossEntropyLoss is used. | ||
|
||
References | ||
---------- | ||
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm | ||
for Domain Adaptation. In International Conference on | ||
Machine Learning, 2019. | ||
""" | ||
if domain_classifier is None: | ||
# raise error if num_feature is None | ||
if num_features is None: | ||
raise ValueError( | ||
"If domain_classifier is None, num_features has to be provided" | ||
) | ||
domain_classifier = DomainClassifier(num_features=num_features) | ||
|
||
if base_criterion is None: | ||
base_criterion = torch.nn.CrossEntropyLoss() | ||
|
||
net = DomainAwareNet( | ||
module=DomainAwareModule, | ||
module__base_module=module, | ||
module__layer_name=layer_name, | ||
module__domain_classifier=domain_classifier, | ||
iterator_train=DomainBalancedDataLoader, | ||
criterion=DomainAwareCriterion, | ||
criterion__base_criterion=base_criterion, | ||
criterion__reg=reg, | ||
criterion__adapt_criterion=MDDLoss(gamma=gamma, criterion=criterion), | ||
**kwargs, | ||
) | ||
return net | ||
|
||
|
||
class _RandomLayer(nn.Module): | ||
"""Randomized multilinear map layer. | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,16 @@ | ||
# Author: Theo Gnassounou <[email protected]> | ||
# Oleksii Kachaiev <[email protected]> | ||
# | ||
# Ambroise Odonnat <[email protected]> | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# License: BSD 3-Clause | ||
import pytest | ||
|
||
torch = pytest.importorskip("torch") | ||
|
||
import numpy as np | ||
from torch.nn import BCELoss | ||
from torch.nn import BCELoss, CrossEntropyLoss | ||
|
||
from skada.datasets import make_shifted_datasets | ||
from skada.deep import CDAN, DANN | ||
from skada.deep import CDAN, DANN, MDD | ||
from skada.deep.modules import DomainClassifier, ToyModule2D | ||
|
||
|
||
|
@@ -114,6 +114,55 @@ def test_cdan(domain_classifier, domain_criterion, num_feature, max_feature, n_c | |
assert history[0]["train_loss"] > history[-1]["train_loss"] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"domain_classifier, criterion, num_features", | ||
[ | ||
(DomainClassifier(num_features=10), CrossEntropyLoss(), None), | ||
(DomainClassifier(num_features=10), None, None), | ||
(None, None, 10), | ||
], | ||
) | ||
def test_mdd(domain_classifier, criterion, num_features): | ||
module = ToyModule2D() | ||
module.eval() | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
n_samples = 20 | ||
dataset = make_shifted_datasets( | ||
n_samples_source=n_samples, | ||
n_samples_target=n_samples, | ||
shift="concept_drift", | ||
noise=0.1, | ||
random_state=42, | ||
return_dataset=True, | ||
) | ||
|
||
method = MDD( | ||
ToyModule2D(), | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
reg=1, | ||
gamma=4.0, | ||
domain_classifier=domain_classifier, | ||
num_features=num_features, | ||
criterion=criterion, | ||
ambroiseodt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
layer_name="dropout", | ||
batch_size=10, | ||
max_epochs=10, | ||
train_split=None, | ||
) | ||
|
||
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"]) | ||
method.fit(X.astype(np.float32), y, sample_domain) | ||
|
||
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"]) | ||
|
||
y_pred = method.predict(X_test.astype(np.float32), sample_domain_test) | ||
|
||
assert y_pred.shape[0] == X_test.shape[0] | ||
|
||
history = method.history_ | ||
|
||
assert history[0]["train_loss"] > history[-1]["train_loss"] | ||
|
||
|
||
def test_missing_num_features(): | ||
with pytest.raises(ValueError): | ||
DANN( | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.