2
2
#include " ggml.h"
3
3
#include " softmax.cuh"
4
4
#include < cstdint>
5
+ #include < utility>
5
6
6
7
template <typename T>
7
8
static __device__ __forceinline__ float t2f32 (T val) {
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
181
182
}
182
183
}
183
184
185
+ template <int ... Ns, typename T>
186
+ static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
187
+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188
+ {
189
+ const int id = ggml_cuda_get_device ();
190
+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
191
+
192
+ auto launch_kernel = [=](auto I) -> bool {
193
+ constexpr int ncols = decltype (I)::value;
194
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
195
+
196
+ if (p.ncols == ncols) {
197
+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
198
+ soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199
+ (x, mask, dst, p);
200
+ return true ;
201
+ }
202
+ return false ;
203
+ };
204
+
205
+ // unary fold over launch_kernel
206
+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
207
+ return ;
208
+ }
209
+
210
+ // default case
211
+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
212
+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p);
213
+ }
214
+
215
+
184
216
template <typename T>
185
217
static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
186
218
int nth = WARP_SIZE;
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
193
225
static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
194
226
195
227
196
- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
197
- if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
198
- switch (ncols_x) {
199
- case 32 :
200
- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201
- (x, mask, dst, params);
202
- break ;
203
- case 64 :
204
- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, mask, dst, params);
206
- break ;
207
- case 128 :
208
- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
209
- (x, mask, dst, params);
210
- break ;
211
- case 256 :
212
- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
213
- (x, mask, dst, params);
214
- break ;
215
- case 512 :
216
- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
217
- (x, mask, dst, params);
218
- break ;
219
- case 1024 :
220
- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
221
- (x, mask, dst, params);
222
- break ;
223
- case 2048 :
224
- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
225
- (x, mask, dst, params);
226
- break ;
227
- case 4096 :
228
- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
229
- (x, mask, dst, params);
230
- break ;
231
- default :
232
- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
233
- (x, mask, dst, params);
234
- break ;
235
- }
228
+ const int id = ggml_cuda_get_device ();
229
+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
230
+
231
+
232
+ if (nbytes_shared <= smpbo) {
233
+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
236
234
} else {
237
235
const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
238
236
soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments