Skip to content

Commit 6681688

Browse files
authored
opencl: add GELU_ERF (#14476)
1 parent bac8bed commit 6681688

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ struct ggml_backend_opencl_context {
398398
cl_kernel kernel_scale;
399399
cl_kernel kernel_silu, kernel_silu_4;
400400
cl_kernel kernel_gelu, kernel_gelu_4;
401+
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
401402
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
402403
cl_kernel kernel_relu;
403404
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
@@ -736,6 +737,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
736737

737738
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err));
738739
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err));
740+
CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err));
741+
CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err));
739742
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err));
740743
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err));
741744
GGML_LOG_CONT(".");
@@ -2266,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
22662269
case GGML_UNARY_OP_GELU:
22672270
case GGML_UNARY_OP_SILU:
22682271
case GGML_UNARY_OP_RELU:
2272+
case GGML_UNARY_OP_GELU_ERF:
22692273
case GGML_UNARY_OP_GELU_QUICK:
22702274
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
22712275
case GGML_UNARY_OP_SIGMOID:
@@ -3870,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
38703874
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
38713875
}
38723876

3877+
static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3878+
GGML_ASSERT(src0);
3879+
GGML_ASSERT(src0->extra);
3880+
GGML_ASSERT(dst);
3881+
GGML_ASSERT(dst->extra);
3882+
3883+
UNUSED(src1);
3884+
3885+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
3886+
3887+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
3888+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
3889+
3890+
cl_ulong offset0 = extra0->offset + src0->view_offs;
3891+
cl_ulong offsetd = extrad->offset + dst->view_offs;
3892+
3893+
cl_kernel kernel;
3894+
3895+
int n = ggml_nelements(dst);
3896+
3897+
if (n % 4 == 0) {
3898+
kernel = backend_ctx->kernel_gelu_erf_4;
3899+
n /= 4;
3900+
} else {
3901+
kernel = backend_ctx->kernel_gelu_erf;
3902+
}
3903+
3904+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
3905+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
3906+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
3907+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
3908+
3909+
size_t global_work_size[] = {(size_t)n, 1, 1};
3910+
size_t local_work_size[] = {64, 1, 1};
3911+
3912+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
3913+
}
3914+
38733915
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
38743916
GGML_ASSERT(src0);
38753917
GGML_ASSERT(src0->extra);
@@ -6388,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
63886430
}
63896431
func = ggml_cl_gelu;
63906432
break;
6433+
case GGML_UNARY_OP_GELU_ERF:
6434+
if (!any_on_device) {
6435+
return false;
6436+
}
6437+
func = ggml_cl_gelu_erf;
6438+
break;
63916439
case GGML_UNARY_OP_GELU_QUICK:
63926440
if (!any_on_device) {
63936441
return false;

ggml/src/ggml-opencl/kernels/gelu.cl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define GELU_COEF_A 0.044715f
77
#define GELU_QUICK_COEF -1.702f
88
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
9+
#define SQRT_2_INV 0.70710678118654752440084436210484f
910

1011
kernel void kernel_gelu(
1112
global float * src0,
@@ -35,6 +36,32 @@ kernel void kernel_gelu_4(
3536
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
3637
}
3738

39+
kernel void kernel_gelu_erf(
40+
global float * src0,
41+
ulong offset0,
42+
global float * dst,
43+
ulong offsetd
44+
) {
45+
src0 = (global float*)((global char*)src0 + offset0);
46+
dst = (global float*)((global char*)dst + offsetd);
47+
48+
float x = src0[get_global_id(0)];
49+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
50+
}
51+
52+
kernel void kernel_gelu_erf_4(
53+
global float4 * src0,
54+
ulong offset0,
55+
global float4 * dst,
56+
ulong offsetd
57+
) {
58+
src0 = (global float4*)((global char*)src0 + offset0);
59+
dst = (global float4*)((global char*)dst + offsetd);
60+
61+
float4 x = src0[get_global_id(0)];
62+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
63+
}
64+
3865
kernel void kernel_gelu_quick(
3966
global float * src0,
4067
ulong offset0,

0 commit comments

Comments
 (0)