Skip to content

Commit 28657a8

Browse files
authored
ggml : implement GEGLU_ERF and GEGLU_QUICK ops (#14445)
1 parent bee2842 commit 28657a8

File tree

20 files changed

+789
-32
lines changed

20 files changed

+789
-32
lines changed

ggml/include/ggml.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ extern "C" {
557557
GGML_GLU_OP_REGLU,
558558
GGML_GLU_OP_GEGLU,
559559
GGML_GLU_OP_SWIGLU,
560+
GGML_GLU_OP_GEGLU_ERF,
561+
GGML_GLU_OP_GEGLU_QUICK,
560562

561563
GGML_GLU_OP_COUNT,
562564
};
@@ -1147,6 +1149,22 @@ extern "C" {
11471149
struct ggml_context * ctx,
11481150
struct ggml_tensor * a);
11491151

1152+
GGML_API struct ggml_tensor * ggml_geglu_erf(
1153+
struct ggml_context * ctx,
1154+
struct ggml_tensor * a);
1155+
1156+
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
1157+
struct ggml_context * ctx,
1158+
struct ggml_tensor * a);
1159+
1160+
GGML_API struct ggml_tensor * ggml_geglu_quick(
1161+
struct ggml_context * ctx,
1162+
struct ggml_tensor * a);
1163+
1164+
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
1165+
struct ggml_context * ctx,
1166+
struct ggml_tensor * a);
1167+
11501168
// A: n columns, r rows,
11511169
// B: n columns, r rows,
11521170
GGML_API struct ggml_tensor * ggml_glu_split(
@@ -1170,6 +1188,16 @@ extern "C" {
11701188
struct ggml_tensor * a,
11711189
struct ggml_tensor * b);
11721190

1191+
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
1192+
struct ggml_context * ctx,
1193+
struct ggml_tensor * a,
1194+
struct ggml_tensor * b);
1195+
1196+
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
1197+
struct ggml_context * ctx,
1198+
struct ggml_tensor * a,
1199+
struct ggml_tensor * b);
1200+
11731201
// normalize along rows
11741202
GGML_API struct ggml_tensor * ggml_norm(
11751203
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21722172
case GGML_GLU_OP_REGLU:
21732173
case GGML_GLU_OP_GEGLU:
21742174
case GGML_GLU_OP_SWIGLU:
2175+
case GGML_GLU_OP_GEGLU_ERF:
2176+
case GGML_GLU_OP_GEGLU_QUICK:
21752177
{
21762178
n_tasks = n_threads;
21772179
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
36143614
}
36153615
}
36163616

