Skip to content

Commit d83c4ba

Browse files
authored
[Feature] add multi-source sampler (#1938)
1 parent 2a92b26 commit d83c4ba

File tree

2 files changed

+112
-1
lines changed

2 files changed

+112
-1
lines changed

mmpose/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .builder import build_dataset
33
from .dataset_wrappers import CombinedDataset
44
from .datasets import * # noqa
5+
from .samplers import MultiSourceSampler
56
from .transforms import * # noqa
67

7-
__all__ = ['build_dataset', 'CombinedDataset']
8+
__all__ = ['build_dataset', 'CombinedDataset', 'MultiSourceSampler']

mmpose/datasets/samplers.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import itertools
3+
import math
4+
from typing import Iterator, List, Optional, Sized, Union
5+
6+
import torch
7+
from mmengine.dist import get_dist_info, sync_random_seed
8+
from torch.utils.data import Sampler
9+
10+
from mmpose.datasets import CombinedDataset
11+
from mmpose.registry import DATA_SAMPLERS
12+
13+
14+
@DATA_SAMPLERS.register_module()
15+
class MultiSourceSampler(Sampler):
16+
r"""Multi-Source Sampler.
17+
According to the sampling ratio, sample data from different
18+
datasets to form batches.
19+
Args:
20+
dataset (Sized): The dataset
21+
batch_size (int): Size of mini-batch
22+
source_ratio (list[int | float]): The sampling ratio of different
23+
source datasets in a mini-batch
24+
shuffle (bool): Whether shuffle the dataset or not. Defaults to
25+
``True``
26+
seed (int, optional): Random seed. If ``None``, set a random seed.
27+
Defaults to ``None``
28+
"""
29+
30+
def __init__(self,
31+
dataset: Sized,
32+
batch_size: int,
33+
source_ratio: List[Union[int, float]],
34+
shuffle: bool = True,
35+
seed: Optional[int] = None) -> None:
36+
37+
assert isinstance(dataset, CombinedDataset),\
38+
f'The dataset must be CombinedDataset, but get {dataset}'
39+
assert isinstance(batch_size, int) and batch_size > 0, \
40+
'batch_size must be a positive integer value, ' \
41+
f'but got batch_size={batch_size}'
42+
assert isinstance(source_ratio, list), \
43+
f'source_ratio must be a list, but got source_ratio={source_ratio}'
44+
assert len(source_ratio) == len(dataset._lens), \
45+
'The length of source_ratio must be equal to ' \
46+
f'the number of datasets, but got source_ratio={source_ratio}'
47+
48+
rank, world_size = get_dist_info()
49+
self.rank = rank
50+
self.world_size = world_size
51+
52+
self.dataset = dataset
53+
self.cumulative_sizes = [0] + list(itertools.accumulate(dataset._lens))
54+
self.batch_size = batch_size
55+
self.source_ratio = source_ratio
56+
57+
self.num_samples = math.ceil(len(self.dataset) / world_size)
58+
59+
self.num_per_source = [
60+
int(batch_size * sr / sum(source_ratio)) for sr in source_ratio
61+
]
62+
self.num_per_source[0] = batch_size - sum(self.num_per_source[1:])
63+
64+
assert sum(self.num_per_source) == batch_size, \
65+
'The sum of num_per_source must be equal to ' \
66+
f'batch_size, but get {self.num_per_source}'
67+
68+
self.seed = sync_random_seed() if seed is None else seed
69+
self.shuffle = shuffle
70+
self.source2inds = {
71+
source: self._indices_of_rank(len(ds))
72+
for source, ds in enumerate(dataset.datasets)
73+
}
74+
75+
def _infinite_indices(self, sample_size: int) -> Iterator[int]:
76+
"""Infinitely yield a sequence of indices."""
77+
g = torch.Generator()
78+
g.manual_seed(self.seed)
79+
while True:
80+
if self.shuffle:
81+
yield from torch.randperm(sample_size, generator=g).tolist()
82+
else:
83+
yield from torch.arange(sample_size).tolist()
84+
85+
def _indices_of_rank(self, sample_size: int) -> Iterator[int]:
86+
"""Slice the infinite indices by rank."""
87+
yield from itertools.islice(
88+
self._infinite_indices(sample_size), self.rank, None,
89+
self.world_size)
90+
91+
def __iter__(self) -> Iterator[int]:
92+
batch_buffer = []
93+
while True:
94+
for source, num in enumerate(self.num_per_source):
95+
batch_buffer_per_source = []
96+
for idx in self.source2inds[source]:
97+
idx += self.cumulative_sizes[source]
98+
batch_buffer_per_source.append(idx)
99+
if len(batch_buffer_per_source) == num:
100+
batch_buffer += batch_buffer_per_source
101+
break
102+
yield from batch_buffer
103+
batch_buffer = []
104+
105+
def __len__(self) -> int:
106+
return self.num_samples
107+
108+
def set_epoch(self, epoch: int) -> None:
109+
"""Compatible in `epoch-based runner."""
110+
pass

0 commit comments

Comments
 (0)