From 12bef0d3ad38c2b3dc8bc5d443be2fed36c88bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 19 Apr 2025 23:58:46 +0200 Subject: [PATCH] CUDA: batched+noncont MMQ, refactor bs>1 MoE code --- ggml/src/ggml-cuda/getrows.cu | 171 ++++++---- ggml/src/ggml-cuda/getrows.cuh | 7 + ggml/src/ggml-cuda/ggml-cuda.cu | 260 +++++++-------- ggml/src/ggml-cuda/mmq.cu | 220 +++++++++++-- ggml/src/ggml-cuda/mmq.cuh | 554 +++++++++++++++++++++++--------- ggml/src/ggml-cuda/mmvq.cu | 6 +- ggml/src/ggml-cuda/quantize.cu | 49 +-- ggml/src/ggml-cuda/quantize.cuh | 15 +- tests/test-backend-ops.cpp | 5 + 9 files changed, 858 insertions(+), 429 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4cef53a98cfd6..ea8bf69160996 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -33,8 +33,8 @@ static __global__ void k_get_rows( dfloat2 v; dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = v.x; - dst_row[iybs + iqs + y_offset] = v.y; + dst_row[iybs + iqs + 0] = float(v.x); + dst_row[iybs + iqs + y_offset] = float(v.y); } template @@ -60,7 +60,7 @@ static __global__ void k_get_rows_float( dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = src0_row[i00]; + dst_row[i00] = float(src0_row[i00]); } template @@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float( dst[dst_row*ncols + col] = sum; } -template -static void get_rows_cuda( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - +template +static void get_rows_cuda_q( + const void * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); const dim3 block_nums(block_num_x, ne10, ne11*ne12); // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, + src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); - - GGML_UNUSED(dst); } -template +template static void get_rows_cuda_float( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(ne13 == 1); - + const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num_x, ne10, ne11*ne12); // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, + src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); - - GGML_UNUSED(dst); } -void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - const void * src0_d = (const void *) src0->data; - const int32_t * src1_d = (const int32_t *) src1->data; - float * dst_d = (float *) dst->data; - - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - switch (src0->type) { +template +static void ggml_cuda_get_rows_switch_src0_type( + const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + switch (src0_type) { case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); + get_rows_cuda_float((const half *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); + get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; default: // TODO: k-quants - GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); + GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type)); break; } } +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream) { + switch (dst_type) { + case GGML_TYPE_F32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_F16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + default: + GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type)); + break; + } +} + +void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ne13 == 1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); +} + void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh index a1ca643f1c530..3c5bea5f48c1c 100644 --- a/ggml/src/ggml-cuda/getrows.cuh +++ b/ggml/src/ggml-cuda/getrows.cuh @@ -3,6 +3,13 @@ #define CUDA_GET_ROWS_BLOCK_SIZE 256 #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream); + void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fba8cb6565bae..9fb2134f98d3d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1551,7 +1551,7 @@ static void ggml_cuda_op_mul_mat( if (src1_on_device && src1_is_contiguous) { quantize_src1( - dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10, + dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10, nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float), src1_padded_col_size, ne11, ne12, ne13, stream); CUDA_CHECK(cudaGetLastError()); @@ -1649,7 +1649,7 @@ static void ggml_cuda_op_mul_mat( if (quantize_src1 && !src1_is_contiguous) { quantize_src1( - src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10, + src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10, src1_padded_col_size, src1_ncols, 1, 1, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1949,6 +1949,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_vec_q) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_q) { + ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention @@ -1964,183 +1966,145 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } -struct mmid_row_mapping { - int32_t i1; - int32_t i2; -}; - -static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, - int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, - int64_t ne11, int64_t ne10, - size_t nb11, size_t nb12) { - int32_t iid1 = blockIdx.x; - int32_t id = blockIdx.y; - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); - - if (row_id_i != i02) { - return; - } - - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - - __shared__ int src1_row; - if (threadIdx.x == 0) { - src1_row = atomicAdd(cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - __syncthreads(); - - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); - float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - - for (int i = threadIdx.x; i < ne10; i += blockDim.x) { - src1_row_contiguous[i] = src1_row_original[i]; - } -} - -static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, - const mmid_row_mapping * __restrict__ row_mapping, - int64_t ne0, - size_t nb1, size_t nb2) { - int32_t i = blockIdx.x; - - const int32_t i1 = row_mapping[i].i1; - const int32_t i2 = row_mapping[i].i2; - - const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); - float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); - - for (int j = threadIdx.x; j < ne0; j += blockDim.x) { - dst_row_original[j] = dst_row_contiguous[j]; - } -} - static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS - - if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) { - if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); - } else { - ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); - } - return; - } - + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); - cudaStream_t stream = ctx.stream(); + GGML_TENSOR_BINARY_OP_LOCALS - const int64_t n_as = ne02; - const int64_t n_ids = ids->ne[0]; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ne2 == 1) { + if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + } else { + ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); + } + return; + } - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) { + ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); + return; + } + } - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + cudaStream_t stream = ctx.stream(); - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[3] = nb02; + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; + const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc)) + || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type; + const ggml_type type_dst_sorted = GGML_TYPE_F32; + const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted); + const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted); - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; - ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + std::vector ids_to_sorted_host; + ids_to_sorted_host.reserve(2*ne_get_rows); + std::vector ids_from_sorted_host(ne_get_rows); - src1_row.data = src1_contiguous.get(); - dst_row.data = dst_contiguous.get(); + ggml_cuda_pool_alloc ids_buf_dev(ctx.pool(), 2*ne_get_rows); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; + std::vector tokens_per_expert(ne02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + ggml_cuda_pool_alloc src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted); + ggml_cuda_pool_alloc dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted); - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + std::vector ids_host(ggml_nbytes(ids)); + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); - if (row_id_i != i02) { - continue; + for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices + for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens + for (int64_t iex = 0; iex < n_expert_used; ++iex) { + const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); + assert(expert_to_use >= 0 && expert_to_use < ne02); + if (expert_to_use == i02) { + ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size(); + ids_to_sorted_host.push_back(i12*ne11 + iex % ne11); + tokens_per_expert[i02]++; + break; } - - num_src1_rows++; } } + } + GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows)); - if (num_src1_rows == 0) { - continue; - } - - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); - - { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - } + ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end()); - src0_row.data = src0_original + i02*nb02; + CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); + const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows; + const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows; - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; - src1_row.nb[2] = num_src1_rows*nb11; - src1_row.nb[3] = num_src1_rows*nb11; + get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted, + ne10, nb11, nb12, nb13, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream); + CUDA_CHECK(cudaGetLastError()); - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + char * src1_data_cur = (char *) src1_sorted.ptr; + char * dst_data_cur = (char *) dst_sorted.ptr; + for (int64_t i02 = 0; i02 < ne02; ++i02) { + if (tokens_per_expert[i02] == 0) { + continue; + } - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + ggml_tensor src0_slice = *src0; + src0_slice.ne[2] = 1; + src0_slice.nb[3] = src0_slice.nb[2]; + src0_slice.data = (char *) src0->data + i02*nb02; + + ggml_tensor src1_slice; + memset(&src1_slice, 0, sizeof(src1_slice)); + src1_slice.buffer = src1->buffer; + src1_slice.type = type_src1_sorted; + src1_slice.ne[0] = ne10; + src1_slice.ne[1] = tokens_per_expert[i02]; + src1_slice.ne[2] = 1; + src1_slice.ne[3] = 1; + src1_slice.nb[0] = ts_src1_sorted; + src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0]; + src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1]; + src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2]; + src1_slice.data = src1_data_cur; + + ggml_tensor dst_slice; + memset(&dst_slice, 0, sizeof(dst_slice)); + dst_slice.buffer = dst->buffer; + dst_slice.type = type_dst_sorted; + dst_slice.ne[0] = ne0; + dst_slice.ne[1] = tokens_per_expert[i02]; + dst_slice.ne[2] = 1; + dst_slice.ne[3] = 1; + dst_slice.nb[0] = ts_dst_sorted; + dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0]; + dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1]; + dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2]; + dst_slice.data = dst_data_cur; + + ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice); + CUDA_CHECK(cudaGetLastError()); - { - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_contiguous.get(), - dev_row_mapping.get(), - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } + src1_data_cur += src1_slice.nb[2]; + dst_data_cur += dst_slice.nb[2]; } + + get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type, + ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + nb1, nb2, nb3, stream); } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index b36b43d5417ba..f397a7e038469 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,37 +1,10 @@ #include "mmq.cuh" +#include "quantize.cuh" -void ggml_cuda_op_mul_mat_q( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - - const int64_t ne00 = src0->ne[0]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - GGML_ASSERT(ne10 % QK8_1 == 0); +#include - const int64_t ne0 = dst->ne[0]; - - const int64_t row_diff = row_high - row_low; - const int64_t stride00 = ne00 / ggml_blck_size(src0->type); - - int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - - // The stream-k decomposition is only faster for recent NVIDIA GPUs. - // Also its fixup needs to allocate a temporary buffer in the memory pool. - // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; - - switch (src0->type) { +static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + switch (args.type_x) { case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -90,10 +63,195 @@ void ggml_cuda_op_mul_mat_q( GGML_ABORT("fatal error"); break; } +} + +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + const char * src0_d = (const char *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + + if (!ids) { + const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, + ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + ne00, ne01, ne1, s01, s1, + ne02, ne12, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k}; + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); + return; + } + + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + + std::vector ids_host(ggml_nbytes(ids)); + std::vector ids_src1_host; + ids_src1_host.reserve(ne_get_rows); + std::vector ids_dst_host; + ids_dst_host.reserve(ne_get_rows); + std::vector tokens_per_expert_host(ne02); + std::vector expert_bounds_host(ne02 + 1); + ggml_cuda_pool_alloc ids_buf_dev(ctx.pool()); + + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices + for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens + for (int64_t iex = 0; iex < n_expert_used; ++iex) { + const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); + assert(expert_to_use >= 0 && expert_to_use < ne02); + if (expert_to_use == i02) { + ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11); + ids_dst_host.push_back(i12*ne1 + iex); + tokens_per_expert_host[i02]++; + break; + } + } + } + } + + int32_t cumsum = 0; + for (int64_t i = 0; i < ne02; ++i) { + expert_bounds_host[i] = cumsum; + cumsum += tokens_per_expert_host[i]; + } + expert_bounds_host[ne02] = cumsum; + + std::vector ids_buf_host; + ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size()); + ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end()); + ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end()); + ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end()); + ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device. + CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + const int32_t * ids_src1_dev = ids_buf_dev.ptr; + const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size(); + const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size(); + + const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[2] / ts_src1; + quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, + ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d, + ne00, ne01, ne_get_rows, s01, s1, + ne02, ne02, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); +} + +void ggml_cuda_op_mul_mat_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + const int64_t stride01 = ne00 / ggml_blck_size(src0->type); + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + // The stream-k decomposition is only faster for recent NVIDIA GPUs. + // Also its fixup needs to allocate a temporary buffer in the memory pool. + // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. + const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && + ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; + const mmq_args args = { + src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, + ne00, row_diff, src1_ncols, stride01, nrows_dst, + 1, 1, 0, 0, 0, + 1, 1, 0, 0, 0, + use_stream_k}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); GGML_UNUSED(src1); GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); + GGML_UNUSED(src1_padded_row_size); } bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cb2015520ba1..8c93e8326e20b 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -13,9 +13,10 @@ using namespace ggml_cuda_mma; #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 -typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); -typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); enum mmq_q8_1_ds_layout { MMQ_Q8_1_DS_LAYOUT_D4, @@ -233,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */ // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -289,7 +290,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -328,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -384,7 +385,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -423,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -495,7 +496,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -565,7 +566,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -621,7 +622,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -651,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; @@ -732,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; @@ -762,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; @@ -839,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; @@ -871,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -955,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1011,7 +1012,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; @@ -1074,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -1201,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1298,7 +1299,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1340,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co } template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1437,7 +1438,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1469,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1578,7 +1579,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1610,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1693,7 +1694,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1726,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -1835,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_iq4_nl( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1893,7 +1894,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1951,7 +1952,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2007,7 +2008,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2070,7 +2071,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq3_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2126,7 +2127,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq3_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2189,7 +2190,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq1_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2245,7 +2246,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq4_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2306,8 +2307,8 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void mmq_write_back_dp4a( - const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -2324,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( continue; } - dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } template static __device__ __forceinline__ void mmq_write_back_mma( - const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -2362,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( continue; } - dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } } @@ -2518,17 +2519,18 @@ struct mmq_type_traits { }; template -static __device__ void mul_mat_q_process_tile( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0, - const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { +static __device__ __forceinline__ void mul_mat_q_process_tile( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; - extern __shared__ char data_mul_mat_q[]; - int * tile_y = (int *) data_mul_mat_q; + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); #ifdef NEW_MMA_AVAILABLE @@ -2543,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile( float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; - const int tile_x_max_i = ne01 - it*mmq_y - 1; - const int tile_y_max_j = ne11 - jt*mmq_x - 1; - - const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2568,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile( __syncthreads(); { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2585,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile( } if (fixup) { - write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); } else { - write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); } - - GGML_UNUSED(ne00); GGML_UNUSED(ne10); } @@ -2609,8 +2604,11 @@ template #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -2621,26 +2619,85 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); + const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_y; + int col_diff = ncols_y; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + constexpr bool fixup = false; mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - blockIdx.x, blockIdx.y, 0, ne00/qk); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - const int64_t blocks_per_ne00 = ne00 / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y - // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; @@ -2649,13 +2706,64 @@ static __global__ void mul_mat_q( int kb0_start = kbc % blocks_per_ne00; int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile. - const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile. + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_y; + int col_diff = ncols_y; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - it, jt, kb0_start, kb0_stop); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); kbc += blocks_per_ne00; kbc -= kbc % blocks_per_ne00; @@ -2668,55 +2776,106 @@ static __global__ void mul_mat_q( return; } - const int jt = kbc / (blocks_per_ne00*nty); - const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_y; + int col_diff = ncols_y; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - it, jt, kb0_start, kb0_stop); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } template static __global__ void mul_mat_q_stream_k_fixup( - float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) { - + const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, + const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst, + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const int64_t blocks_per_ne00 = ne00 / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; - const int ntx = (ne11 + mmq_x - 1) / mmq_x; - const int nty = (ne01 + mmq_y - 1) / mmq_y; - - bool any_fixup = false; + const int ntx = (ncols_y + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; - const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x); - const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x); + const int bidx0 = blockIdx.x; - int64_t kbc_0; - int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq; - - for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - kbc_0 = kbc_stop_0; - kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter; - const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter; + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; - // Skip fixup tile if the MMQ CUDA block never wrote anything to it: - if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { - continue; - } + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } - const int jt = kbc_stop / (blocks_per_ne00*nty); - const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + bool any_fixup = false; - // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: - if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) { + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; continue; } @@ -2733,16 +2892,71 @@ static __global__ void mul_mat_q_stream_k_fixup( sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; } if (!any_fixup) { return; } - dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y; + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; - const int i_max = ne01 - blockIdx.x*mmq_y - 1; - const int j_max = ne11 - blockIdx.y*mmq_x - 1; + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_y - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { + ids_dst_shared[j] = ids_dst[col_low + j]; + } + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -2760,26 +2974,27 @@ static __global__ void mul_mat_q_stream_k_fixup( continue; } - dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } struct mmq_args { - const char * x; const char * y; float * dst; - int64_t ne00; int64_t ne01; int64_t stride01; - int64_t ne10; int64_t ne11; int64_t stride11; - int64_t ne0; + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; bool use_stream_k; }; template -static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); - return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } template @@ -2791,86 +3006,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shmem_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - shmem_limit_raised[id] = true; + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); + shared_memory_limit_raised[id] = true; } #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - const int nty = (args.ne01 + mmq_y - 1) / mmq_y; - const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; - const dim3 block_nums_xy_tiling(nty, ntx, 1); + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; if (!args.use_stream_k) { - if (args.ne01 % mmq_y == 0) { + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } return; } - const dim3 block_nums_mmq(nsm, 1, 1); + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; ggml_cuda_pool & pool = ctx.pool(id); - ggml_cuda_pool_alloc tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } - if (args.ne01 % mmq_y == 0) { + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + + if (!fixup_needed) { + return; + } - mul_mat_q_stream_k_fixup<<>> - (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + + if (!fixup_needed) { + return; + } - mul_mat_q_stream_k_fixup<<>> - (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } } template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int smpbo = ggml_cuda_info().devices[id].smpbo; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); - const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; int mmq_x_best = 0; - int nparts_best = INT_MAX; + int ntiles_x_best = INT_MAX; - for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) { continue; } - const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; - const int nwaves_xy_tiling = ntiles_x*block_num_y; - const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; + const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x; - if (nparts < nparts_best) { - mmq_x_best = mmq_x; - nparts_best = nparts; + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; } } @@ -2954,6 +3197,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // ------------------------------------------------------------------------------------------------------------------------- +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d846e35a6a26d..132c466fd1aa6 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -158,7 +158,7 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1. + // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const int channel_dst = blockIdx.y; const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; @@ -507,7 +507,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -519,7 +519,7 @@ void ggml_cuda_mul_mat_vec_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); } const int64_t s01 = src0->nb[1] / ts_src0; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 3bab47d56a22e..931a45ad347dc 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -49,29 +49,38 @@ static __global__ void quantize_q8_1( template static __global__ void quantize_mmq_q8_1( - const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; - if (ix0 >= kx0_padded) { + if (i0 >= ne0) { return; } - const float4 * x4 = (const float4 *) x; + const int64_t i1 = blockIdx.y; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; - const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; + const int64_t i00 = i0; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + const float4 * x4 = (const float4 *) x; block_q8_1_mmq * y = (block_q8_1_mmq *) vy; const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel - const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel + const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: - const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); float amax = fabsf(xi.x); amax = fmaxf(amax, fabsf(xi.y)); amax = fmaxf(amax, fabsf(xi.z)); @@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1( if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { sum = xi.x + xi.y + xi.z + xi.w; - // Exchange calculate sum across vals_per_sum/4 threads. + // Calculate sums across vals_per_sum/4 threads. #pragma unroll for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); @@ -137,9 +146,10 @@ static __global__ void quantize_mmq_q8_1( } void quantize_row_q8_1_cuda( - const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { - + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(!ids); GGML_ASSERT(ne0 % QK8_1 == 0); const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; @@ -150,9 +160,9 @@ void quantize_row_q8_1_cuda( } void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { - + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { GGML_ASSERT(ne0 % (4*QK8_1) == 0); const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); @@ -161,21 +171,18 @@ void quantize_mmq_q8_1_cuda( switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: quantize_mmq_q8_1 - <<>>(x, vy, ne00, ne1, ne0); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; case MMQ_Q8_1_DS_LAYOUT_DS4: quantize_mmq_q8_1 - <<>>(x, vy, ne00, ne1, ne0); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; case MMQ_Q8_1_DS_LAYOUT_D2S6: quantize_mmq_q8_1 - <<>>(x, vy, ne00, ne1, ne0); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; default: GGML_ABORT("fatal error"); break; } - GGML_UNUSED(s01); - GGML_UNUSED(s02); - GGML_UNUSED(s03); } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index b627c4e4008b4..725ab52443c0e 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); typedef void (*quantize_cuda_t)( - const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); void quantize_row_q8_1_cuda( - const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d70acb7719435..9591b1a89e723 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4184,6 +4184,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1})); + + // test cases with large ne00/ne10 to cover stream-k fixup + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1})); } } for (ggml_type type_a : other_types) {