Skip to content

[MRG] Add MCC method #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,5 @@ The library is distributed under the 3-Clause BSD license.
[32] Hu, D., Liang, J., Liew, J. H., Xue, C., Bai, S., & Wang, X. (2023). [Mixed Samples as Probes for Unsupervised Model Selection in Domain Adaptation](https://proceedings.neurips.cc/paper_files/paper/2023/file/7721f1fea280e9ffae528dc78c732576-Paper-Conference.pdf). Advances in Neural Information Processing Systems 36 (2024).

[33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). [Contrastive Adaptation Network for Unsupervised Domain Adaptation](https://arxiv.org/abs/1901.00976). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902).

[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020.
3 changes: 3 additions & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Deep learning DA :py:mod:`skada.deep`:
DeepJDOTLoss
DANLoss
CDANLoss
MCCLoss
CANLoss

.. autosummary::
Expand All @@ -156,11 +157,13 @@ Deep learning DA :py:mod:`skada.deep`:
dan_loss
deepcoral_loss
deepjdot_loss
mcc_loss
cdd_loss
DeepCoral
DeepJDOT
DANN
CDAN
MCC
CAN


Expand Down
3 changes: 3 additions & 0 deletions skada/deep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ._divergence import DeepCoral, DeepCoralLoss, DANLoss, DAN, CAN, CANLoss
from ._optimal_transport import DeepJDOT, DeepJDOTLoss
from ._adversarial import DANN, CDAN, DANNLoss, CDANLoss
from ._class_confusion import MCC, MCCLoss
from ._baseline import SourceOnly, TargetOnly

from . import losses
Expand All @@ -36,6 +37,8 @@
'DANN',
'CDANLoss',
'CDAN',
'MCCLoss',
'MCC',
'CANLoss',
'CAN',
'SourceOnly',
Expand Down
105 changes: 105 additions & 0 deletions skada/deep/_class_confusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Author: Theo Gnassounou <[email protected]>
#
# License: BSD 3-Clause
import torch

from skada.deep.base import (
BaseDALoss,
DomainAwareCriterion,
DomainAwareModule,
DomainAwareNet,
DomainBalancedDataLoader,
)

from .losses import mcc_loss


class MCCLoss(BaseDALoss):
"""Loss MCC.

This loss reduces the class confusion of the predicted label of target domain
See [33]_.

Parameters
----------
T : float, default=1
Temperature parameter for the scaling.
If T=1, the scaling is a softmax function.

References
----------
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang.
Minimum Class Confusion for Versatile Domain Adaptation.
In ECCV, 2020.
"""

def __init__(self, T=1):
super().__init__()
self.T = T

def forward(
self,
y_s,
y_pred_s,
y_pred_t,
domain_pred_s,
domain_pred_t,
features_s,
features_t,
):
"""Compute the domain adaptation loss"""
loss = mcc_loss(
y_pred_t,
T=self.T,
)
return loss


def MCC(
module,
layer_name,
reg=1,
T=1,
base_criterion=None,
**kwargs,
):
"""MCC.

See [33]_.

Parameters
----------
module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`.
layer_name : str
The name of the module's layer whose outputs are
collected during the training for the adaptation.
reg : float, default=1
Regularization parameter for DA loss.
T : float, default=1
Temperature parameter for the scaling.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.

References
----------
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang.
Minimum Class Confusion for Versatile Domain Adaptation.
In ECCV, 2020.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__adapt_criterion=MCCLoss(T=T),
criterion__reg=reg,
**kwargs,
)
return net
60 changes: 60 additions & 0 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,63 @@ def forward(
):
"""Compute the domain adaptation loss"""
return 0


def probability_scaling(logits, temperature=1):
"""Probability scaling.

Parameters
----------
logits : torch.Tensor
The logits.
temperature : float, default=1
The temperature.

Returns
-------
torch.Tensor
The scaled probabilities.
"""
return torch.nn.functional.softmax(logits / temperature, dim=1)


def mcc_loss(y, T=1):
"""Estimate the Frobenius norm divide by 4*n**2
for DeepCORAL method [33]_.

Parameters
----------
y : tensor
The output of target domain of the model.

T : float, default=1
The temperature for the scaling.

Returns
-------
loss : ndarray
The loss of the method.

References
----------
.. [33] Ying Jin, Ximei Wang, Mingsheng Long, Jianmin Wang.
Minimum Class Confusion for Versatile Domain Adaptation.
In ECCV, 2020.
"""
# Probability Rescaling
y_scaled = probability_scaling(y, temperature=T)

# Uncertainty Reweighting & class correlation matrix
H = -torch.sum(y_scaled * torch.log(y_scaled), axis=1)
W = (1 + torch.exp(-H)) / torch.mean(1 + torch.exp(-H))
y_weighted = torch.matmul(torch.diag(W), y_scaled)
C = torch.einsum("ij,ik->jk", y_scaled, y_weighted)

# Category Normalization
C_tilde = C / torch.sum(C, axis=1, keepdim=True)

# MCC Loss
C_ = C_tilde - torch.diag(torch.diag(C_tilde))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
C_ = C_tilde - torch.diag(torch.diag(C_tilde))
C_ = C_tilde - torch.diag(C_tilde)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I want is to remove the diag. If I do only one torch.diag I will have a vector and then remove the diag on all the line. Two torch.diag create a diagonal matrix

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah my bad I thought it returned a diagonal matrix directly

loss = torch.mean(torch.sum(torch.abs(C_), axis=1))

return loss
56 changes: 56 additions & 0 deletions skada/deep/tests/test_deep_class_confusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Author: Theo Gnassounou <[email protected]>
# Oleksii Kachaiev <[email protected]>
# Yanis Lalou <[email protected]>
#
# License: BSD 3-Clause
import pytest

pytest.importorskip("torch")

import numpy as np

from skada.datasets import make_shifted_datasets
from skada.deep import MCC
from skada.deep.modules import ToyModule2D


@pytest.mark.parametrize(
"T",
[1, 0.5],
)
def test_mcc(T):
module = ToyModule2D(n_classes=5)
module.eval()

n_samples = 50
dataset = make_shifted_datasets(
n_samples_source=n_samples,
n_samples_target=n_samples,
shift="concept_drift",
noise=0.1,
random_state=42,
return_dataset=True,
label="multiclass",
)

method = MCC(
module,
reg=1,
layer_name="dropout",
batch_size=32,
max_epochs=5,
train_split=None,
T=T,
)

X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
method.fit(X.astype(np.float32), y, sample_domain)
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"])

y_pred = method.predict(X_test.astype(np.float32), sample_domain_test)

assert y_pred.shape[0] == X_test.shape[0]

history = method.history_

assert history[0]["train_loss"] > history[-1]["train_loss"]
Loading