Skip to content

[To_review] Modify sampler to take the max of the two domains #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions skada/deep/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def forward(
pass



class DomainBalancedSampler(Sampler):
"""Domain balanced sampler

Expand All @@ -173,19 +174,35 @@ class DomainBalancedSampler(Sampler):
----------
dataset : torch dataset
The dataset to sample from.
batch_size : int
The batch size.
max_samples : str, default='max'
The maximum number of samples to use. It can be 'max', 'min', 'source', or 'target'.
"""

def __init__(self, dataset, batch_size):
def __init__(self, dataset, batch_size, max_samples="max"):
self.dataset = dataset
self.positive_indices = [
idx for idx, sample in enumerate(dataset) if sample[0]["sample_domain"] >= 0
]
self.negative_indices = [
idx for idx, sample in enumerate(dataset) if sample[0]["sample_domain"] < 0
]
self.num_samples = (
self.num_samples_source = (
len(self.positive_indices) - len(self.positive_indices) % batch_size
)
self.num_samples_target = (
len(self.negative_indices) - len(self.negative_indices) % batch_size
)
if max_samples == "max":
self.num_samples = max(self.num_samples_source, self.num_samples_target)
elif max_samples == "min":
self.num_samples = min(self.num_samples_source, self.num_samples_target)
elif max_samples == "source":
self.num_samples = self.num_samples_source
elif max_samples == "target":
self.num_samples = self.num_samples_target


def __iter__(self):
positive_sampler = torch.utils.data.sampler.RandomSampler(self.positive_indices)
Expand All @@ -195,7 +212,11 @@ def __iter__(self):
negative_iter = iter(negative_sampler)

for _ in range(self.num_samples):
pos_idx = self.positive_indices[next(positive_iter)]
try:
pos_idx = self.positive_indices[next(positive_iter)]
except StopIteration:
positive_iter = iter(positive_sampler)
pos_idx = self.positive_indices[next(positive_iter)]
try:
neg_idx = self.negative_indices[next(negative_iter)]
except StopIteration:
Expand All @@ -208,6 +229,7 @@ def __len__(self):
return 2 * self.num_samples



class DomainBalancedDataLoader(DataLoader):
"""Domain balanced data loader

Expand All @@ -217,12 +239,17 @@ class DomainBalancedDataLoader(DataLoader):
----------
dataset : torch dataset
The dataset to sample from.
batch_size : int
The batch size.
max_samples : str, default='max'
The maximum number of samples to use. It can be 'max', 'min', 'source', or 'target'.
"""

def __init__(
self,
dataset,
batch_size,
max_samples="max",
shuffle=False,
sampler=None,
batch_sampler=None,
Expand All @@ -234,7 +261,7 @@ def __init__(
worker_init_fn=None,
multiprocessing_context=None,
):
sampler = DomainBalancedSampler(dataset, batch_size)
sampler = DomainBalancedSampler(dataset, batch_size, max_samples=max_samples)
super().__init__(
dataset,
2 * batch_size,
Expand Down
25 changes: 22 additions & 3 deletions skada/deep/tests/test_deep_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,16 @@ def test_return_features():
assert features.shape == (X_test.shape[0], num_features)


def test_domain_balanced_sampler():
@pytest.mark.parametrize(
"max_samples",
[
"max",
"source",
"target",
"min",
],
)
def test_domain_balanced_sampler(max_samples):
n_samples = 20
dataset = make_shifted_datasets(
n_samples_source=n_samples,
Expand All @@ -303,10 +312,20 @@ def test_domain_balanced_sampler():
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_dict = {"X": X.astype(np.float32), "sample_domain": sample_domain}

n_samples_source = np.sum(sample_domain > 0)
n_samples_target = np.sum(sample_domain < 0)

dataset = Dataset(X_dict, y)

sampler = DomainBalancedSampler(dataset, 10)
assert len(sampler) == 2 * np.sum(sample_domain > 0)
sampler = DomainBalancedSampler(dataset, 10, max_samples=max_samples)
if max_samples == "max":
assert len(sampler) == 2 * max(n_samples_source, n_samples_target)
elif max_samples == "source":
assert len(sampler) == 2 * n_samples_source
elif max_samples == "target":
assert len(sampler) == 2 * n_samples_target
elif max_samples == "min":
assert len(sampler) == 2 * min(n_samples_source, n_samples_target)


def test_domain_balanced_dataloader():
Expand Down
Loading