Skip to content

[MRG] Add MDD method #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e203380
Update refs
ambroiseodt Oct 24, 2024
f7252e9
Correct typos
ambroiseodt Oct 24, 2024
687aaec
Add MDDLoss
ambroiseodt Oct 24, 2024
d3d10b0
Add MDD module
ambroiseodt Oct 24, 2024
9ad3b73
Update __init__
ambroiseodt Oct 24, 2024
6ece26a
Update docs
ambroiseodt Oct 24, 2024
17b853a
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
a6aeada
Update docs
ambroiseodt Oct 24, 2024
97e8eb2
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
8bcbbf7
Merge branch 'main' into add_mdd
ambroiseodt Oct 24, 2024
39b3f0d
Intergrate review
ambroiseodt Oct 24, 2024
d4968a7
Update: binary classif not handled
ambroiseodt Oct 24, 2024
3ea625b
Add mdd test
ambroiseodt Oct 24, 2024
12f60fb
Add import MDD
ambroiseodt Oct 24, 2024
81565be
Update docs & binary case not handled
ambroiseodt Oct 25, 2024
3cbba75
Update test
ambroiseodt Oct 25, 2024
ada2a56
Update order and type of labels in loss
ambroiseodt Oct 25, 2024
f0e318d
debug test MDD
ambroiseodt Oct 25, 2024
51219c3
Debug test
ambroiseodt Oct 25, 2024
37eb528
Merge branch 'main' into add_mdd
YanisLalou Oct 25, 2024
daa6108
Debug test
ambroiseodt Oct 25, 2024
795656f
Merge branch 'main' into add_mdd
ambroiseodt Oct 29, 2024
f82aadb
Debug MDD
ambroiseodt Oct 29, 2024
1b99644
Debug MDD
ambroiseodt Oct 29, 2024
211f852
Update MDD (docs and code)
ambroiseodt Oct 30, 2024
43a78dd
Update domain_classifier
ambroiseodt Oct 30, 2024
f5cc644
Debug test_mdd
ambroiseodt Oct 30, 2024
8d5991b
Update authors
ambroiseodt Oct 30, 2024
677451e
Update __init__
ambroiseodt Oct 30, 2024
659dce5
Update __init__.py
ambroiseodt Oct 30, 2024
f98b053
Update __init__.py
ambroiseodt Oct 30, 2024
8ed1305
Debug test
ambroiseodt Oct 30, 2024
15649ae
Update authors
ambroiseodt Oct 30, 2024
4f43050
Update nomenclature "domain" to "disc"
ambroiseodt Oct 30, 2024
eb6c984
Merge branch 'main' into add_mdd
ambroiseodt Oct 31, 2024
2de5598
Set disc_criterion default as base_criterion
ambroiseodt Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,5 @@ The library is distributed under the 3-Clause BSD license.
[33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). [Contrastive Adaptation Network for Unsupervised Domain Adaptation](https://arxiv.org/abs/1901.00976). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).

[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020.

[35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413).
2 changes: 2 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Deep learning DA :py:mod:`skada.deep`:
CDANLoss
MCCLoss
CANLoss
MDDLoss

.. autosummary::
:toctree: gen_modules/
Expand All @@ -165,6 +166,7 @@ Deep learning DA :py:mod:`skada.deep`:
CDAN
MCC
CAN
MDD



Expand Down
41 changes: 21 additions & 20 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Theo Gnassounou <[email protected]>
# Remi Flamary <[email protected]>
# Yanis Lalou <[email protected]>
#
# Ambroise Odonnat <[email protected]>
# License: BSD 3-Clause

"""
Expand All @@ -17,30 +17,31 @@

from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss
from ._adversarial import DANN, CDAN, MDD, DANNLoss, CDANLoss, MDDLoss
from ._class_confusion import MCC, MCCLoss
from ._baseline import SourceOnly, TargetOnly

from . import losses
from . import modules

__all__ = [
'losses',
'modules',
'DeepCoralLoss',
'DeepCoral',
'DANLoss',
'DAN',
'DeepJDOTLoss',
'DeepJDOT',
'DANNLoss',
'DANN',
'CDANLoss',
'CDAN',
'MCCLoss',
'MCC',
'CANLoss',
'CAN',
'SourceOnly',
'TargetOnly',
"losses",
"modules",
"DeepCoralLoss",
"DeepCoral",
"DANLoss",
"DAN",
"DeepJDOTLoss",
"DeepJDOT",
"DANNLoss",
"DANN",
"CDANLoss",
"CDAN",
"MCCLoss",
"MCC",
"CANLoss",
"CAN",
"MDD",
"SourceOnly",
"TargetOnly",
]
163 changes: 159 additions & 4 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Author: Theo Gnassounou <[email protected]>
# Ambroise Odonnat <[email protected]>
#
# License: BSD 3-Clause
import math
Expand Down Expand Up @@ -30,7 +31,7 @@ class DANNLoss(BaseDALoss):

Parameters
----------
target_criterion : torch criterion (class), default=None
domain_criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
DANN loss. If None, a BCELoss is used.

Expand Down Expand Up @@ -161,9 +162,7 @@ class CDANLoss(BaseDALoss):

Parameters
----------
reg : float, default=1
Regularization parameter.
target_criterion : torch criterion (class), default=None
domain_criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
CDAN loss. If None, a BCELoss is used.

Expand Down Expand Up @@ -370,6 +369,162 @@ def CDAN(
return net


class MDDLoss(BaseDALoss):
"""Loss MDD.

This loss tries to minimize the disparity discrepancy between
the source and target domains. The discrepancy is estimated
through adversarial training of three networks: an encoder,
a task network and a discriminator.

See [35]_ for details.

Parameters
----------
gamma : float (default=4.0)
Margin parameter following [35]_
criterion : torch criterion (class), default=None
The initialized criterion (loss) used to compute the
MDD loss. If None, a CrossEntropyLoss is used.

References
----------
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm
for Domain Adaptation. In International Conference on
Machine Learning, 2019.
"""

def __init__(self, gamma=4.0, criterion=None):
super().__init__()
self.gamma = gamma
if criterion is None:
self.criterion = torch.nn.CrossEntropyLoss()
else:
self.criterion = criterion

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
):
"""Compute the domain adaptation loss"""
# TODO: handle binary classification
# if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
# pseudo_label_s = y_pred_s > 0
# pseudo_label_t = y_pred_t > 0

# elif isinstance(self.criterion, torch.nn.BCELoss):
# pseudo_label_s = y_pred_s > 0.5
# pseudo_label_t = y_pred_t > 0.5

if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
pseudo_label_s = torch.argmax(y_pred_s, axis=-1)
pseudo_label_t = torch.argmax(y_pred_t, axis=-1)
else:
pseudo_label_s = y_pred_s
pseudo_label_t = y_pred_t

pseudo_label_s = pseudo_label_s.to(domain_pred_s.device)
pseudo_label_t = pseudo_label_t.to(domain_pred_t.device)

# Deal with CrossEntropyLoss convention in Pytorch
if domain_pred_s.shape == pseudo_label_s.shape:
domain_pred_s = domain_pred_s.float()
pseudo_label_s = pseudo_label_s.float()
else:
domain_pred_s = domain_pred_s.float()
pseudo_label_s = pseudo_label_s.long()

disc_loss_src = self.criterion(domain_pred_s, pseudo_label_s)
disc_loss_tgt = self.criterion(domain_pred_t, pseudo_label_t)

# Compute the loss value
disc_loss = self.gamma * disc_loss_src - disc_loss_tgt

return disc_loss


def MDD(
module,
layer_name,
reg=1,
gamma=4.0,
domain_classifier=None,
num_features=None,
base_criterion=None,
criterion=None,
**kwargs,
):
"""Margin Disparity Discrepancy (MDD).

From [35]_.

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
layer_name : str
The name of the module's layer whose outputs are
collected during the training.
reg : float, default=1
Regularization parameter for DA loss.
gamma : float (default=4.0)
Margin parameter following [35]_.
domain_classifier : torch module, default=None
A PyTorch :class:`~torch.nn.Module` used to classify the
domain. If None, a domain classifier is created following [1]_.
num_features : int, default=None
Size of the input of domain classifier,
e.g size of the last layer of
the feature extractor.
If domain_classifier is None, num_features has to be
provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
criterion : torch criterion (class)
The criterion (loss) used to compute the
MDD loss. If None, a CrossEntropyLoss is used.

References
----------
.. [35] Yuchen Zhang et. al. Bridging Theory and Algorithm
for Domain Adaptation. In International Conference on
Machine Learning, 2019.
"""
if domain_classifier is None:
# raise error if num_feature is None
if num_features is None:
raise ValueError(
"If domain_classifier is None, num_features has to be provided"
)
domain_classifier = DomainClassifier(num_features=num_features)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__domain_classifier=domain_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=MDDLoss(gamma=gamma, criterion=criterion),
**kwargs,
)
return net


class _RandomLayer(nn.Module):
"""Randomized multilinear map layer.

Expand Down
55 changes: 52 additions & 3 deletions skada/deep/tests/test_deep_adversarial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Author: Theo Gnassounou <[email protected]>
# Oleksii Kachaiev <[email protected]>
#
# Ambroise Odonnat <[email protected]>
# License: BSD 3-Clause
import pytest

torch = pytest.importorskip("torch")

import numpy as np
from torch.nn import BCELoss
from torch.nn import BCELoss, CrossEntropyLoss

from skada.datasets import make_shifted_datasets
from skada.deep import CDAN, DANN
from skada.deep import CDAN, DANN, MDD
from skada.deep.modules import DomainClassifier, ToyModule2D


Expand Down Expand Up @@ -114,6 +114,55 @@ def test_cdan(domain_classifier, domain_criterion, num_feature, max_feature, n_c
assert history[0]["train_loss"] > history[-1]["train_loss"]


@pytest.mark.parametrize(
"domain_classifier, criterion, num_features",
[
(DomainClassifier(num_features=10), CrossEntropyLoss(), None),
(DomainClassifier(num_features=10), None, None),
(None, None, 10),
],
)
def test_mdd(domain_classifier, criterion, num_features):
module = ToyModule2D()
module.eval()

n_samples = 20
dataset = make_shifted_datasets(
n_samples_source=n_samples,
n_samples_target=n_samples,
shift="concept_drift",
noise=0.1,
random_state=42,
return_dataset=True,
)

method = MDD(
ToyModule2D(),
reg=1,
gamma=4.0,
domain_classifier=domain_classifier,
num_features=num_features,
criterion=criterion,
layer_name="dropout",
batch_size=10,
max_epochs=10,
train_split=None,
)

X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
method.fit(X.astype(np.float32), y, sample_domain)

X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"])

y_pred = method.predict(X_test.astype(np.float32), sample_domain_test)

assert y_pred.shape[0] == X_test.shape[0]

history = method.history_

assert history[0]["train_loss"] > history[-1]["train_loss"]


def test_missing_num_features():
with pytest.raises(ValueError):
DANN(
Expand Down
Loading