Skip to content

Commit 70784fb

Browse files
WintersMontagne10335jiahy0825
authored andcommitted
【Hackathon 5th No.51】 为 Paddle 新增 flatten 的 spmd 切分推导规则 (PaddlePaddle#57875)
* Adding flatten spmd segmentation and derivation rules for Paddle * fix bugs * add unit test code * modified: test/auto_parallel/spmd_rules/CMakeLists.txt * modify the code according to the review * modified: paddle/phi/infermeta/spmd_rules/flatten.cc
1 parent ea99965 commit 70784fb

File tree

5 files changed

+642
-0
lines changed

5 files changed

+642
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/flatten.h"
16+
#include <numeric>
17+
18+
#include "glog/logging.h"
19+
20+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
21+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
22+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
23+
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
24+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
25+
26+
namespace phi {
27+
namespace distributed {
28+
29+
using phi::distributed::auto_parallel::str_join;
30+
31+
int PreprocessAxis(int axis, int ndim) {
32+
if (axis < 0) {
33+
axis += ndim;
34+
}
35+
36+
PADDLE_ENFORCE_LT(
37+
axis,
38+
ndim,
39+
phi::errors::InvalidArgument("The Start_axis or Stop_axis [%d] is not "
40+
"less than the Tensor X's rank [%d].",
41+
axis,
42+
ndim));
43+
44+
return axis;
45+
}
46+
47+
std::vector<DimTrans*> MakeFlattenDimTrans(
48+
const std::vector<int64_t>& src_shape, int start_axis, int stop_axis) {
49+
std::vector<DimTrans*> ret;
50+
51+
std::vector<DimTrans*> input_dims;
52+
for (int64_t i = 0; i < static_cast<int64_t>(src_shape.size()); i++) {
53+
if (i < start_axis || i > stop_axis) {
54+
ret.emplace_back(new InputDim(i));
55+
} else {
56+
input_dims.emplace_back(new InputDim(i));
57+
}
58+
59+
if (i == stop_axis) {
60+
ret.emplace_back(make_flatten(input_dims));
61+
}
62+
}
63+
64+
return ret;
65+
}
66+
67+
std::vector<DimTrans*> MakeFlattenDimTransReverse(
68+
const std::vector<int64_t>& src_shape, int start_axis, int stop_axis) {
69+
std::vector<DimTrans*> ret;
70+
71+
std::vector<int64_t> tgt_splitted_shape;
72+
for (int i = start_axis; i <= stop_axis; i++) {
73+
tgt_splitted_shape.emplace_back(src_shape[i]);
74+
}
75+
76+
for (int64_t i = 0; i < static_cast<int64_t>(src_shape.size()); i++) {
77+
if (i < start_axis) {
78+
ret.emplace_back(new InputDim(i));
79+
} else if (i > stop_axis) {
80+
ret.emplace_back(new InputDim(i - (stop_axis - start_axis)));
81+
} else {
82+
ret.emplace_back(make_split(
83+
new InputDim(start_axis), tgt_splitted_shape, i - start_axis));
84+
}
85+
}
86+
87+
return ret;
88+
}
89+
90+
SpmdInfo FlattenInferSpmd(const DistMetaTensor& x,
91+
int start_axis,
92+
int stop_axis) {
93+
// Step0: Verify input args based on flatten logic
94+
auto src_shape = phi::vectorize(x.dims());
95+
int x_ndim = static_cast<int64_t>(src_shape.size());
96+
auto x_dist_attr_src = x.dist_attr();
97+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
98+
PADDLE_ENFORCE_EQ(
99+
x_ndim,
100+
x_dims_mapping.size(),
101+
phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's "
102+
"dims_mapping size [%d] are not matched.",
103+
x_ndim,
104+
x_dims_mapping.size()));
105+
106+
// Step1: Build the transformation from
107+
// the original shape to the target shape
108+
109+
start_axis = PreprocessAxis(start_axis, x_ndim);
110+
stop_axis = PreprocessAxis(stop_axis, x_ndim);
111+
std::vector<DimTrans*> trans =
112+
MakeFlattenDimTrans(src_shape, start_axis, stop_axis);
113+
114+
// Step2: Infer the dims mapping of input (if reshard is
115+
// needed) and output from the dimension transformation.
116+
std::vector<std::vector<int64_t>> dims_mapping_vec =
117+
InferFromDimTrans(x, trans);
118+
119+
// Step3: Update the dist attributes of input
120+
// and output with the inferred dims mapping.
121+
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
122+
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
123+
TensorDistAttr out_dist_attr(x_dist_attr_src);
124+
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
125+
126+
VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) << "]";
127+
VLOG(4) << "Start_axis: " << start_axis;
128+
VLOG(4) << "Stop_axis: " << start_axis;
129+
VLOG(4) << "Transformation from input to output:";
130+
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
131+
DimTrans* t = trans[i];
132+
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
133+
}
134+
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
135+
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
136+
VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";
137+
138+
CleanUp();
139+
140+
return {{x_dist_attr_dst}, {out_dist_attr}};
141+
}
142+
143+
SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,
144+
const DistMetaTensor& out,
145+
int start_axis,
146+
int stop_axis) {
147+
// Step0: Verify input args based on flatten logic
148+
auto x_shape = phi::vectorize(x.dims());
149+
auto x_ndim = x_shape.size();
150+
auto out_shape = phi::vectorize(out.dims());
151+
int out_ndim = out_shape.size();
152+
auto out_dist_attr_src = out.dist_attr();
153+
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
154+
PADDLE_ENFORCE_EQ(
155+
out_ndim,
156+
out_dims_mapping.size(),
157+
phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's "
158+
"dims_mapping size [%d] are not matched.",
159+
out_ndim,
160+
out_dims_mapping.size()));
161+
162+
// Step1: Build the transformation from the output shape
163+
// to original shape. This function infers the dims mapping
164+
// from output to input, we first get the transformation
165+
// from output to input so that we can infer the dims mapping
166+
// with the map from output axes to input axes.
167+
168+
start_axis = PreprocessAxis(start_axis, x_ndim);
169+
stop_axis = PreprocessAxis(stop_axis, x_ndim);
170+
171+
std::vector<DimTrans*> trans =
172+
MakeFlattenDimTransReverse(x_shape, start_axis, stop_axis);
173+
174+
// Step2: Infer the dims mapping of input with
175+
// output's dims_mapping and the transformation.
176+
std::vector<std::vector<int64_t>> dims_mapping_vec =
177+
InferFromDimTrans(out, trans);
178+
179+
// Step3: Update the dist attributes of input
180+
// and output with the inferred dims mapping
181+
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
182+
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
183+
TensorDistAttr x_dist_attr(x.dist_attr());
184+
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
185+
186+
VLOG(4) << "FlattenInferSpmdReverse: Out shape: [" << str_join(out_shape)
187+
<< "] X shape: [" << str_join(x_shape) << "]";
188+
VLOG(4) << "Transformation from output to input:";
189+
for (int64_t i = 0, n = trans.size(); i < n; i++) {
190+
DimTrans* t = trans[i];
191+
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string();
192+
}
193+
VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] "
194+
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
195+
VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";
196+
197+
CleanUp();
198+
199+
return {{x_dist_attr}, {out_dist_attr_dst}};
200+
}
201+
202+
} // namespace distributed
203+
} // namespace phi
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
19+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
20+
#include "paddle/phi/core/distributed/type_defs.h"
21+
22+
namespace phi {
23+
namespace distributed {
24+
25+
SpmdInfo FlattenInferSpmd(const DistMetaTensor& x,
26+
int start_axis,
27+
int stop_axis);
28+
29+
SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x,
30+
const DistMetaTensor& out,
31+
int start_axis,
32+
int stop_axis);
33+
} // namespace distributed
34+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h"
2020
#include "paddle/phi/infermeta/spmd_rules/elementwise.h"
2121
#include "paddle/phi/infermeta/spmd_rules/embedding.h"
22+
#include "paddle/phi/infermeta/spmd_rules/flatten.h"
2223
#include "paddle/phi/infermeta/spmd_rules/layer_norm.h"
2324
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
2425
#include "paddle/phi/infermeta/spmd_rules/reduction.h"
@@ -492,6 +493,11 @@ PD_REGISTER_SPMD_RULE(reshape2,
492493
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
493494
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse));
494495

496+
// flatten rule
497+
PD_REGISTER_SPMD_RULE(flatten,
498+
PD_INFER_SPMD(phi::distributed::FlattenInferSpmd),
499+
PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse));
500+
495501
// embedding rule
496502
PD_REGISTER_SPMD_RULE(
497503
embedding,

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ if(WITH_DISTRIBUTE)
1818
py_test_modules(test_default_data_parallel_rule MODULES
1919
test_default_data_parallel_rule)
2020
py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule)
21+
py_test_modules(test_flatten_rule MODULES test_flatten_rule)
2122
# End of unittests WITH single card WITHOUT timeout
2223

2324
endif()

0 commit comments

Comments
 (0)