Skip to content

【Bug fix】Fix randn bug in large shape #70492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/index_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __global__ void VectorizedIndexKernel(T *out,

template <typename T, typename Functor>
void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) {
int numel = out->numel();
int64_t numel = out->numel();
T *out_data = dev_ctx.template Alloc<T>(out);
if (numel <= 0) return;
int vec_size = phi::GetVectorizedSize(out_data);
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/primitive/datamover_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,9 @@ __device__ __forceinline__ void ReadDataReduce(Ty* dst,
template <typename T, int NX, int NY, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst,
T* __restrict__ src,
int num) {
int64_t num) {
if (IsBoundary) {
int thread_offset = threadIdx.x * NX;
int64_t thread_offset = threadIdx.x * NX;
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if ((thread_offset + idx) < num) {
Expand All @@ -611,7 +611,7 @@ __device__ __forceinline__ void WriteData(T* dst,
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
constexpr int kVectorsPerThread = NX / kVectorSize;

int thread_offset = threadIdx.x * kVectorsPerThread;
int64_t thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread];
Expand Down Expand Up @@ -681,12 +681,12 @@ __device__ __forceinline__ void WriteData(T* dst,
template <typename Tx, typename Ty, int NX, int NY, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(Ty* dst,
const Tx* __restrict__ src,
int size_nx,
int64_t size_nx,
int size_ny,
int stride_nx,
int stride_ny) {
int thread_offset = threadIdx.x;
int left_size_nx = size_nx - thread_offset;
int64_t left_size_nx = size_nx - thread_offset;

// Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
Expand Down
60 changes: 56 additions & 4 deletions test/legacy_test/test_gaussian_random_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ def test_fixed_random_number(self):
if "V100" not in paddle.device.cuda.get_device_name():
return

def _check_random_value(dtype, expect, expect_mean, expect_std):
x = paddle.randn([32, 3, 1024, 1024], dtype=dtype)
def _check_random_value(shape, dtype, expect, expect_mean, expect_std):
x = paddle.randn(shape, dtype=dtype)
actual = x.numpy()
np.testing.assert_allclose(
actual[2, 1, 512, 1000:1010], expect, rtol=1e-05
Expand Down Expand Up @@ -487,7 +487,9 @@ def _check_random_value(dtype, expect, expect_mean, expect_std):
-0.0000053026194133403266873214888799115129813799285329878330230713
)
expect_std = 0.99999191058126390974081232343451119959354400634765625
_check_random_value(paddle.float64, expect, expect_mean, expect_std)
_check_random_value(
[32, 3, 1024, 1024], paddle.float64, expect, expect_mean, expect_std
)

expect = [
-0.7988942,
Expand All @@ -503,7 +505,57 @@ def _check_random_value(dtype, expect, expect_mean, expect_std):
]
expect_mean = -0.00004762359094456769526004791259765625
expect_std = 0.999975681304931640625
_check_random_value(paddle.float32, expect, expect_mean, expect_std)
_check_random_value(
[32, 3, 1024, 1024], paddle.float32, expect, expect_mean, expect_std
)

# test randn in large shape
expect = [
-1.4770278,
-0.637431,
-0.41728288,
0.31339037,
-1.7627009,
0.4061812,
1.0679497,
0.03405872,
-0.7271235,
-0.42642546,
]

expect_mean = 0.0000010386128224126878194510936737060547
expect_std = 1.00000822544097900390625
_check_random_value(
[4, 2, 60000, 12000],
paddle.float32,
expect,
expect_mean,
expect_std,
)

# test randn with seed 0 in large shape
paddle.seed(0)
expect = [
-1.7653463,
0.5957617,
0.45865676,
-0.3061651,
0.17204928,
-1.7802757,
-0.10731091,
1.042362,
0.70476884,
0.2720365,
]
expect_mean = -0.0000002320642948916429304517805576324463
expect_std = 1.00001156330108642578125
_check_random_value(
[4, 2, 60000, 12000],
paddle.float32,
expect,
expect_mean,
expect_std,
)


if __name__ == "__main__":
Expand Down
Loading