@@ -5,9 +5,10 @@ template <typename T, typename type_acc, int block_size>
5
5
static __global__ void mul_mat_vec (
6
6
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7
7
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8
- const int64_t row = blockIdx .x ;
9
- const int64_t channel = blockIdx .z ;
10
- const int tid = threadIdx .x ;
8
+ const int64_t row = blockIdx .x ;
9
+ const int64_t channel = blockIdx .z ;
10
+ const int tid = threadIdx .x ;
11
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
11
12
12
13
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
13
14
y += channel *stride_channel_y;
@@ -18,8 +19,8 @@ static __global__ void mul_mat_vec(
18
19
extern __shared__ char data_mmv[];
19
20
float * buf_iw = (float *) data_mmv;
20
21
21
- if (block_size > WARP_SIZE ) {
22
- if (tid < WARP_SIZE ) {
22
+ if (block_size > warp_size ) {
23
+ if (tid < warp_size ) {
23
24
buf_iw[tid] = 0 .0f ;
24
25
}
25
26
__syncthreads ();
@@ -67,16 +68,16 @@ static __global__ void mul_mat_vec(
67
68
static_assert (std::is_same<T, void >::value, " unsupported type" );
68
69
}
69
70
70
- sumf = warp_reduce_sum (sumf);
71
+ sumf = warp_reduce_sum<warp_size> (sumf);
71
72
72
- if (block_size > WARP_SIZE ) {
73
- buf_iw[tid/WARP_SIZE ] = sumf;
73
+ if (block_size > warp_size ) {
74
+ buf_iw[tid/warp_size ] = sumf;
74
75
__syncthreads ();
75
- if (tid >= WARP_SIZE ) {
76
+ if (tid >= warp_size ) {
76
77
return ;
77
78
}
78
79
sumf = buf_iw[tid];
79
- sumf = warp_reduce_sum (sumf);
80
+ sumf = warp_reduce_sum<warp_size> (sumf);
80
81
}
81
82
82
83
if (tid != 0 ) {
@@ -96,18 +97,27 @@ static void launch_mul_mat_vec_cuda(
96
97
GGML_ASSERT (stride_row % 2 == 0 );
97
98
GGML_ASSERT (nchannels_y % nchannels_x == 0 );
98
99
const int64_t channel_ratio = nchannels_y / nchannels_x;
100
+ int device;
101
+ int warp_size;
99
102
100
- int64_t block_size_best = WARP_SIZE;
101
- int64_t niter_best = (ncols + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
102
- for (int64_t block_size = 2 *WARP_SIZE; block_size <= 256 ; block_size += WARP_SIZE) {
103
+ CUDA_CHECK (cudaGetDevice (&device));
104
+ warp_size = ggml_cuda_info ().devices [device].warp_size ;
105
+
106
+ int64_t block_size_best = warp_size;
107
+ int64_t niter_best = (ncols + 2 *warp_size - 1 ) / (2 *warp_size);
108
+ int64_t max_block_size = 256 ;
109
+ if (ggml_cuda_info ().devices [device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info ().devices [device].cc < GGML_CUDA_CC_RDNA1) {
110
+ max_block_size = 128 ;
111
+ }
112
+ for (int64_t block_size = 2 *warp_size; block_size <= max_block_size; block_size += warp_size) {
103
113
const int64_t niter = (ncols + 2 *block_size - 1 ) / (2 *block_size);
104
114
if (niter < niter_best) {
105
115
niter_best = niter;
106
116
block_size_best = block_size;
107
117
}
108
118
}
109
119
110
- const int smem = WARP_SIZE *sizeof (float );
120
+ const int smem = warp_size *sizeof (float );
111
121
const dim3 block_nums (nrows, 1 , nchannels_y);
112
122
const dim3 block_dims (block_size_best, 1 , 1 );
113
123
switch (block_size_best) {
0 commit comments