Skip to content

[MRG] Add MDD method #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e203380
Update refs
ambroiseodt Oct 24, 2024
f7252e9
Correct typos
ambroiseodt Oct 24, 2024
687aaec
Add MDDLoss
ambroiseodt Oct 24, 2024
d3d10b0
Add MDD module
ambroiseodt Oct 24, 2024
9ad3b73
Update __init__
ambroiseodt Oct 24, 2024
6ece26a
Update docs
ambroiseodt Oct 24, 2024
17b853a
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
a6aeada
Update docs
ambroiseodt Oct 24, 2024
97e8eb2
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
8bcbbf7
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
39b3f0d
Intergrate review
ambroiseodt Oct 24, 2024
d4968a7
Update: binary classif not handled
ambroiseodt Oct 24, 2024
3ea625b
Add mdd test
ambroiseodt Oct 24, 2024
12f60fb
Add import MDD
ambroiseodt Oct 24, 2024
81565be
Update docs & binary case not handled
ambroiseodt Oct 25, 2024
3cbba75
Update test
ambroiseodt Oct 25, 2024
ada2a56
Update order and type of labels in loss
ambroiseodt Oct 25, 2024
f0e318d
debug test MDD
ambroiseodt Oct 25, 2024
51219c3
Debug test
ambroiseodt Oct 25, 2024
37eb528
Merge branch 'main' into add_mdd
YanisLalou Oct 25, 2024
daa6108
Debug test
ambroiseodt Oct 25, 2024
795656f
Merge branch 'main' into add_mdd
ambroiseodt Oct 29, 2024
f82aadb
Debug MDD
ambroiseodt Oct 29, 2024
1b99644
Debug MDD
ambroiseodt Oct 29, 2024
211f852
Update MDD (docs and code)
ambroiseodt Oct 30, 2024
43a78dd
Update domain_classifier
ambroiseodt Oct 30, 2024
f5cc644
Debug test_mdd
ambroiseodt Oct 30, 2024
8d5991b
Update authors
ambroiseodt Oct 30, 2024
677451e
Update __init__
ambroiseodt Oct 30, 2024
659dce5
Update __init__.py
ambroiseodt Oct 30, 2024
f98b053
Update __init__.py
ambroiseodt Oct 30, 2024
8ed1305
Debug test
ambroiseodt Oct 30, 2024
15649ae
Update authors
ambroiseodt Oct 30, 2024
4f43050
Update nomenclature "domain" to "disc"
ambroiseodt Oct 30, 2024
eb6c984
Merge branch 'main' into add_mdd
ambroiseodt Oct 31, 2024
2de5598
Set disc_criterion default as base_criterion
ambroiseodt Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,5 @@ The library is distributed under the 3-Clause BSD license.
[33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). [Contrastive Adaptation Network for Unsupervised Domain Adaptation](https://arxiv.org/abs/1901.00976). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).

[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020.

[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).
2 changes: 2 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Deep learning DA :py:mod:`skada.deep`:
CDANLoss
MCCLoss
CANLoss
MDDLoss

.. autosummary::
:toctree: gen_modules/
Expand All @@ -165,6 +166,7 @@ Deep learning DA :py:mod:`skada.deep`:
CDAN
MCC
CAN
MDD



Expand Down
41 changes: 22 additions & 19 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Author: Theo Gnassounou <[email protected]>
# Remi Flamary <[email protected]>
# Yanis Lalou <[email protected]>
# Ambroise Odonnat <[email protected]>
#
# License: BSD 3-Clause

Expand All @@ -17,30 +18,32 @@

from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss
from ._class_confusion import MCC, MCCLoss
from ._baseline import SourceOnly, TargetOnly

from . import losses
from . import modules

__all__ = [
'losses',
'modules',
'DeepCoralLoss',
'DeepCoral',
'DANLoss',
'DAN',
'DeepJDOTLoss',
'DeepJDOT',
'DANNLoss',
'DANN',
'CDANLoss',
'CDAN',
'MCCLoss',
'MCC',
'CANLoss',
'CAN',
'SourceOnly',
'TargetOnly',
"losses",
"modules",
"DeepCoralLoss",
"DeepCoral",
"DANLoss",
"DAN",
"DeepJDOTLoss",
"DeepJDOT",
"DANNLoss",
"DANN",
"CDANLoss",
"CDAN",
"MCCLoss",
"MCC",
"MDDLoss",
"MDD",
"CANLoss",
"CAN",
"SourceOnly",
"TargetOnly",
]
160 changes: 154 additions & 6 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Author: Theo Gnassounou <[email protected]>
# Ambroise Odonnat <[email protected]>
#
# License: BSD 3-Clause
import copy
import math

import numpy as np
Expand Down Expand Up @@ -30,14 +32,14 @@

Parameters
----------
target_criterion : torch criterion (class), default=None
domain_criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
DANN loss. If None, a BCELoss is used.

References
----------
.. [15] Yaroslav Ganin et. al. Domain-Adversarial Training
of Neural Networks In Journal of Machine Learning
of Neural Networks. In Journal of Machine Learning
Research, 2016.
"""

Expand Down Expand Up @@ -120,7 +122,7 @@
References
----------
.. [15] Yaroslav Ganin et. al. Domain-Adversarial Training
of Neural Networks In Journal of Machine Learning
of Neural Networks. In Journal of Machine Learning
Research, 2016.
"""
if domain_classifier is None:
Expand Down Expand Up @@ -161,9 +163,7 @@

Parameters
----------
reg : float, default=1
Regularization parameter.
target_criterion : torch criterion (class), default=None
domain_criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
CDAN loss. If None, a BCELoss is used.

Expand Down Expand Up @@ -370,6 +370,154 @@
return net


class MDDLoss(BaseDALoss):
"""Loss MDD.

This loss tries to minimize the disparity discrepancy between
the source and target domains. The discrepancy is estimated
through adversarial training of three networks: an encoder,
a task network and a discriminator.

See [35]_ for details.

Parameters
----------
disc_criterion : torch criterion (class)
The criterion (loss) used to compute the
MDD loss for the discriminator. It should
be the same loss as the base criterion.
If None, a CrossEntropyLoss is used.
gamma : float (default=4.0)
Margin parameter following [35]_

References
----------
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm
for Domain Adaptation. In International Conference on
Machine Learning, 2019.
"""

def __init__(self, disc_criterion, gamma=4.0):
super().__init__()
if disc_criterion is None:
self.disc_criterion_ = torch.nn.CrossEntropyLoss()

Check warning on line 403 in skada/deep/_adversarial.py

View check run for this annotation

Codecov / codecov/patch

skada/deep/_adversarial.py#L403

Added line #L403 was not covered by tests
else:
self.disc_criterion_ = disc_criterion
self.gamma = gamma

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
disc_pred_s,
disc_pred_t,
features_s,
features_t,
):
"""Compute the domain adaptation loss"""
# TODO: handle binary classification
# Multiclass classification
pseudo_label_s = torch.argmax(y_pred_s, axis=-1)
pseudo_label_t = torch.argmax(y_pred_t, axis=-1)

disc_loss_src = self.disc_criterion_(disc_pred_s, pseudo_label_s)
disc_loss_tgt = self.disc_criterion_(disc_pred_t, pseudo_label_t)

# Compute the MDD loss value
disc_loss = self.gamma * disc_loss_src - disc_loss_tgt

return disc_loss


def MDD(
module,
layer_name,
reg=1,
gamma=4.0,
disc_classifier=None,
num_features=None,
n_classes=None,
base_criterion=None,
disc_criterion=None,
**kwargs,
):
"""Margin Disparity Discrepancy (MDD).

From [35]_.

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
layer_name : str
The name of the module's layer whose outputs are
collected during the training.
reg : float, default=1
Regularization parameter for DA loss.
disc_classifier : torch module, default=None
A PyTorch :class:`~torch.nn.Module` used as a discriminator.
It should have the same architecture than the classifier
used on the source. If None, a domain classifier is
created following [1]_.
num_features : int, default=None
Size of the input of domain classifier,
e.g size of the last layer of
the feature extractor.
If domain_classifier is None, num_features has to be
provided.
n_classes : int, default=None
Number of classes. If domain_classifier is None,
n_classes has to be provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
disc_criterion : torch criterion (class)
The criterion (loss) used to compute the
MDD loss for the discriminator.
If None, use the same loss as base_criterion.
gamma : float (default=4.0)
Margin parameter following [35]_.

References
----------
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm
for Domain Adaptation. In International Conference on
Machine Learning, 2019.
"""
if disc_classifier is None:
# raise error if num_feature is None
if num_features is None:
raise ValueError(

Check warning on line 493 in skada/deep/_adversarial.py

View check run for this annotation

Codecov / codecov/patch

skada/deep/_adversarial.py#L493

Added line #L493 was not covered by tests
"If disc_classifier is None, num_features has to be provided"
)
disc_classifier = DomainClassifier(
num_features=num_features, n_classes=n_classes
)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

if disc_criterion is None:
disc_criterion = copy.deepcopy(base_criterion)

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__domain_classifier=disc_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=MDDLoss(gamma=gamma, disc_criterion=disc_criterion),
**kwargs,
)
return net


