Skip to content

[WIP] Implementation of deep Multi-source method #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,4 @@ The library is distributed under the 3-Clause BSD license.

[32] Hu, D., Liang, J., Liew, J. H., Xue, C., Bai, S., & Wang, X. (2023). [Mixed Samples as Probes for Unsupervised Model Selection in Domain Adaptation](https://proceedings.neurips.cc/paper_files/paper/2023/file/7721f1fea280e9ffae528dc78c732576-Paper-Conference.pdf). Advances in Neural Information Processing Systems 36 (2024).


[33] Zhu, Y., Zhuang, F., and Wang, D., (2022). [Aligning Domain-specific Distribution and Classifier for Cross-domain Classification from Multiple Sources](https://arxiv.org/abs/2201.01003). Association for the Advancement of Artificial Intelligence.
2 changes: 2 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Deep learning DA :py:mod:`skada.deep`:
DeepJDOTLoss
DANLoss
CDANLoss
MFSANLoss

.. autosummary::
:toctree: gen_modules/
Expand All @@ -159,6 +160,7 @@ Deep learning DA :py:mod:`skada.deep`:
DeepJDOT
DANN
CDAN
MFSAN



Expand Down
2 changes: 1 addition & 1 deletion examples/deep/plot_training_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from skada.deep.base import (
DomainAwareCriterion,
DomainAwareModule,
DomainBalancedDataLoader,
)
from skada.deep.dataloaders import DomainBalancedDataLoader
from skada.deep.modules import MNISTtoUSPSNet

# %%
Expand Down
6 changes: 6 additions & 0 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss
from ._multi_source import MFSAN, MFSANLoss, MultiSourceModule
from ._baseline import SourceOnly, TargetOnly

from . import losses
from . import modules
from . import dataloaders

__all__ = [
'losses',
'modules',
'dataloaders',
'DeepCoralLoss',
'DeepCoral',
'DANLoss',
Expand All @@ -35,6 +38,9 @@
'DANN',
'CDANLoss',
'CDAN',
'MFSANLoss',
'MFSAN',
'MultiSourceModule',
'SourceOnly',
'TargetOnly',
]
25 changes: 6 additions & 19 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)
from skada.deep.dataloaders import DomainBalancedDataLoader

from .modules import DomainClassifier
from .utils import check_generator
Expand Down Expand Up @@ -57,6 +57,7 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
domain_label = torch.zeros(
Expand Down Expand Up @@ -189,6 +190,7 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
dtype = torch.float32
Expand Down Expand Up @@ -289,26 +291,11 @@ def forward(self, X, sample_domain=None, is_fit=False, return_features=False):

domain_pred_s = self.domain_classifier_(multilinear_map)
domain_pred_t = self.domain_classifier_(multilinear_map_target)
domain_pred = torch.empty(len(sample_domain), device=domain_pred_s.device)
domain_pred[source_idx] = domain_pred_s
domain_pred[~source_idx] = domain_pred_t

y_pred = torch.empty(
(len(sample_domain), y_pred_s.shape[1]), device=y_pred_s.device
)
y_pred[source_idx] = y_pred_s
y_pred[~source_idx] = y_pred_t

features = torch.empty(
(len(sample_domain), features_s.shape[1]), device=features_s.device
)
features[source_idx] = features_s
features[~source_idx] = features_t

return (
y_pred,
domain_pred,
features,
(y_pred_s, y_pred_t),
(domain_pred_s, domain_pred_t),
(features_s, features_t),
sample_domain,
)
else:
Expand Down
3 changes: 2 additions & 1 deletion skada/deep/_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainOnlyDataLoader,
)
from skada.deep.dataloaders import DomainOnlyDataLoader


class DummyLoss(BaseDALoss):
Expand All @@ -33,6 +33,7 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
return 0
Expand Down
8 changes: 5 additions & 3 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)
from skada.deep.dataloaders import DomainBalancedDataLoader

from .losses import dan_loss, deepcoral_loss
from .losses import deepcoral_loss, mmd_loss


class DeepCoralLoss(BaseDALoss):
Expand Down Expand Up @@ -50,6 +50,7 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
loss = deepcoral_loss(features_s, features_t, self.assume_centered)
Expand Down Expand Up @@ -133,9 +134,10 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
loss = dan_loss(features_s, features_t, sigmas=self.sigmas)
loss = mmd_loss(features_s, features_t, sigmas=self.sigmas)
return loss


Expand Down
203 changes: 203 additions & 0 deletions skada/deep/_multi_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Author: Theo Gnassounou <[email protected]>
#
# License: BSD 3-Clause
import torch
import torch.nn as nn
import torch.nn.functional as F

