Skip to content

opencl: broadcast for soft_max #14510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5757,19 +5757,32 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c

cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;

const int ne00 = src0 ? src0->ne[0] : 0;
const int ne01 = src0 ? src0->ne[1] : 0;
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];

const cl_long nb01 = src0->nb[1];
const cl_long nb02 = src0->nb[2];
const cl_long nb03 = src0->nb[3];

const int ne11 = src1 ? src1->ne[1] : 0;
const int ne12 = src1 ? src1->ne[2] : 0;
const int ne13 = src1 ? src1->ne[3] : 0;

const cl_long nb11 = src1 ? src1->nb[1] : 0;
const cl_long nb12 = src1 ? src1->nb[2] : 0;
const cl_long nb13 = src1 ? src1->nb[3] : 0;

const cl_long nb1 = dst->nb[1];
const cl_long nb2 = dst->nb[2];
const cl_long nb3 = dst->nb[3];

float scale, max_bias;
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));

const int nrows_x = ggml_nrows(src0);
const int nrows_y = src0->ne[1];

const int n_head = nrows_x/nrows_y;
const int n_head = src0->ne[2];
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
Expand Down Expand Up @@ -5816,11 +5829,23 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &n_head_log2));

size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
Expand Down
34 changes: 25 additions & 9 deletions ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,48 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4_f16(
global float * src0,
global char * src0,
ulong offset0,
global half * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;

int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);

global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;

global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);

float slope = 1.0f;

Expand Down
34 changes: 25 additions & 9 deletions ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,48 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4(
global float * src0,
global char * src0,
ulong offset0,
global float * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;

int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);

global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;

global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);

float slope = 1.0f;

Expand Down
34 changes: 25 additions & 9 deletions ggml/src/ggml-opencl/kernels/softmax_f16.cl
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,48 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_f16(
global float * src0,
global char * src0,
ulong offset0,
global half * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;

int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);

global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;

global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);

float slope = 1.0f;

Expand Down
34 changes: 25 additions & 9 deletions ggml/src/ggml-opencl/kernels/softmax_f32.cl
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,48 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max(
global float * src0,
global char * src0,
ulong offset0,
global float * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;

int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);

global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;

global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);

float slope = 1.0f;

Expand Down
Loading