Skip to content

Commit f01c768

Browse files
authored
sampling: fix alignment issue for vocab_size not divisible by vec_size (#211)
Remove debug codes, and support different vec_size.
1 parent f42e328 commit f01c768

File tree

7 files changed

+146
-311
lines changed

7 files changed

+146
-311
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cmake_minimum_required(VERSION 3.20)
1+
cmake_minimum_required(VERSION 3.23.1)
22
project(flashinfer CUDA CXX)
33

44
include(cmake/utils/Utils.cmake)

include/flashinfer/norm.cuh

+24-22
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#ifndef FLASHINFER_NORM_CUH_
1717
#define FLASHINFER_NORM_CUH_
1818

19+
#include <numeric>
20+
1921
#include "flashinfer/utils.cuh"
2022
#include "math.cuh"
2123
#include "utils.cuh"
@@ -25,7 +27,7 @@ namespace flashinfer {
2527

2628
namespace norm {
2729

28-
template <typename T>
30+
template <uint32_t VEC_SIZE, typename T>
2931
__global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restrict__ y,
3032
const uint32_t d, float eps) {
3133
const uint32_t bx = blockIdx.x;
@@ -35,20 +37,19 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
3537
// NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32
3638
const uint32_t thread_id = tx + ty * warp_size;
3739
const uint32_t num_threads = num_warps * warp_size;
38-
constexpr uint32_t vec_size = 16 / sizeof(T);
39-
const uint32_t rounds = ceil_div(d, vec_size * num_threads);
40+
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
4041
extern __shared__ float smem[];
4142

4243
float sum_sq = 0.f;
4344

4445
for (uint32_t i = 0; i < rounds; i++) {
45-
vec_t<T, vec_size> x_vec;
46+
vec_t<T, VEC_SIZE> x_vec;
4647
x_vec.fill(0);
47-
if ((i * num_threads + thread_id) * vec_size < d) {
48-
x_vec.load(x + bx * d + i * num_threads * vec_size + thread_id * vec_size);
48+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
49+
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
4950
}
5051
#pragma unroll
51-
for (uint32_t j = 0; j < vec_size; j++) {
52+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
5253
sum_sq += float(x_vec[j]) * float(x_vec[j]);
5354
}
5455
}
@@ -75,40 +76,41 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
7576
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
7677

7778
for (uint32_t i = 0; i < rounds; i++) {
78-
vec_t<T, vec_size> x_vec;
79-
vec_t<T, vec_size> w_vec;
80-
vec_t<T, vec_size> y_vec;
79+
vec_t<T, VEC_SIZE> x_vec;
80+
vec_t<T, VEC_SIZE> w_vec;
81+
vec_t<T, VEC_SIZE> y_vec;
8182
x_vec.fill(0);
8283
w_vec.fill(0);
83-
if ((i * num_threads + thread_id) * vec_size < d) {
84-
x_vec.load(x + bx * d + i * num_threads * vec_size + thread_id * vec_size);
85-
w_vec.load(w + i * num_threads * vec_size + thread_id * vec_size);
84+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
85+
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
86+
w_vec.load(w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
8687
}
8788
#pragma unroll
88-
for (uint32_t j = 0; j < vec_size; j++) {
89+
for (uint32_t j = 0; j < VEC_SIZE; j++) {
8990
y_vec[j] = float(x_vec[j]) * rms_rcp * float(w_vec[j]);
9091
}
91-
if ((i * num_threads + thread_id) * vec_size < d) {
92-
y_vec.store(y + bx * d + i * num_threads * vec_size + thread_id * vec_size);
92+
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
93+
y_vec.store(y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
9394
}
9495
}
9596
}
9697

9798
template <typename T>
9899
cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5,
99100
cudaStream_t stream = 0) {
100-
constexpr uint32_t vec_size = 16 / sizeof(T);
101-
if (d % vec_size != 0) {
102-
return cudaErrorInvalidValue;
103-
}
101+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
102+
104103
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
105104
const uint32_t num_warps = ceil_div(block_size, 32);
106105
dim3 nblks(batch_size);
107106
dim3 nthrs(32, num_warps);
108107
const uint32_t smem_size = num_warps * sizeof(float);
109-
auto kernel = RMSNormKernel<T>;
110108
void* args[] = {&x, &w, &y, &d, &eps};
111-
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
109+
110+
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
111+
auto kernel = RMSNormKernel<VEC_SIZE, T>;
112+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
113+
});
112114
return cudaSuccess;
113115
}
114116

0 commit comments

Comments
 (0)