from skada.deep.base import (
BaseDALoss,
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
)
from skada.deep.dataloaders import MultiSourceDomainBalancedDataLoader
from skada.deep.losses import mmd_loss


class SelectDomainModule(torch.nn.Module):
"""Select domain module"""

def __init__(self):
super().__init__()

def forward(self, X, sample_domain=None, is_source=True):
if is_source:
X = X[sample_domain - 1, torch.arange(X.size(1))]
return X


class MultiSourceModule(torch.nn.Module):
"""Multi-source module

A multi-source module allowing domain-specific layers

Parameters
----------
layers : list of torch modules (list)
list of the Module in the order.
domain_specific_layers : dict
A list of True or False saying if the layer should domain-specific or not.
n_source_domains : int
The number of domains.
"""

def __init__(self, layers, domain_specific_layers, n_domains):
super().__init__()
for i, layer in enumerate(layers):
if domain_specific_layers[i]:
self.add_module(
f"layer_{i}", nn.ModuleList(layer for _ in range(n_domains))
)
self.add_module(f"output_layer_{i}", SelectDomainModule())
else:
self.add_module(f"layer_{i}", layer)
self.n_domains = n_domains

def forward(self, X, sample_domain=None, sample_weight=None, is_source=True):
for i, layer in enumerate(self.children()):
if isinstance(layer, nn.ModuleList):
if X.size(0) != self.n_domains:
X = [layer[j](X) for j in range(self.n_domains)]
else:
X = [layer[j](X[j]) for j in range(self.n_domains)]
X = torch.stack(X, dim=0)
elif isinstance(layer, SelectDomainModule):
if is_source:
X = layer(X, sample_domain)
else:
X = layer(X, is_source=is_source)
else:
X = layer(X)
return X


class MFSANLoss(BaseDALoss):
"""Loss MFSAN

The loss for the MFSAN method proposed in [33].


Parameters
----------
reg_mmd : float, optional (default=1)
The regularization parameter of the MMD loss.
reg_cl : float, optional (default=1)
The regularization parameter of the target discrepancy
classification loss.
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.

References
----------
.. [33] Zhu, Y., Zhuang, F., and Wang, D., (2022).
Aligning Domain-specific Distribution and Classifier
for Cross-domain Classification from Multiple Sources.
Association for the Advancement of Artificial Intelligence.
"""

def __init__(
self,
reg_mmd=1,
reg_cl=1,
sigmas=None,
):
super().__init__()
self.reg_mmd = reg_mmd
self.reg_cl = reg_cl
self.sigmas = sigmas

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
n_domains = len(features_t)
source_idx = sample_domain > 0
mmd = 0
disc = 0
for i in range(n_domains):
mmd += mmd_loss(
features_s[torch.where(sample_domain[source_idx] == i + 1)[0]],
features_t[i],
sigmas=self.sigmas,
)
for j in range(i + 1, n_domains):
disc += torch.mean(
torch.abs(F.softmax(y_pred_t[i]) - F.softmax(y_pred_t[j])),
)
mmd /= n_domains
disc /= n_domains * (n_domains - 1) / 2
loss = self.reg_mmd * mmd + self.reg_cl * disc
return loss


def MFSAN(
module,
layer_name,
source_domains,
reg_mmd=1,
reg_cl=1,
sigmas=None,
base_criterion=None,
**kwargs,
):
"""MFSAN domain adaptation method.

See [33]_.

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for the adaptation.
source_domains : list of int
The list of source domains.
reg_mmd : float, optional (default=1)
The regularization parameter of the MMD loss.
reg_cl : float, optional (default=1)
The regularization parameter of the target discrepancy
classification loss.
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.

References
----------
.. [33] Zhu, Y., Zhuang, F., and Wang, D., (2022).
Aligning Domain-specific Distribution and Classifier
for Cross-domain Classification from Multiple Sources.
Association for the Advancement of Artificial Intelligence.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__is_multi_source=True,
module__flatten_features=False,
iterator_train=MultiSourceDomainBalancedDataLoader,
iterator_train__source_domains=source_domains,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=1,
criterion__adapt_criterion=MFSANLoss(
reg_mmd=reg_mmd, reg_cl=reg_cl, sigmas=sigmas
),
criterion__is_multi_source=True,
**kwargs,
)
return net
3 changes: 2 additions & 1 deletion skada/deep/_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)
from skada.deep.dataloaders import DomainBalancedDataLoader

from .losses import deepjdot_loss

Expand Down Expand Up @@ -57,6 +57,7 @@ def forward(
domain_pred_t,
features_s,
features_t,
sample_domain,
):
"""Compute the domain adaptation loss"""
loss = deepjdot_loss(
Expand Down
Loading
Loading