Skip to content

Commit aa54376

Browse files
authored
Fix randn bug in large shape (PaddlePaddle#70492)
1 parent a2f1280 commit aa54376

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

paddle/phi/kernels/funcs/index_impl.cu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ __global__ void VectorizedIndexKernel(T *out,
5353

5454
template <typename T, typename Functor>
5555
void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) {
56-
int numel = out->numel();
56+
int64_t numel = out->numel();
5757
T *out_data = dev_ctx.template Alloc<T>(out);
5858
if (numel <= 0) return;
5959
int vec_size = phi::GetVectorizedSize(out_data);

paddle/phi/kernels/primitive/datamover_primitives.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -597,9 +597,9 @@ __device__ __forceinline__ void ReadDataReduce(Ty* dst,
597597
template <typename T, int NX, int NY, bool IsBoundary = false>
598598
__device__ __forceinline__ void WriteData(T* dst,
599599
T* __restrict__ src,
600-
int num) {
600+
int64_t num) {
601601
if (IsBoundary) {
602-
int thread_offset = threadIdx.x * NX;
602+
int64_t thread_offset = threadIdx.x * NX;
603603
#pragma unroll
604604
for (int idx = 0; idx < NX; ++idx) {
605605
if ((thread_offset + idx) < num) {
@@ -611,7 +611,7 @@ __device__ __forceinline__ void WriteData(T* dst,
611611
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
612612
constexpr int kVectorsPerThread = NX / kVectorSize;
613613

614-
int thread_offset = threadIdx.x * kVectorsPerThread;
614+
int64_t thread_offset = threadIdx.x * kVectorsPerThread;
615615
using VecType = details::VectorType<T, kVectorSize>;
616616
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
617617
VecType vec_temp[kVectorsPerThread];
@@ -681,12 +681,12 @@ __device__ __forceinline__ void WriteData(T* dst,
681681
template <typename Tx, typename Ty, int NX, int NY, bool IsBoundary = false>
682682
__device__ __forceinline__ void WriteData(Ty* dst,
683683
const Tx* __restrict__ src,
684-
int size_nx,
684+
int64_t size_nx,
685685
int size_ny,
686686
int stride_nx,
687687
int stride_ny) {
688688
int thread_offset = threadIdx.x;
689-
int left_size_nx = size_nx - thread_offset;
689+
int64_t left_size_nx = size_nx - thread_offset;
690690

691691
// Each branch is added for better performance
692692
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1

test/legacy_test/test_gaussian_random_op.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ def test_fixed_random_number(self):
458458
if "V100" not in paddle.device.cuda.get_device_name():
459459
return
460460

461-
def _check_random_value(dtype, expect, expect_mean, expect_std):
462-
x = paddle.randn([32, 3, 1024, 1024], dtype=dtype)
461+
def _check_random_value(shape, dtype, expect, expect_mean, expect_std):
462+
x = paddle.randn(shape, dtype=dtype)
463463
actual = x.numpy()
464464
np.testing.assert_allclose(
465465
actual[2, 1, 512, 1000:1010], expect, rtol=1e-05
@@ -487,7 +487,9 @@ def _check_random_value(dtype, expect, expect_mean, expect_std):
487487
-0.0000053026194133403266873214888799115129813799285329878330230713
488488
)
489489
expect_std = 0.99999191058126390974081232343451119959354400634765625
490-
_check_random_value(paddle.float64, expect, expect_mean, expect_std)
490+
_check_random_value(
491+
[32, 3, 1024, 1024], paddle.float64, expect, expect_mean, expect_std
492+
)
491493

492494
expect = [
493495
-0.7988942,
@@ -503,7 +505,57 @@ def _check_random_value(dtype, expect, expect_mean, expect_std):
503505
]
504506
expect_mean = -0.00004762359094456769526004791259765625
505507
expect_std = 0.999975681304931640625
506-
_check_random_value(paddle.float32, expect, expect_mean, expect_std)
508+
_check_random_value(
509+
[32, 3, 1024, 1024], paddle.float32, expect, expect_mean, expect_std
510+
)
511+
512+
# test randn in large shape
513+
expect = [
514+
-1.4770278,
515+
-0.637431,
516+
-0.41728288,
517+
0.31339037,
518+
-1.7627009,
519+
0.4061812,
520+
1.0679497,
521+
0.03405872,
522+
-0.7271235,
523+
-0.42642546,
524+
]
525+
526+
expect_mean = 0.0000010386128224126878194510936737060547
527+
expect_std = 1.00000822544097900390625
528+
_check_random_value(
529+
[4, 2, 60000, 12000],
530+
paddle.float32,
531+
expect,
532+
expect_mean,
533+
expect_std,
534+
)
535+
536+
# test randn with seed 0 in large shape
537+
paddle.seed(0)
538+
expect = [
539+
-1.7653463,
540+
0.5957617,
541+
0.45865676,
542+
-0.3061651,
543+
0.17204928,
544+
-1.7802757,
545+
-0.10731091,
546+
1.042362,
547+
0.70476884,
548+
0.2720365,
549+
]
550+
expect_mean = -0.0000002320642948916429304517805576324463
551+
expect_std = 1.00001156330108642578125
552+
_check_random_value(
553+
[4, 2, 60000, 12000],
554+
paddle.float32,
555+
expect,
556+
expect_mean,
557+
expect_std,
558+
)
507559

508560

509561
if __name__ == "__main__":

0 commit comments

Comments
 (0)