Skip to content

[MRG] Check if sample_domain have only unique domains indexes in check_*_domain #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 25, 2024
29 changes: 29 additions & 0 deletions skada/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,12 @@ def test_check_X_y_allow_exceptions():
random_sample_domain = rng.choice(
np.concatenate((np.arange(-5, 0), np.arange(1, 6))), size=len(y)
)
common_idx_sample_domain = rng.choice(np.array([1, -1]), size=len(y))
allow_source = False
allow_target = False
allow_multi_source = False
allow_multi_target = False
allow_common_domain_idx = False

positive_numbers = random_sample_domain[random_sample_domain > 0]
negative_numbers = random_sample_domain[random_sample_domain < 0]
Expand Down Expand Up @@ -296,6 +298,18 @@ def test_check_X_y_allow_exceptions():
allow_auto_sample_domain=False,
allow_multi_target=allow_multi_target,
)
with pytest.raises(
ValueError,
match=(
"Domain labels should be unique: the same domain "
"index should not be used both for source and target"
),
):
check_X_domain(
X,
sample_domain=common_idx_sample_domain,
allow_common_domain_idx=allow_common_domain_idx,
)


def test_check_X_allow_exceptions():
Expand All @@ -314,10 +328,12 @@ def test_check_X_allow_exceptions():
random_sample_domain = rng.choice(
np.concatenate((np.arange(-5, 0), np.arange(1, 6))), size=len(y)
)
common_idx_sample_domain = rng.choice(np.array([1, -1]), size=len(y))
allow_source = False
allow_target = False
allow_multi_source = False
allow_multi_target = False
allow_common_domain_idx = False

positive_numbers = random_sample_domain[random_sample_domain > 0]
negative_numbers = random_sample_domain[random_sample_domain < 0]
Expand Down Expand Up @@ -382,6 +398,19 @@ def test_check_X_allow_exceptions():
allow_multi_target=allow_multi_target,
)

with pytest.raises(
ValueError,
match=(
"Domain labels should be unique: the same domain "
"index should not be used both for source and target"
),
):
check_X_domain(
X,
sample_domain=common_idx_sample_domain,
allow_common_domain_idx=allow_common_domain_idx,
)


def test_check_X_domain_multi_nd():
# Create a 3D array (10 samples, 2 features, 3 channels)
Expand Down
22 changes: 22 additions & 0 deletions skada/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def check_X_y_domain(
allow_multi_source: bool = True,
allow_target: bool = True,
allow_multi_target: bool = True,
allow_common_domain_idx: bool = True,
allow_auto_sample_domain: bool = True,
allow_nd: bool = False,
allow_label_masks: bool = True,
Expand All @@ -57,6 +58,8 @@ def check_X_y_domain(
Allow the presence of target domains.
allow_multi_target : bool, optional (default=True)
Allow multiple target domains.
allow_common_domain_idx : bool, optional (default=True)
Allow the same domain index to be used for source and target domains, e.g 1 for a source domain and -1 for a target domain.
allow_auto_sample_domain : bool, optional (default=True)
Allow automatic generation of sample_domain if not provided.
allow_nd : bool, optional (default=False)
Expand Down Expand Up @@ -115,6 +118,14 @@ def check_X_y_domain(
raise ValueError(f"Number of targets provided is {n_targets} "
"and 'allow_multi_target' is set to False")

# Check for unique domain idx
if not allow_common_domain_idx:
unique_domain_idx = np.unique(sample_domain)
unique_domain_idx_abs = np.abs(unique_domain_idx)
if len(unique_domain_idx) != len(np.unique(unique_domain_idx_abs)):
raise ValueError("Domain labels should be unique: the same domain "
"index should not be used both for source and target")

return X, y, sample_domain


Expand All @@ -128,6 +139,7 @@ def check_X_domain(
allow_multi_source: bool = True,
allow_target: bool = True,
allow_multi_target: bool = True,
allow_common_domain_idx: bool = True,
allow_auto_sample_domain: bool = True,
allow_nd: bool = False,
):
Expand All @@ -153,6 +165,8 @@ def check_X_domain(
Allow the presence of target domains.
allow_multi_target : bool, optional (default=True)
Allow multiple target domains.
allow_common_domain_idx : bool, optional (default=True)
Allow the same domain index to be used for source and target domains, e.g 1 for a source domain and -1 for a target domain.
allow_auto_sample_domain : bool, optional (default=True)
Allow automatic generation of sample_domain if not provided.
allow_nd : bool, optional (default=False)
Expand Down Expand Up @@ -206,6 +220,14 @@ def check_X_domain(
if not allow_multi_target and n_sources > 1:
raise ValueError(f"Number of targets provided is {n_targets} "
"and 'allow_multi_target' is set to False")

# Check for unique domain idx
if not allow_common_domain_idx:
unique_domain_idx = np.unique(sample_domain)
unique_domain_idx_abs = np.abs(unique_domain_idx)
if len(unique_domain_idx) != len(np.unique(unique_domain_idx_abs)):
raise ValueError("Domain labels should be unique: the same domain "
"index should not be used both for source and target")

return X, sample_domain

Expand Down
Loading