@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
3614
3614
}
3615
3615
}
3616
3616
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32 (
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src [0 ];
3624
+ const ggml_tensor * src1 = dst->src [1 ];
3625
+ char * src0_d = (char *) src0->data ;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3627
+ const size_t src0_o = src0->nb [1 ];
3628
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3629
+
3630
+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3631
+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3635
+ GGML_ASSERT (src0->type == src1->type );
3636
+ }
3637
+
3638
+ const int ith = params->ith ;
3639
+ const int nth = params->nth ;
3640
+
3641
+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3642
+ const int nr = ggml_nrows (src0);
3643
+
3644
+ GGML_ASSERT (dst->ne [0 ] == nc);
3645
+ GGML_ASSERT (ggml_nrows (dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1 )/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN (ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0 ;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0 ; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3670
+ GGML_UNUSED (x);
3671
+ assert (!isnan (x));
3672
+ assert (!isinf (x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16 (
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src [0 ];
3683
+ const ggml_tensor * src1 = dst->src [1 ];
3684
+ char * src0_d = (char *) src0->data ;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3686
+ const size_t src0_o = src0->nb [1 ];
3687
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3688
+
3689
+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3690
+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3694
+ GGML_ASSERT (src0->type == src1->type );
3695
+ }
3696
+
3697
+ const int ith = params->ith ;
3698
+ const int nth = params->nth ;
3699
+
3700
+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3701
+ const int nr = ggml_nrows (src0);
3702
+
3703
+ GGML_ASSERT (dst->ne [0 ] == nc);
3704
+ GGML_ASSERT (ggml_nrows (dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1 )/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN (ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0 ;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0 ; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3729
+ const float v = GGML_FP16_TO_FP32 (x);
3730
+ GGML_UNUSED (v);
3731
+ assert (!isnan (v));
3732
+ assert (!isinf (v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf (
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src [0 ];
3743
+
3744
+ switch (src0->type ) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32 (params, dst);
3748
+ } break ;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16 (params, dst);
3752
+ } break ;
3753
+ default :
3754
+ {
3755
+ GGML_ABORT (" fatal error" );
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32 (
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src [0 ];
3767
+ const ggml_tensor * src1 = dst->src [1 ];
3768
+ char * src0_d = (char *) src0->data ;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3770
+ const size_t src0_o = src0->nb [1 ];
3771
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3772
+
3773
+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3774
+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3778
+ GGML_ASSERT (src0->type == src1->type );
3779
+ }
3780
+
3781
+ const int ith = params->ith ;
3782
+ const int nth = params->nth ;
3783
+
3784
+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3785
+ const int nr = ggml_nrows (src0);
3786
+
3787
+ GGML_ASSERT (dst->ne [0 ] == nc);
3788
+ GGML_ASSERT (ggml_nrows (dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1 )/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN (ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0 ;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32 (nc, (float *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0 ; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3813
+ GGML_UNUSED (x);
3814
+ assert (!isnan (x));
3815
+ assert (!isinf (x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16 (
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src [0 ];
3826
+ const ggml_tensor * src1 = dst->src [1 ];
3827
+ char * src0_d = (char *) src0->data ;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data );
3829
+ const size_t src0_o = src0->nb [1 ];
3830
+ const size_t src1_o = src1 ? src1->nb [1 ] : src0->nb [1 ];
3831
+
3832
+ GGML_ASSERT (ggml_is_contiguous_1 (src0));
3833
+ GGML_ASSERT (ggml_is_contiguous_1 (dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT (ggml_is_contiguous_1 (src1));
3837
+ GGML_ASSERT (src0->type == src1->type );
3838
+ }
3839
+
3840
+ const int ith = params->ith ;
3841
+ const int nth = params->nth ;
3842
+
3843
+ const int nc = src1 ? src0->ne [0 ] : src0->ne [0 ] / 2 ;
3844
+ const int nr = ggml_nrows (src0);
3845
+
3846
+ GGML_ASSERT (dst->ne [0 ] == nc);
3847
+ GGML_ASSERT (ggml_nrows (dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3850
+
3851
+ // rows per thread
3852
+ const int dr = (nr + nth - 1 )/nth;
3853
+
3854
+ // row range for this thread
3855
+ const int ir0 = dr*ith;
3856
+ const int ir1 = MIN (ir0 + dr, nr);
3857
+
3858
+ for (int i1 = ir0; i1 < ir1; i1++) {
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861
+
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0 ;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16 (nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb [1 ])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3870
+ for (int k = 0 ; k < nc; k++) {
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])))[k];
3872
+ const float v = GGML_FP16_TO_FP32 (x);
3873
+ GGML_UNUSED (v);
3874
+ assert (!isnan (v));
3875
+ assert (!isinf (v));
3876
+ }
3877
+ #endif
3878
+ }
3879
+ }
3880
+
3881
+ static void ggml_compute_forward_geglu_quick (
3882
+ const ggml_compute_params * params,
3883
+ ggml_tensor * dst) {
3884
+
3885
+ const ggml_tensor * src0 = dst->src [0 ];
3886
+
3887
+ switch (src0->type ) {
3888
+ case GGML_TYPE_F32:
3889
+ {
3890
+ ggml_compute_forward_geglu_quick_f32 (params, dst);
3891
+ } break ;
3892
+ case GGML_TYPE_F16:
3893
+ {
3894
+ ggml_compute_forward_geglu_quick_f16 (params, dst);
3895
+ } break ;
3896
+ default :
3897
+ {
3898
+ GGML_ABORT (" fatal error" );
3899
+ }
3900
+ }
3901
+ }
3902
+
3617
3903
// ggml_compute_forward_norm
3618
3904
3619
3905
static void ggml_compute_forward_norm_f32 (
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
8779
9065
{
8780
9066
ggml_compute_forward_swiglu (params, dst);
8781
9067
} break ;
9068
+ case GGML_GLU_OP_GEGLU_ERF:
9069
+ {
9070
+ ggml_compute_forward_geglu_erf (params, dst);
9071
+ } break ;
9072
+ case GGML_GLU_OP_GEGLU_QUICK:
9073
+ {
9074
+ ggml_compute_forward_geglu_quick (params, dst);
9075
+ } break ;
8782
9076
default :
8783
9077
{
8784
9078
GGML_ABORT (" fatal error" );
0 commit comments