Skip to content

Commit ca6ef98

Browse files
[Feature] Support IterableDataset on distributed environment (#1151)
* support IterableDataset on distributed environment * update distributed script in user guide ~
1 parent ba9f019 commit ca6ef98

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

docs/zh/user_guide.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,8 +952,7 @@ best_value: 0.02460772916674614
952952
953953
``` sh
954954
# 指定 0,1,2,3 张卡启动分布式数据并行训练
955-
export CUDA_VISIBLE_DEVICES=0,1,2,3
956-
python -m paddle.distributed.launch --gpus="0,1,2,3" poiseuille_flow.py
955+
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch poiseuille_flow.py
957956
```
958957
959958
<!-- #### 2.2.2 模型并行

ppsci/data/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def build_dataloader(_dataset, cfg):
6060
world_size = dist.get_world_size()
6161
# just return IterableDataset as dataloader
6262
if isinstance(_dataset, io.IterableDataset):
63-
if world_size > 1:
64-
raise ValueError(
65-
f"world_size({world_size}) should be 1 when using IterableDataset."
66-
)
6763
return _dataset
6864

6965
cfg = copy.deepcopy(cfg)

ppsci/data/dataset/array_dataset.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,68 @@
2020

2121
import numpy as np
2222
import paddle
23+
from paddle import distributed as dist
2324
from paddle import io
2425
from paddle import vision
2526

2627
from ppsci.utils import logger
2728

2829

30+
def _group_array_into_ranks(
31+
data: Optional[np.ndarray], rank: int, world_size: int
32+
) -> Optional[np.ndarray]:
33+
"""
34+
Group data into different ranks. For example, if data is [1, 2, 3, 4, 5, 6, 7, 8, 9] and
35+
world_size is 3, then the result will be rank0: [1, 4, 7], rank1: [2, 5, 8], rank2: [3, 6, 9].
36+
37+
Args:
38+
data (Optional[np.ndarray]): Data to be grouped, can be np.ndarray or None.
39+
rank (int): Rank number.
40+
world_size (int): Number of workers.
41+
42+
Returns:
43+
np.ndarray: Grouped data.
44+
"""
45+
if data is None:
46+
# skip grouping if data is None
47+
return None
48+
49+
# check if data can be grouped evenly into different ranks
50+
if len(data) < world_size:
51+
raise ValueError(
52+
f"Length of data to be grouped{len(data)} must be larger than world_size."
53+
)
54+
if len(data) % world_size != 0:
55+
raise ValueError(
56+
f"Length of data to be grouped{len(data)} must be divisible by world_size."
57+
)
58+
59+
return data[rank::world_size]
60+
61+
62+
def _group_dict_into_ranks(
63+
data_dict: Optional[Dict[str, Optional[np.ndarray]]], rank: int, world_size: int
64+
) -> Optional[Dict[str, Optional[np.ndarray]]]:
65+
"""
66+
Group data dict into different ranks for each key-value pair.
67+
68+
Args:
69+
data_dict (Dict[str, Optional[np.ndarray]]): Data to be grouped, can be Dict[str, Optional[np.ndarray]] or None.
70+
rank (int): Rank number.
71+
world_size (int): Number of workers.
72+
73+
Returns:
74+
Optional[Dict[str, Optional[np.ndarray]]]: Grouped data dict.
75+
"""
76+
77+
if data_dict is None:
78+
return data_dict
79+
80+
return {
81+
k: _group_array_into_ranks(v, rank, world_size) for k, v in data_dict.items()
82+
}
83+
84+
2985
class NamedArrayDataset(io.Dataset):
3086
"""Class for Named Array Dataset.
3187
@@ -132,6 +188,8 @@ def __init__(
132188
)
133189
self._len = len(next(iter(self.input.values())))
134190
self.transforms = transforms
191+
self.world_size_ = dist.get_world_size()
192+
self.rank_ = dist.get_rank()
135193

136194
@property
137195
def num_samples(self):
@@ -143,9 +201,15 @@ def __iter__(self):
143201
input_, label_, weight_ = self.transforms(
144202
self.input, self.label, self.weight
145203
)
146-
yield input_, label_, weight_
147204
else:
148-
yield self.input, self.label, self.weight
205+
input_, label_, weight_ = self.input, self.label, self.weight
206+
207+
if self.world_size_ > 1:
208+
input_ = _group_dict_into_ranks(input_, self.rank_, self.world_size_)
209+
label_ = _group_dict_into_ranks(label_, self.rank_, self.world_size_)
210+
weight_ = _group_dict_into_ranks(weight_, self.rank_, self.world_size_)
211+
212+
yield input_, label_, weight_
149213

150214
def __len__(self):
151215
return 1
@@ -197,6 +261,8 @@ def __init__(
197261

198262
self.weight_fn = weight
199263
self.transforms = transforms
264+
self.world_size_ = dist.get_world_size()
265+
self.rank_ = dist.get_rank()
200266

201267
@property
202268
def num_samples(self):
@@ -223,6 +289,18 @@ def to_tensor_dict(_dict):
223289
input_batch, label_batch, weight_batch = self.transforms(
224290
input_batch, label_batch, weight_batch
225291
)
292+
293+
if self.world_size_ > 1:
294+
input_batch = _group_dict_into_ranks(
295+
input_batch, self.rank_, self.world_size_
296+
)
297+
label_batch = _group_dict_into_ranks(
298+
label_batch, self.rank_, self.world_size_
299+
)
300+
weight_batch = _group_dict_into_ranks(
301+
weight_batch, self.rank_, self.world_size_
302+
)
303+
226304
yield to_tensor_dict(input_batch), to_tensor_dict(
227305
label_batch
228306
), to_tensor_dict(weight_batch)

0 commit comments

Comments
 (0)