class _RandomLayer(nn.Module):
"""Randomized multilinear map layer.

Expand Down
3 changes: 2 additions & 1 deletion skada/deep/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Remi Flamary <[email protected]>
# Yanis Lalou <[email protected]>
# Antoine Collas <[email protected]>
# Ambroise Odonnat <[email protected]>
#
# License: BSD 3-Clause
import torch
Expand Down Expand Up @@ -192,7 +193,7 @@ def forward(self, x, sample_weight=None):
Parameter for the reverse layer.
"""
reverse_x = GradientReversalLayer.apply(x, self.alpha)
return self.classifier(reverse_x).flatten()
return self.classifier(reverse_x).squeeze()


class MNISTtoUSPSNet(nn.Module):
Expand Down
63 changes: 61 additions & 2 deletions skada/deep/tests/test_deep_adversarial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Author: Theo Gnassounou <[email protected]>
# Oleksii Kachaiev <[email protected]>
# Ambroise Odonnat <[email protected]>
#
# License: BSD 3-Clause
import pytest

torch = pytest.importorskip("torch")

import numpy as np
from torch.nn import BCELoss
from torch.nn import BCELoss, CrossEntropyLoss

from skada.datasets import make_shifted_datasets
from skada.deep import CDAN, DANN
from skada.deep import CDAN, DANN, MDD
from skada.deep.modules import DomainClassifier, ToyModule2D


