Skip to content

Commit 5e530c7

Browse files
[MERGE] Fix mixval (#222)
* fix mixval * fix scorer evaluation * add classification check * simplify mix_labels computation * sample_domain * alpha >= 0.5 and TODO handle multiple target domains --------- Co-authored-by: Antoine Collas <[email protected]>
1 parent d7cef20 commit 5e530c7

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

skada/metrics.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -634,15 +634,20 @@ class MixValScorer(_BaseDomainAwareScorer):
634634
----------
635635
alpha : float, default=0.55
636636
Mixing parameter for mixup.
637-
random_state : int, RandomState instance or None, default=None
638-
Controls the randomness of the mixing process.
639-
greater_is_better : bool, default=True
640-
Whether higher scores are better.
641637
ice_type : {'both', 'intra', 'inter'}, default='both'
642638
Type of ICE score to compute:
643639
- 'both': Compute both intra-cluster and inter-cluster ICE scores (average).
644640
- 'intra': Compute only intra-cluster ICE score.
645641
- 'inter': Compute only inter-cluster ICE score.
642+
scoring : str or callable, default=None
643+
A string (see model evaluation documentation) or
644+
a scorer callable object / function with signature
645+
``scorer(estimator, X, y)``.
646+
If None, the provided estimator object's `score` method is used.
647+
greater_is_better : bool, default=True
648+
Whether higher scores are better.
649+
random_state : int, RandomState instance or None, default=None
650+
Controls the randomness of the mixing process.
646651
647652
Attributes
648653
----------
@@ -665,15 +670,17 @@ class MixValScorer(_BaseDomainAwareScorer):
665670
def __init__(
666671
self,
667672
alpha=0.55,
668-
random_state=None,
669-
greater_is_better=True,
670673
ice_type="both",
674+
scoring=None,
675+
greater_is_better=True,
676+
random_state=None,
671677
):
672678
super().__init__()
673679
self.alpha = alpha
674-
self.random_state = random_state
675-
self._sign = 1 if greater_is_better else -1
676680
self.ice_type = ice_type
681+
self.scoring = scoring
682+
self._sign = 1 if greater_is_better else -1
683+
self.random_state = random_state
677684

678685
if self.ice_type not in ["both", "intra", "inter"]:
679686
raise ValueError("ice_type must be 'both', 'intra', or 'inter'")
@@ -698,10 +705,17 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params):
698705
score : float
699706
The ICE score.
700707
"""
708+
scorer = check_scoring(estimator, self.scoring)
709+
701710
X, _, sample_domain = check_X_y_domain(X, y, sample_domain)
702711
source_idx = extract_source_indices(sample_domain)
703712
X_target = X[~source_idx]
704713

714+
# Check from y values if it is a classification problem
715+
y_type = _find_y_type(y)
716+
if y_type != Y_Type.DISCRETE:
717+
raise ValueError("MixVal scorer only supports classification problems.")
718+
705719
rng = check_random_state(self.random_state)
706720
rand_idx = rng.permutation(X_target.shape[0])
707721

@@ -713,24 +727,29 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params):
713727
same_idx = (labels_a == labels_b).nonzero()[0]
714728
diff_idx = (labels_a != labels_b).nonzero()[0]
715729

716-
# Mixup with images and hard pseudo labels
730+
# Mixup with X_target and hard pseudo labels
717731
mix_inputs = self.alpha * X_target + (1 - self.alpha) * X_target[rand_idx]
718-
mix_labels = self.alpha * labels_a + (1 - self.alpha) * labels_b
719-
720-
# Obtain predictions for the mixed samples
721-
mix_pred = estimator.predict(
722-
mix_inputs, sample_domain=np.full(mix_inputs.shape[0], -1)
723-
)
732+
if self.alpha >= 0.5:
733+
mix_labels = labels_a
734+
else:
735+
mix_labels = labels_b
724736

725737
# Calculate ICE scores based on ice_type
738+
# TODO: handle multiple target domains
726739
if self.ice_type in ["both", "intra"]:
727-
ice_same = (
728-
np.sum(mix_pred[same_idx] == mix_labels[same_idx]) / same_idx.shape[0]
740+
ice_same = scorer(
741+
estimator,
742+
mix_inputs[same_idx],
743+
mix_labels[same_idx],
744+
sample_domain=np.full(same_idx.shape[0], -1),
729745
)
730746

731747
if self.ice_type in ["both", "inter"]:
732-
ice_diff = (
733-
np.sum(mix_pred[diff_idx] == mix_labels[diff_idx]) / diff_idx.shape[0]
748+
ice_diff = scorer(
749+
estimator,
750+
mix_inputs[diff_idx],
751+
mix_labels[diff_idx],
752+
sample_domain=np.full(diff_idx.shape[0], -1),
734753
)
735754

736755
if self.ice_type == "both":

skada/tests/test_scorer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,13 @@ def test_mixval_scorer(da_dataset):
297297
# Test invalid ice_type
298298
with pytest.raises(ValueError):
299299
MixValScorer(ice_type="invalid")
300+
301+
302+
def test_mixval_scorer_regression(da_reg_dataset):
303+
X, y, sample_domain = da_reg_dataset
304+
305+
estimator = make_da_pipeline(DensityReweightAdapter(), LinearRegression())
306+
307+
scorer = MixValScorer(alpha=0.55, random_state=42)
308+
with pytest.raises(ValueError):
309+
scorer(estimator, X, y, sample_domain)

0 commit comments

Comments
 (0)