@@ -6596,7 +6596,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6596
6596
}
6597
6597
6598
6598
*s = hsum_float_8(acc);
6599
+ #elif defined(__VXE__) || defined(__VXE2__)
6600
+ uint32_t aux[3];
6601
+ uint32_t utmp[4];
6602
+
6603
+ const int32x4_t v_z = vec_splat_s32(0);
6604
+ const uint8x16_t v_3m = vec_splat_u8(0x03);
6605
+
6606
+ const uint8x16_t v_0c = vec_splat_u8(1);
6607
+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
6608
+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
6609
+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
6610
+
6611
+ uint8x16_t q3h[4];
6612
+ uint8x16_t q3b[2];
6613
+ int8x16_t q3bytes[4];
6614
+ int8x16_t q8bytes[4];
6615
+ uint8x16_t qhbits[2];
6616
+
6617
+ float sum = 0;
6618
+
6619
+ for (int i = 0; i < nb; ++i) {
6620
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6599
6621
6622
+ const uint8_t * restrict x0l = x[i].qs;
6623
+ const uint8_t * restrict x0h = x[i].hmask;
6624
+ const int8_t * restrict y0 = y[i].qs;
6625
+
6626
+ qhbits[0] = vec_xl(0 , x0h);
6627
+ qhbits[1] = vec_xl(16, x0h);
6628
+
6629
+ int32_t isum = 0;
6630
+
6631
+ memcpy(aux, x[i].scales, 12);
6632
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6633
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6634
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6635
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6636
+
6637
+ int8_t * scale = (int8_t *)utmp;
6638
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6639
+
6640
+ for (int j = 0; j < QK_K/128; ++j) {
6641
+ int32x4_t isum0, isum1, isum2, isum3;
6642
+
6643
+ q3b[0] = vec_xl(0 , x0l);
6644
+ q3b[1] = vec_xl(16, x0l);
6645
+ x0l += 32;
6646
+
6647
+ q8bytes[0] = vec_xl(0 , y0);
6648
+ q8bytes[1] = vec_xl(16 , y0);
6649
+ q8bytes[2] = vec_xl(32 , y0);
6650
+ q8bytes[3] = vec_xl(48 , y0);
6651
+ q8bytes[4] = vec_xl(64 , y0);
6652
+ q8bytes[5] = vec_xl(80 , y0);
6653
+ q8bytes[6] = vec_xl(96 , y0);
6654
+ q8bytes[7] = vec_xl(112, y0);
6655
+ y0 += 128;
6656
+
6657
+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6658
+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6659
+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6660
+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6661
+
6662
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6663
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6664
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6665
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6666
+
6667
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6668
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6669
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6670
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6671
+
6672
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6673
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6674
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6675
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6676
+
6677
+ scale += 4;
6678
+
6679
+ q3h[0] = vec_andc(v_2c, qhbits[0]);
6680
+ q3h[1] = vec_andc(v_2c, qhbits[1]);
6681
+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6682
+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6683
+
6684
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6685
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6686
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6687
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6688
+
6689
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6690
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6691
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6692
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6693
+
6694
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6695
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6696
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6697
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6698
+
6699
+ scale += 4;
6700
+
6701
+ if (j == 0) {
6702
+ qhbits[0] = vec_sr(qhbits[0], 4);
6703
+ qhbits[1] = vec_sr(qhbits[1], 4);
6704
+ }
6705
+ }
6706
+
6707
+ sum += d * isum;
6708
+ }
6709
+
6710
+ *s = sum;
6600
6711
#else
6601
6712
// scalar version
6602
6713
// This function is written like this so the compiler can manage to vectorize most of it
0 commit comments