20
20
21
21
import numpy as np
22
22
import paddle
23
+ from paddle import distributed as dist
23
24
from paddle import io
24
25
from paddle import vision
25
26
26
27
from ppsci .utils import logger
27
28
28
29
30
+ def _group_array_into_ranks (
31
+ data : Optional [np .ndarray ], rank : int , world_size : int
32
+ ) -> Optional [np .ndarray ]:
33
+ """
34
+ Group data into different ranks. For example, if data is [1, 2, 3, 4, 5, 6, 7, 8, 9] and
35
+ world_size is 3, then the result will be rank0: [1, 4, 7], rank1: [2, 5, 8], rank2: [3, 6, 9].
36
+
37
+ Args:
38
+ data (Optional[np.ndarray]): Data to be grouped, can be np.ndarray or None.
39
+ rank (int): Rank number.
40
+ world_size (int): Number of workers.
41
+
42
+ Returns:
43
+ np.ndarray: Grouped data.
44
+ """
45
+ if data is None :
46
+ # skip grouping if data is None
47
+ return None
48
+
49
+ # check if data can be grouped evenly into different ranks
50
+ if len (data ) < world_size :
51
+ raise ValueError (
52
+ f"Length of data to be grouped{ len (data )} must be larger than world_size."
53
+ )
54
+ if len (data ) % world_size != 0 :
55
+ raise ValueError (
56
+ f"Length of data to be grouped{ len (data )} must be divisible by world_size."
57
+ )
58
+
59
+ return data [rank ::world_size ]
60
+
61
+
62
+ def _group_dict_into_ranks (
63
+ data_dict : Optional [Dict [str , Optional [np .ndarray ]]], rank : int , world_size : int
64
+ ) -> Optional [Dict [str , Optional [np .ndarray ]]]:
65
+ """
66
+ Group data dict into different ranks for each key-value pair.
67
+
68
+ Args:
69
+ data_dict (Dict[str, Optional[np.ndarray]]): Data to be grouped, can be Dict[str, Optional[np.ndarray]] or None.
70
+ rank (int): Rank number.
71
+ world_size (int): Number of workers.
72
+
73
+ Returns:
74
+ Optional[Dict[str, Optional[np.ndarray]]]: Grouped data dict.
75
+ """
76
+
77
+ if data_dict is None :
78
+ return data_dict
79
+
80
+ return {
81
+ k : _group_array_into_ranks (v , rank , world_size ) for k , v in data_dict .items ()
82
+ }
83
+
84
+
29
85
class NamedArrayDataset (io .Dataset ):
30
86
"""Class for Named Array Dataset.
31
87
@@ -132,6 +188,8 @@ def __init__(
132
188
)
133
189
self ._len = len (next (iter (self .input .values ())))
134
190
self .transforms = transforms
191
+ self .world_size_ = dist .get_world_size ()
192
+ self .rank_ = dist .get_rank ()
135
193
136
194
@property
137
195
def num_samples (self ):
@@ -143,9 +201,15 @@ def __iter__(self):
143
201
input_ , label_ , weight_ = self .transforms (
144
202
self .input , self .label , self .weight
145
203
)
146
- yield input_ , label_ , weight_
147
204
else :
148
- yield self .input , self .label , self .weight
205
+ input_ , label_ , weight_ = self .input , self .label , self .weight
206
+
207
+ if self .world_size_ > 1 :
208
+ input_ = _group_dict_into_ranks (input_ , self .rank_ , self .world_size_ )
209
+ label_ = _group_dict_into_ranks (label_ , self .rank_ , self .world_size_ )
210
+ weight_ = _group_dict_into_ranks (weight_ , self .rank_ , self .world_size_ )
211
+
212
+ yield input_ , label_ , weight_
149
213
150
214
def __len__ (self ):
151
215
return 1
@@ -197,6 +261,8 @@ def __init__(
197
261
198
262
self .weight_fn = weight
199
263
self .transforms = transforms
264
+ self .world_size_ = dist .get_world_size ()
265
+ self .rank_ = dist .get_rank ()
200
266
201
267
@property
202
268
def num_samples (self ):
@@ -223,6 +289,18 @@ def to_tensor_dict(_dict):
223
289
input_batch , label_batch , weight_batch = self .transforms (
224
290
input_batch , label_batch , weight_batch
225
291
)
292
+
293
+ if self .world_size_ > 1 :
294
+ input_batch = _group_dict_into_ranks (
295
+ input_batch , self .rank_ , self .world_size_
296
+ )
297
+ label_batch = _group_dict_into_ranks (
298
+ label_batch , self .rank_ , self .world_size_
299
+ )
300
+ weight_batch = _group_dict_into_ranks (
301
+ weight_batch , self .rank_ , self .world_size_
302
+ )
303
+
226
304
yield to_tensor_dict (input_batch ), to_tensor_dict (
227
305
label_batch
228
306
), to_tensor_dict (weight_batch )
0 commit comments