@@ -634,15 +634,20 @@ class MixValScorer(_BaseDomainAwareScorer):
634
634
----------
635
635
alpha : float, default=0.55
636
636
Mixing parameter for mixup.
637
- random_state : int, RandomState instance or None, default=None
638
- Controls the randomness of the mixing process.
639
- greater_is_better : bool, default=True
640
- Whether higher scores are better.
641
637
ice_type : {'both', 'intra', 'inter'}, default='both'
642
638
Type of ICE score to compute:
643
639
- 'both': Compute both intra-cluster and inter-cluster ICE scores (average).
644
640
- 'intra': Compute only intra-cluster ICE score.
645
641
- 'inter': Compute only inter-cluster ICE score.
642
+ scoring : str or callable, default=None
643
+ A string (see model evaluation documentation) or
644
+ a scorer callable object / function with signature
645
+ ``scorer(estimator, X, y)``.
646
+ If None, the provided estimator object's `score` method is used.
647
+ greater_is_better : bool, default=True
648
+ Whether higher scores are better.
649
+ random_state : int, RandomState instance or None, default=None
650
+ Controls the randomness of the mixing process.
646
651
647
652
Attributes
648
653
----------
@@ -665,15 +670,17 @@ class MixValScorer(_BaseDomainAwareScorer):
665
670
def __init__ (
666
671
self ,
667
672
alpha = 0.55 ,
668
- random_state = None ,
669
- greater_is_better = True ,
670
673
ice_type = "both" ,
674
+ scoring = None ,
675
+ greater_is_better = True ,
676
+ random_state = None ,
671
677
):
672
678
super ().__init__ ()
673
679
self .alpha = alpha
674
- self .random_state = random_state
675
- self ._sign = 1 if greater_is_better else - 1
676
680
self .ice_type = ice_type
681
+ self .scoring = scoring
682
+ self ._sign = 1 if greater_is_better else - 1
683
+ self .random_state = random_state
677
684
678
685
if self .ice_type not in ["both" , "intra" , "inter" ]:
679
686
raise ValueError ("ice_type must be 'both', 'intra', or 'inter'" )
@@ -698,10 +705,17 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params):
698
705
score : float
699
706
The ICE score.
700
707
"""
708
+ scorer = check_scoring (estimator , self .scoring )
709
+
701
710
X , _ , sample_domain = check_X_y_domain (X , y , sample_domain )
702
711
source_idx = extract_source_indices (sample_domain )
703
712
X_target = X [~ source_idx ]
704
713
714
+ # Check from y values if it is a classification problem
715
+ y_type = _find_y_type (y )
716
+ if y_type != Y_Type .DISCRETE :
717
+ raise ValueError ("MixVal scorer only supports classification problems." )
718
+
705
719
rng = check_random_state (self .random_state )
706
720
rand_idx = rng .permutation (X_target .shape [0 ])
707
721
@@ -713,24 +727,29 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params):
713
727
same_idx = (labels_a == labels_b ).nonzero ()[0 ]
714
728
diff_idx = (labels_a != labels_b ).nonzero ()[0 ]
715
729
716
- # Mixup with images and hard pseudo labels
730
+ # Mixup with X_target and hard pseudo labels
717
731
mix_inputs = self .alpha * X_target + (1 - self .alpha ) * X_target [rand_idx ]
718
- mix_labels = self .alpha * labels_a + (1 - self .alpha ) * labels_b
719
-
720
- # Obtain predictions for the mixed samples
721
- mix_pred = estimator .predict (
722
- mix_inputs , sample_domain = np .full (mix_inputs .shape [0 ], - 1 )
723
- )
732
+ if self .alpha >= 0.5 :
733
+ mix_labels = labels_a
734
+ else :
735
+ mix_labels = labels_b
724
736
725
737
# Calculate ICE scores based on ice_type
738
+ # TODO: handle multiple target domains
726
739
if self .ice_type in ["both" , "intra" ]:
727
- ice_same = (
728
- np .sum (mix_pred [same_idx ] == mix_labels [same_idx ]) / same_idx .shape [0 ]
740
+ ice_same = scorer (
741
+ estimator ,
742
+ mix_inputs [same_idx ],
743
+ mix_labels [same_idx ],
744
+ sample_domain = np .full (same_idx .shape [0 ], - 1 ),
729
745
)
730
746
731
747
if self .ice_type in ["both" , "inter" ]:
732
- ice_diff = (
733
- np .sum (mix_pred [diff_idx ] == mix_labels [diff_idx ]) / diff_idx .shape [0 ]
748
+ ice_diff = scorer (
749
+ estimator ,
750
+ mix_inputs [diff_idx ],
751
+ mix_labels [diff_idx ],
752
+ sample_domain = np .full (diff_idx .shape [0 ], - 1 ),
734
753
)
735
754
736
755
if self .ice_type == "both" :
0 commit comments