Skip to content

[TO_REVIEW] Add epsilon in MCC to prevent log(0) #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion skada/deep/_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class MCCLoss(BaseDALoss):
T : float, default=1
Temperature parameter for the scaling.
If T=1, the scaling is a softmax function.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.

References
----------
Expand All @@ -33,9 +35,10 @@ class MCCLoss(BaseDALoss):
In ECCV, 2020.
"""

def __init__(self, T=1):
def __init__(self, T=1, eps=1e-7):
super().__init__()
self.T = T
self.eps = eps

def forward(
self,
Expand All @@ -51,6 +54,7 @@ def forward(
loss = mcc_loss(
y_pred_t,
T=self.T,
eps=self.eps,
)
return loss

Expand Down
7 changes: 4 additions & 3 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,18 @@ def probability_scaling(logits, temperature=1):
return torch.nn.functional.softmax(logits / temperature, dim=1)


def mcc_loss(y, T=1):
def mcc_loss(y, T=1, eps=1e-7):
"""Estimate the Frobenius norm divide by 4*n**2
for DeepCORAL method [33]_.

Parameters
----------
y : tensor
The output of target domain of the model.

T : float, default=1
The temperature for the scaling.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.

Returns
-------
Expand All @@ -395,7 +396,7 @@ def mcc_loss(y, T=1):
y_scaled = probability_scaling(y, temperature=T)

# Uncertainty Reweighting & class correlation matrix
H = -torch.sum(y_scaled * torch.log(y_scaled), axis=1)
H = -torch.sum(y_scaled * torch.log(y_scaled + eps), axis=1)
W = (1 + torch.exp(-H)) / torch.mean(1 + torch.exp(-H))
y_weighted = torch.matmul(torch.diag(W), y_scaled)
C = torch.einsum("ij,ik->jk", y_scaled, y_weighted)
Expand Down
25 changes: 25 additions & 0 deletions skada/deep/tests/test_deep_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
pytest.importorskip("torch")

import numpy as np
import torch

from skada.datasets import make_shifted_datasets
from skada.deep import MCC
from skada.deep.losses import mcc_loss
from skada.deep.modules import ToyModule2D


Expand Down Expand Up @@ -54,3 +56,26 @@ def test_mcc(T):
history = method.history_

assert history[0]["train_loss"] > history[-1]["train_loss"]


def test_mcc_with_zeros():
"""Test that mcc_loss handles zero probabilities correctly."""
# Create logits with extreme values that will result in zeros
# after softmax operation due to numerical underflow
logits = torch.tensor(
[
[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0],
]
)

# Verify that we actually get zeros in y_scaled
y_scaled = torch.nn.functional.softmax(logits, dim=1)
assert torch.sum(y_scaled == 0.0) > 0, "Test setup failed: no zeros in y_scaled"

# This should not raise any errors due to the epsilon in log
loss = mcc_loss(logits, T=1.0)

assert torch.isfinite(loss) # Check that the loss is not NaN or infinite
assert loss >= 0 # MCC loss should be non-negative
Loading