Skip to content

Commit bee2842

Browse files
authored
opencl : broadcast for soft_max (#14510)
1 parent 2b72bed commit bee2842

File tree

5 files changed

+132
-59
lines changed

5 files changed

+132
-59
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
57635763

57645764
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
57655765

5766-
const int ne00 = src0 ? src0->ne[0] : 0;
5767-
const int ne01 = src0 ? src0->ne[1] : 0;
5768-
const int ne02 = src0 ? src0->ne[2] : 0;
5769-
const int ne03 = src0 ? src0->ne[3] : 0;
5766+
const int ne00 = src0->ne[0];
5767+
const int ne01 = src0->ne[1];
5768+
const int ne02 = src0->ne[2];
5769+
const int ne03 = src0->ne[3];
5770+
5771+
const cl_long nb01 = src0->nb[1];
5772+
const cl_long nb02 = src0->nb[2];
5773+
const cl_long nb03 = src0->nb[3];
5774+
5775+
const int ne12 = src1 ? src1->ne[2] : 0;
5776+
const int ne13 = src1 ? src1->ne[3] : 0;
5777+
5778+
const cl_long nb11 = src1 ? src1->nb[1] : 0;
5779+
const cl_long nb12 = src1 ? src1->nb[2] : 0;
5780+
const cl_long nb13 = src1 ? src1->nb[3] : 0;
5781+
5782+
const cl_long nb1 = dst->nb[1];
5783+
const cl_long nb2 = dst->nb[2];
5784+
const cl_long nb3 = dst->nb[3];
57705785

57715786
float scale, max_bias;
57725787
memcpy(&scale, dst->op_params + 0, sizeof(float));
57735788
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
57745789

5775-
const int nrows_x = ggml_nrows(src0);
5776-
const int nrows_y = src0->ne[1];
5777-
5778-
const int n_head = nrows_x/nrows_y;
5790+
const int n_head = src0->ne[2];
57795791
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
57805792

57815793
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
58205832
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
58215833
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
58225834
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5823-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5824-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5825-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
5826-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
5827-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
5828-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
5829-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
5835+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
5836+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
5837+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
5838+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5839+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
5840+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
5841+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
5842+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
5843+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
5844+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
5845+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
5846+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
5847+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
5848+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
5849+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
5850+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
58305851

58315852
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
58325853
size_t local_work_size[] = {(size_t)nth, 1, 1};

ggml/src/ggml-opencl/kernels/softmax_4_f16.cl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,45 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_4_f16(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global half * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
32-
int ne01,
33-
int ne02,
32+
ulong nb01,
33+
ulong nb02,
34+
ulong nb03,
35+
int ne12,
36+
int ne13,
37+
ulong nb11,
38+
ulong nb12,
39+
ulong nb13,
40+
ulong nb1,
41+
ulong nb2,
42+
ulong nb3,
3443
float scale,
3544
float max_bias,
3645
float m0,
3746
float m1,
3847
int n_head_log2
3948
) {
40-
src0 = (global float *)((global char *)src0 + offset0);
41-
src1 = (global half *)((global char *)src1 + offset1);
42-
dst = (global float *)((global char *)dst + offsetd);
49+
src0 = src0 + offset0;
50+
src1 = src1 + offset1;
51+
dst = dst + offsetd;
4352

4453
int i03 = get_group_id(2);
4554
int i02 = get_group_id(1);
4655
int i01 = get_group_id(0);
4756

48-
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49-
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
50-
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
57+
int i13 = i03%ne13;
58+
int i12 = i02%ne12;
59+
int i11 = i01;
60+
61+
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62+
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63+
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5164

5265
float slope = 1.0f;
5366

ggml/src/ggml-opencl/kernels/softmax_4_f32.cl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,45 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_4(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global float * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
32-
int ne01,
33-
int ne02,
32+
ulong nb01,
33+
ulong nb02,
34+
ulong nb03,
35+
int ne12,
36+
int ne13,
37+
ulong nb11,
38+
ulong nb12,
39+
ulong nb13,
40+
ulong nb1,
41+
ulong nb2,
42+
ulong nb3,
3443
float scale,
3544
float max_bias,
3645
float m0,
3746
float m1,
3847
int n_head_log2
3948
) {
40-
src0 = (global float*)((global char*)src0 + offset0);
41-
src1 = (global float*)((global char*)src1 + offset1);
42-
dst = (global float*)((global char*)dst + offsetd);
49+
src0 = src0 + offset0;
50+
src1 = src1 + offset1;
51+
dst = dst + offsetd;
4352

4453
int i03 = get_group_id(2);
4554
int i02 = get_group_id(1);
4655
int i01 = get_group_id(0);
4756

48-
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
49-
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
50-
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
57+
int i13 = i03%ne13;
58+
int i12 = i02%ne12;
59+
int i11 = i01;
60+
61+
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62+
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63+
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5164

5265
float slope = 1.0f;
5366

ggml/src/ggml-opencl/kernels/softmax_f16.cl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,45 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max_f16(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global half * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
32-
int ne01,
33-
int ne02,
32+
ulong nb01,
33+
ulong nb02,
34+
ulong nb03,
35+
int ne12,
36+
int ne13,
37+
ulong nb11,
38+
ulong nb12,
39+
ulong nb13,
40+
ulong nb1,
41+
ulong nb2,
42+
ulong nb3,
3443
float scale,
3544
float max_bias,
3645
float m0,
3746
float m1,
3847
int n_head_log2
3948
) {
40-
src0 = (global float *)((global char *)src0 + offset0);
41-
src1 = (global half *)((global char *)src1 + offset1);
42-
dst = (global float *)((global char *)dst + offsetd);
49+
src0 = src0 + offset0;
50+
src1 = src1 + offset1;
51+
dst = dst + offsetd;
4352

4453
int i03 = get_group_id(2);
4554
int i02 = get_group_id(1);
4655
int i01 = get_group_id(0);
4756

48-
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49-
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
50-
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
57+
int i13 = i03%ne13;
58+
int i12 = i02%ne12;
59+
int i11 = i01;
60+
61+
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62+
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63+
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5164

5265
float slope = 1.0f;
5366

ggml/src/ggml-opencl/kernels/softmax_f32.cl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,45 @@
2222
REQD_SUBGROUP_SIZE_64
2323
#endif
2424
kernel void kernel_soft_max(
25-
global float * src0,
25+
global char * src0,
2626
ulong offset0,
27-
global float * src1,
27+
global char * src1,
2828
ulong offset1,
29-
global float * dst,
29+
global char * dst,
3030
ulong offsetd,
3131
int ne00,
32-
int ne01,
33-
int ne02,
32+
ulong nb01,
33+
ulong nb02,
34+
ulong nb03,
35+
int ne12,
36+
int ne13,
37+
ulong nb11,
38+
ulong nb12,
39+
ulong nb13,
40+
ulong nb1,
41+
ulong nb2,
42+
ulong nb3,
3443
float scale,
3544
float max_bias,
3645
float m0,
3746
float m1,
3847
int n_head_log2
3948
) {
40-
src0 = (global float*)((global char*)src0 + offset0);
41-
src1 = (global float*)((global char*)src1 + offset1);
42-
dst = (global float*)((global char*)dst + offsetd);
49+
src0 = src0 + offset0;
50+
src1 = src1 + offset1;
51+
dst = dst + offsetd;
4352

4453
int i03 = get_group_id(2);
4554
int i02 = get_group_id(1);
4655
int i01 = get_group_id(0);
4756

48-
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
49-
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
50-
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
57+
int i13 = i03%ne13;
58+
int i12 = i02%ne12;
59+
int i11 = i01;
60+
61+
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
62+
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
63+
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
5164

5265
float slope = 1.0f;
5366

0 commit comments

Comments
 (0)