3617+
// ggml_compute_forward_geglu_erf
3618+
3619+
static void ggml_compute_forward_geglu_erf_f32(
3620+
const ggml_compute_params * params,
3621+
ggml_tensor * dst) {
3622+
3623+
const ggml_tensor * src0 = dst->src[0];
3624+
const ggml_tensor * src1 = dst->src[1];
3625+
char * src0_d = (char *) src0->data;
3626+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627+
const size_t src0_o = src0->nb[1];
3628+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629+
3630+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3631+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3632+
3633+
if (src1) {
3634+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3635+
GGML_ASSERT(src0->type == src1->type);
3636+
}
3637+
3638+
const int ith = params->ith;
3639+
const int nth = params->nth;
3640+
3641+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642+
const int nr = ggml_nrows(src0);
3643+
3644+
GGML_ASSERT(dst->ne[0] == nc);
3645+
GGML_ASSERT(ggml_nrows(dst) == nr);
3646+
3647+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648+
3649+
// rows per thread
3650+
const int dr = (nr + nth - 1)/nth;
3651+
3652+
// row range for this thread
3653+
const int ir0 = dr*ith;
3654+
const int ir1 = MIN(ir0 + dr, nr);
3655+
3656+
for (int i1 = ir0; i1 < ir1; i1++) {
3657+
float * src0_p = (float *) (src0_d + i1*src0_o);
3658+
float * src1_p = (float *) (src1_d + i1*src1_o);
3659+
3660+
if (!src1) {
3661+
src0_p += swapped ? nc : 0;
3662+
src1_p += swapped ? 0 : nc;
3663+
}
3664+
3665+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666+
3667+
#ifndef NDEBUG
3668+
for (int k = 0; k < nc; k++) {
3669+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670+
GGML_UNUSED(x);
3671+
assert(!isnan(x));
3672+
assert(!isinf(x));
3673+
}
3674+
#endif
3675+
}
3676+
}
3677+
3678+
static void ggml_compute_forward_geglu_erf_f16(
3679+
const ggml_compute_params * params,
3680+
ggml_tensor * dst) {
3681+
3682+
const ggml_tensor * src0 = dst->src[0];
3683+
const ggml_tensor * src1 = dst->src[1];
3684+
char * src0_d = (char *) src0->data;
3685+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686+
const size_t src0_o = src0->nb[1];
3687+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688+
3689+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3690+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3691+
3692+
if (src1) {
3693+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3694+
GGML_ASSERT(src0->type == src1->type);
3695+
}
3696+
3697+
const int ith = params->ith;
3698+
const int nth = params->nth;
3699+
3700+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701+
const int nr = ggml_nrows(src0);
3702+
3703+
GGML_ASSERT(dst->ne[0] == nc);
3704+
GGML_ASSERT(ggml_nrows(dst) == nr);
3705+
3706+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707+
3708+
// rows per thread
3709+
const int dr = (nr + nth - 1)/nth;
3710+
3711+
// row range for this thread
3712+
const int ir0 = dr*ith;
3713+
const int ir1 = MIN(ir0 + dr, nr);
3714+
3715+
for (int i1 = ir0; i1 < ir1; i1++) {
3716+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718+
3719+
if (!src1) {
3720+
src0_p += swapped ? nc : 0;
3721+
src1_p += swapped ? 0 : nc;
3722+
}
3723+
3724+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725+
3726+
#ifndef NDEBUG
3727+
for (int k = 0; k < nc; k++) {
3728+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729+
const float v = GGML_FP16_TO_FP32(x);
3730+
GGML_UNUSED(v);
3731+
assert(!isnan(v));
3732+
assert(!isinf(v));
3733+
}
3734+
#endif
3735+
}
3736+
}
3737+
3738+
static void ggml_compute_forward_geglu_erf(
3739+
const ggml_compute_params * params,
3740+
ggml_tensor * dst) {
3741+
3742+
const ggml_tensor * src0 = dst->src[0];
3743+
3744+
switch (src0->type) {
3745+
case GGML_TYPE_F32:
3746+
{
3747+
ggml_compute_forward_geglu_erf_f32(params, dst);
3748+
} break;
3749+
case GGML_TYPE_F16:
3750+
{
3751+
ggml_compute_forward_geglu_erf_f16(params, dst);
3752+
} break;
3753+
default:
3754+
{
3755+
GGML_ABORT("fatal error");
3756+
}
3757+
}
3758+
}
3759+
3760+
// ggml_compute_forward_geglu_quick
3761+
3762+
static void ggml_compute_forward_geglu_quick_f32(
3763+
const ggml_compute_params * params,
3764+
ggml_tensor * dst) {
3765+
3766+
const ggml_tensor * src0 = dst->src[0];
3767+
const ggml_tensor * src1 = dst->src[1];
3768+
char * src0_d = (char *) src0->data;
3769+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770+
const size_t src0_o = src0->nb[1];
3771+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772+
3773+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3774+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3775+
3776+
if (src1) {
3777+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3778+
GGML_ASSERT(src0->type == src1->type);
3779+
}
3780+
3781+
const int ith = params->ith;
3782+
const int nth = params->nth;
3783+
3784+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785+
const int nr = ggml_nrows(src0);
3786+
3787+
GGML_ASSERT(dst->ne[0] == nc);
3788+
GGML_ASSERT(ggml_nrows(dst) == nr);
3789+
3790+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791+
3792+
// rows per thread
3793+
const int dr = (nr + nth - 1)/nth;
3794+
3795+
// row range for this thread
3796+
const int ir0 = dr*ith;
3797+
const int ir1 = MIN(ir0 + dr, nr);
3798+
3799+
for (int i1 = ir0; i1 < ir1; i1++) {
3800+
float * src0_p = (float *) (src0_d + i1*src0_o);
3801+
float * src1_p = (float *) (src1_d + i1*src1_o);
3802+
3803+
if (!src1) {
3804+
src0_p += swapped ? nc : 0;
3805+
src1_p += swapped ? 0 : nc;
3806+
}
3807+
3808+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809+
3810+
#ifndef NDEBUG
3811+
for (int k = 0; k < nc; k++) {
3812+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813+
GGML_UNUSED(x);
3814+
assert(!isnan(x));
3815+
assert(!isinf(x));
3816+
}
3817+
#endif
3818+
}
3819+
}
3820+
3821+
static void ggml_compute_forward_geglu_quick_f16(
3822+
const ggml_compute_params * params,
3823+
ggml_tensor * dst) {
3824+
3825+
const ggml_tensor * src0 = dst->src[0];
3826+
const ggml_tensor * src1 = dst->src[1];
3827+
char * src0_d = (char *) src0->data;
3828+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829+
const size_t src0_o = src0->nb[1];
3830+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831+
3832+
GGML_ASSERT(ggml_is_contiguous_1(src0));
3833+
GGML_ASSERT(ggml_is_contiguous_1(dst));
3834+
3835+
if (src1) {
3836+
GGML_ASSERT(ggml_is_contiguous_1(src1));
3837+
GGML_ASSERT(src0->type == src1->type);
3838+
}
3839+
3840+
const int ith = params->ith;
3841+
const int nth = params->nth;
3842+
3843+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844+
const int nr = ggml_nrows(src0);
3845+
3846+
GGML_ASSERT(dst->ne[0] == nc);
3847+
GGML_ASSERT(ggml_nrows(dst) == nr);
3848+
3849+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3850+
3851+
// rows per thread
3852+
const int dr = (nr + nth - 1)/nth;
3853+
3854+
// row range for this thread
3855+
const int ir0 = dr*ith;
3856+
const int ir1 = MIN(ir0 + dr, nr);
3857+
3858+
for (int i1 = ir0; i1 < ir1; i1++) {
3859+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861+
3862+
if (!src1) {
3863+
src0_p += swapped ? nc : 0;
3864+
src1_p += swapped ? 0 : nc;
3865+
}
3866+
3867+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868+
3869+
#ifndef NDEBUG
3870+
for (int k = 0; k < nc; k++) {
3871+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3872+
const float v = GGML_FP16_TO_FP32(x);
3873+
GGML_UNUSED(v);
3874+
assert(!isnan(v));
3875+
assert(!isinf(v));
3876+
}
3877+
#endif
3878+
}
3879+
}
3880+
3881+
static void ggml_compute_forward_geglu_quick(
3882+
const ggml_compute_params * params,
3883+
ggml_tensor * dst) {
3884+
3885+
const ggml_tensor * src0 = dst->src[0];
3886+
3887+
switch (src0->type) {
3888+
case GGML_TYPE_F32:
3889+
{
3890+
ggml_compute_forward_geglu_quick_f32(params, dst);
3891+
} break;
3892+
case GGML_TYPE_F16:
3893+
{
3894+
ggml_compute_forward_geglu_quick_f16(params, dst);
3895+
} break;
3896+
default:
3897+
{
3898+
GGML_ABORT("fatal error");
3899+
}
3900+
}
3901+
}
3902+
36173903
// ggml_compute_forward_norm
36183904

36193905
static void ggml_compute_forward_norm_f32(
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
87799065
{
87809066
ggml_compute_forward_swiglu(params, dst);
87819067
} break;
9068+
case GGML_GLU_OP_GEGLU_ERF:
9069+
{
9070+
ggml_compute_forward_geglu_erf(params, dst);
9071+
} break;
9072+
case GGML_GLU_OP_GEGLU_QUICK:
9073+
{
9074+
ggml_compute_forward_geglu_quick(params, dst);
9075+
} break;
87829076
default:
87839077
{
87849078
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)