Expand Down Expand Up @@ -114,6 +115,64 @@ def test_cdan(domain_classifier, domain_criterion, num_feature, max_feature, n_c
assert history[0]["train_loss"] > history[-1]["train_loss"]


@pytest.mark.parametrize(
"disc_classifier, disc_criterion, num_features, n_classes",
[
(
DomainClassifier(num_features=10, n_classes=5),
CrossEntropyLoss(),
10,
5,
),
(
DomainClassifier(num_features=10, n_classes=5),
CrossEntropyLoss(),
None,
None,
),
(DomainClassifier(num_features=10, n_classes=5), None, None, None),
(None, None, 10, 5),
],
)
def test_mdd(disc_classifier, disc_criterion, num_features, n_classes):
n_samples = 20
dataset = make_shifted_datasets(
n_samples_source=n_samples,
n_samples_target=n_samples,
shift="concept_drift",
noise=0.1,
random_state=42,
return_dataset=True,
)

method = MDD(
ToyModule2D(n_classes=5),
reg=1,
gamma=4.0,
disc_classifier=disc_classifier,
num_features=num_features,
n_classes=n_classes,
disc_criterion=disc_criterion,
layer_name="dropout",
batch_size=10,
max_epochs=50,
train_split=None,
)

X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
method.fit(X.astype(np.float32), y, sample_domain)

X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"])

y_pred = method.predict(X_test.astype(np.float32), sample_domain_test)

assert y_pred.shape[0] == X_test.shape[0]

history = method.history_

assert history[0]["train_loss"] > history[-1]["train_loss"]


def test_missing_num_features():
with pytest.raises(ValueError):
DANN(
Expand Down
Loading