16
16
#ifndef FLASHINFER_NORM_CUH_
17
17
#define FLASHINFER_NORM_CUH_
18
18
19
+ #include < numeric>
20
+
19
21
#include " flashinfer/utils.cuh"
20
22
#include " math.cuh"
21
23
#include " utils.cuh"
@@ -25,7 +27,7 @@ namespace flashinfer {
25
27
26
28
namespace norm {
27
29
28
- template <typename T>
30
+ template <uint32_t VEC_SIZE, typename T>
29
31
__global__ void RMSNormKernel (T* __restrict__ x, T* __restrict__ w, T* __restrict__ y,
30
32
const uint32_t d, float eps) {
31
33
const uint32_t bx = blockIdx .x ;
@@ -35,20 +37,19 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
35
37
// NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32
36
38
const uint32_t thread_id = tx + ty * warp_size;
37
39
const uint32_t num_threads = num_warps * warp_size;
38
- constexpr uint32_t vec_size = 16 / sizeof (T);
39
- const uint32_t rounds = ceil_div (d, vec_size * num_threads);
40
+ const uint32_t rounds = ceil_div (d, VEC_SIZE * num_threads);
40
41
extern __shared__ float smem[];
41
42
42
43
float sum_sq = 0 .f ;
43
44
44
45
for (uint32_t i = 0 ; i < rounds; i++) {
45
- vec_t <T, vec_size > x_vec;
46
+ vec_t <T, VEC_SIZE > x_vec;
46
47
x_vec.fill (0 );
47
- if ((i * num_threads + thread_id) * vec_size < d) {
48
- x_vec.load (x + bx * d + i * num_threads * vec_size + thread_id * vec_size );
48
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49
+ x_vec.load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE );
49
50
}
50
51
#pragma unroll
51
- for (uint32_t j = 0 ; j < vec_size ; j++) {
52
+ for (uint32_t j = 0 ; j < VEC_SIZE ; j++) {
52
53
sum_sq += float (x_vec[j]) * float (x_vec[j]);
53
54
}
54
55
}
@@ -75,40 +76,41 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
75
76
float rms_rcp = math::rsqrt (smem[0 ] / float (d) + eps);
76
77
77
78
for (uint32_t i = 0 ; i < rounds; i++) {
78
- vec_t <T, vec_size > x_vec;
79
- vec_t <T, vec_size > w_vec;
80
- vec_t <T, vec_size > y_vec;
79
+ vec_t <T, VEC_SIZE > x_vec;
80
+ vec_t <T, VEC_SIZE > w_vec;
81
+ vec_t <T, VEC_SIZE > y_vec;
81
82
x_vec.fill (0 );
82
83
w_vec.fill (0 );
83
- if ((i * num_threads + thread_id) * vec_size < d) {
84
- x_vec.load (x + bx * d + i * num_threads * vec_size + thread_id * vec_size );
85
- w_vec.load (w + i * num_threads * vec_size + thread_id * vec_size );
84
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85
+ x_vec.load (x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE );
86
+ w_vec.load (w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE );
86
87
}
87
88
#pragma unroll
88
- for (uint32_t j = 0 ; j < vec_size ; j++) {
89
+ for (uint32_t j = 0 ; j < VEC_SIZE ; j++) {
89
90
y_vec[j] = float (x_vec[j]) * rms_rcp * float (w_vec[j]);
90
91
}
91
- if ((i * num_threads + thread_id) * vec_size < d) {
92
- y_vec.store (y + bx * d + i * num_threads * vec_size + thread_id * vec_size );
92
+ if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93
+ y_vec.store (y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE );
93
94
}
94
95
}
95
96
}
96
97
97
98
template <typename T>
98
99
cudaError_t RMSNorm (T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5 ,
99
100
cudaStream_t stream = 0 ) {
100
- constexpr uint32_t vec_size = 16 / sizeof (T);
101
- if (d % vec_size != 0 ) {
102
- return cudaErrorInvalidValue;
103
- }
101
+ const uint32_t vec_size = std::gcd (16 / sizeof (T), d);
102
+
104
103
const uint32_t block_size = std::min<uint32_t >(1024 , d / vec_size);
105
104
const uint32_t num_warps = ceil_div (block_size, 32 );
106
105
dim3 nblks (batch_size);
107
106
dim3 nthrs (32 , num_warps);
108
107
const uint32_t smem_size = num_warps * sizeof (float );
109
- auto kernel = RMSNormKernel<T>;
110
108
void * args[] = {&x, &w, &y, &d, &eps};
111
- FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
109
+
110
+ DISPATCH_ALIGNED_VEC_SIZE (vec_size, VEC_SIZE, {
111
+ auto kernel = RMSNormKernel<VEC_SIZE, T>;
112
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
113
+ });
112
114
return cudaSuccess;
113
115
}
114
116
0 commit comments