Skip to content

Commit 73aa98c

Browse files
authored
[0d Tensor] update scatter for zero-dimension tensor (#49279)
* revert concat and change concat to stack * let stack kernel support int8, uint8 and bool type
1 parent 1c0afa7 commit 73aa98c

File tree

8 files changed

+37
-54
lines changed

8 files changed

+37
-54
lines changed

paddle/fluid/pybind/distributed_py.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,9 @@ void BindDistributed(py::module *m) {
255255
bool sync_op) {
256256
auto out_tensor_list =
257257
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
258-
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
258+
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
259259
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
260-
concat_out_tensor.impl());
260+
stack_out_tensor.impl());
261261
auto *out_dense = p_out_tensor.get();
262262

263263
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
@@ -307,16 +307,16 @@ void BindDistributed(py::module *m) {
307307
bool sync_op) {
308308
auto out_tensor_list =
309309
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
310-
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
310+
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
311311
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
312-
concat_out_tensor.impl());
312+
stack_out_tensor.impl());
313313
auto *out_dense = p_out_tensor.get();
314314

315315
auto in_tensor_list =
316316
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
317-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
317+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
318318
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
319-
concat_in_tensor.impl());
319+
stack_in_tensor.impl());
320320
auto in_dense = *p_in_tensor;
321321

322322
// in_tensor_list should not be empty
@@ -430,9 +430,9 @@ void BindDistributed(py::module *m) {
430430

431431
auto in_tensor_list =
432432
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
433-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
433+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
434434
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
435-
concat_in_tensor.impl());
435+
stack_in_tensor.impl());
436436
auto in_dense = *p_in_tensor;
437437

438438
distributed::ReduceScatterOptions opts{op};
@@ -484,9 +484,9 @@ void BindDistributed(py::module *m) {
484484

485485
auto in_tensor_list =
486486
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
487-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
487+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
488488
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
489-
concat_in_tensor.impl());
489+
stack_in_tensor.impl());
490490
auto in_dense = *p_in_tensor;
491491

492492
distributed::ScatterOptions opts{src};
@@ -746,9 +746,9 @@ void BindDistributed(py::module *m) {
746746
py::handle py_in_tensor) {
747747
auto out_tensor_list =
748748
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
749-
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
749+
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
750750
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
751-
concat_out_tensor.impl());
751+
stack_out_tensor.impl());
752752
auto *out_dense = p_out_tensor.get();
753753

754754
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
@@ -854,16 +854,16 @@ void BindDistributed(py::module *m) {
854854
py::handle py_in_tensor_list) {
855855
auto out_tensor_list =
856856
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
857-
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
857+
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
858858
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
859-
concat_out_tensor.impl());
859+
stack_out_tensor.impl());
860860
auto *out_dense = p_out_tensor.get();
861861

862862
auto in_tensor_list =
863863
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
864-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
864+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
865865
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
866-
concat_in_tensor.impl());
866+
stack_in_tensor.impl());
867867
auto in_dense = *p_in_tensor;
868868

869869
// in_tensor_list should not be empty
@@ -999,9 +999,9 @@ void BindDistributed(py::module *m) {
999999

10001000
auto in_tensor_list =
10011001
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
1002-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
1002+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
10031003
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1004-
concat_in_tensor.impl());
1004+
stack_in_tensor.impl());
10051005
auto in_dense = *p_in_tensor;
10061006

10071007
distributed::ReduceScatterOptions opts{op};
@@ -1057,9 +1057,9 @@ void BindDistributed(py::module *m) {
10571057

10581058
auto in_tensor_list =
10591059
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
1060-
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
1060+
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
10611061
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
1062-
concat_in_tensor.impl());
1062+
stack_in_tensor.impl());
10631063
auto in_dense = *p_in_tensor;
10641064

10651065
distributed::ScatterOptions opts{src};

paddle/phi/infermeta/multiary.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,14 +911,13 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
911911
// 1. calculate axis
912912
int rank = x.at(0)->dims().size();
913913
PADDLE_ENFORCE_EQ(
914-
!rank || (axis >= -rank && axis < rank),
914+
axis >= -rank && axis < rank,
915915
true,
916916
phi::errors::InvalidArgument(
917917
"The axis is expected to be in range of [%d, %d), but got %d",
918918
-rank,
919919
rank,
920920
axis));
921-
axis = rank ? axis : 0;
922921
if (axis < 0) {
923922
axis = axis + rank;
924923
}

paddle/phi/kernels/cpu/stack_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ PD_REGISTER_KERNEL(stack_grad,
5454
phi::StackGradKernel,
5555
float,
5656
double,
57+
bool,
5758
int64_t,
5859
int,
60+
uint8_t,
61+
int8_t,
62+
phi::dtype::float16,
5963
phi::dtype::bfloat16) {}

paddle/phi/kernels/cpu/stack_kernel.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ PD_REGISTER_KERNEL(stack,
5757
phi::StackKernel,
5858
float,
5959
double,
60-
int,
60+
bool,
6161
int64_t,
62+
int,
63+
uint8_t,
64+
int8_t,
65+
phi::dtype::float16,
6266
phi::dtype::bfloat16) {}

paddle/phi/kernels/funcs/concat_funcs.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ namespace funcs {
2121

2222
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
2323
PADDLE_ENFORCE_EQ(
24-
!rank || (axis >= -rank && axis < rank),
24+
axis >= -rank && axis < rank,
2525
true,
2626
phi::errors::InvalidArgument(
2727
"The axis is expected to be in range of [%d, %d), but got %d",
2828
-rank,
2929
rank,
3030
axis));
31-
axis = rank ? axis : 0;
3231
if (axis < 0) {
3332
axis = axis + rank;
3433
}

paddle/phi/kernels/gpu/concat_kernel.cu

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,35 +34,6 @@ void ConcatKernel(const Context& dev_ctx,
3434
DenseTensor* out) {
3535
int64_t axis = axis_scalar.to<int64_t>();
3636

37-
if (UNLIKELY(x[0]->dims().size() == 0)) {
38-
// for dims is 0 specially
39-
phi::DDim tmp_1dim, out_dims;
40-
out_dims[0] = x.size();
41-
tmp_1dim[0] = 1;
42-
43-
out->Resize(out_dims);
44-
dev_ctx.template Alloc<T>(out);
45-
46-
size_t output_offset = 0;
47-
for (auto* in : x) {
48-
if (in->numel() == 0UL) {
49-
continue;
50-
}
51-
auto in_stride = phi::stride_numel(tmp_1dim);
52-
auto out_stride = phi::stride_numel(out->dims());
53-
paddle::operators::StridedNumelCopyWithAxis<T>(
54-
dev_ctx,
55-
axis,
56-
out->data<T>() + output_offset,
57-
out_stride,
58-
in->data<T>(),
59-
in_stride,
60-
in_stride[axis]);
61-
output_offset += in_stride[axis];
62-
}
63-
return;
64-
}
65-
6637
axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size());
6738

6839
std::vector<phi::DDim> x_dims;

paddle/phi/kernels/gpu/stack_grad_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ PD_REGISTER_KERNEL(stack_grad,
139139
phi::StackGradKernel,
140140
float,
141141
double,
142+
bool,
142143
int64_t,
143144
int,
145+
uint8_t,
146+
int8_t,
144147
phi::dtype::float16,
145148
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/stack_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ PD_REGISTER_KERNEL(stack,
175175
phi::StackKernel,
176176
float,
177177
double,
178+
bool,
178179
int64_t,
179180
int,
181+
uint8_t,
182+
int8_t,
180183
phi::dtype::float16,
181184
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)