@@ -3907,28 +3907,27 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
3907
3907
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
3908
3908
}
3909
3909
3910
- // TODO: this implementation is wrong!
3911
- // static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
3912
- // const float p_delta, const int p_delta_rows, const float theta_scale) {
3913
- // const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
3914
- //
3915
- // if (col >= ncols) {
3916
- // return;
3917
- // }
3918
- //
3919
- // const int row = blockDim.x*blockIdx.x + threadIdx.x;
3920
- // const int i = row*ncols + col/2;
3921
- //
3922
- // const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
3923
- // const float sin_theta = sinf(theta);
3924
- // const float cos_theta = cosf(theta);
3925
- //
3926
- // const float x0 = x[i + 0];
3927
- // const float x1 = x[i + ncols/2];
3928
- //
3929
- // dst[i + 0] = x0*cos_theta - x1*sin_theta;
3930
- // dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
3931
- // }
3910
+ static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const float p0,
3911
+ const float p_delta, const int p_delta_rows, const float theta_scale) {
3912
+ const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
3913
+
3914
+ if (col >= ncols) {
3915
+ return ;
3916
+ }
3917
+
3918
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3919
+ const int i = row*ncols + col/2 ;
3920
+
3921
+ const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
3922
+ const float sin_theta = sinf (theta);
3923
+ const float cos_theta = cosf (theta);
3924
+
3925
+ const float x0 = x[i + 0 ];
3926
+ const float x1 = x[i + ncols/2 ];
3927
+
3928
+ dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
3929
+ dst[i + ncols/2 ] = x0*sin_theta + x1*cos_theta;
3930
+ }
3932
3931
3933
3932
static __global__ void rope_glm_f32 (const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
3934
3933
const int col = blockDim .x *blockIdx .x + threadIdx .x ;
@@ -4799,13 +4798,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
4799
4798
4800
4799
static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
4801
4800
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
4802
- GGML_ASSERT (nrows % 2 == 0 );
4801
+ GGML_ASSERT (nrows % 2 == 0 ); // GG: is this assert really needed? I don't see why
4803
4802
const dim3 block_dims (1 , 2 *CUDA_ROPE_BLOCK_SIZE, 1 );
4804
4803
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
4805
4804
const dim3 block_nums (nrows, num_blocks_x, 1 );
4806
4805
rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
4807
4806
}
4808
4807
4808
+ static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
4809
+ const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
4810
+ const dim3 block_dims (1 , 2 *CUDA_ROPE_BLOCK_SIZE, 1 );
4811
+ const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
4812
+ const dim3 block_nums (nrows, num_blocks_x, 1 );
4813
+ rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
4814
+ }
4815
+
4809
4816
static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
4810
4817
GGML_ASSERT (nrows % 4 == 0 );
4811
4818
const dim3 block_dims (4 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
@@ -5548,8 +5555,9 @@ inline void ggml_cuda_op_rope(
5548
5555
const float block_p = max (p - (n_ctx - 2 .f ), 0 .f );
5549
5556
rope_glm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
5550
5557
} else if (is_neox) {
5551
- GGML_ASSERT (false && " RoPE NeoX not implemented yet" );
5552
- #pragma message("TODO: implement RoPE NeoX for CUDA")
5558
+ GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
5559
+ const float p0 = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
5560
+ rope_neox_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
5553
5561
} else {
5554
5562
const float p0 = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
5555
5563
rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
0 commit comments