Skip to content

Commit 799fdc1

Browse files
committed
ggml : vectorize Q8_0 quantization
ggml-org/ggml#127 (comment)
1 parent 6daa09d commit 799fdc1

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

ggml.c

+120
Original file line numberDiff line numberDiff line change
@@ -1509,15 +1509,135 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
15091509
}
15101510

15111511
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1512+
assert(QK8_0 == 32);
15121513
assert(k % QK8_0 == 0);
1514+
const int nb = k / QK8_0;
15131515

15141516
block_q8_0 * restrict y = vy;
15151517

1518+
#if defined(__ARM_NEON)
1519+
for (int i = 0; i < nb; i++) {
1520+
float32x4_t srcv [8];
1521+
float32x4_t asrcv[8];
1522+
float32x4_t amaxv[8];
1523+
1524+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1525+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1526+
1527+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1528+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1529+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1530+
1531+
const float amax = vmaxvq_f32(amaxv[0]);
1532+
1533+
const float d = amax / ((1 << 7) - 1);
1534+
const float id = d ? 1.0f/d : 0.0f;
1535+
1536+
y[i].d = d;
1537+
1538+
for (int l = 0; l < 8; l++) {
1539+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1540+
const int32x4_t vi = vcvtnq_s32_f32(v);
1541+
1542+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1543+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1544+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1545+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1546+
}
1547+
}
1548+
#elif defined(__AVX2__) || defined(__AVX__)
1549+
for (int i = 0; i < nb; i++) {
1550+
// Load elements into 4 AVX vectors
1551+
__m256 v0 = _mm256_loadu_ps( x );
1552+
__m256 v1 = _mm256_loadu_ps( x + 8 );
1553+
__m256 v2 = _mm256_loadu_ps( x + 16 );
1554+
__m256 v3 = _mm256_loadu_ps( x + 24 );
1555+
x += 32;
1556+
1557+
// Compute max(abs(e)) for the block
1558+
const __m256 signBit = _mm256_set1_ps( -0.0f );
1559+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1560+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1561+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1562+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1563+
1564+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1565+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1566+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1567+
const float maxScalar = _mm_cvtss_f32( max4 );
1568+
1569+
// Quantize these floats
1570+
const float d = maxScalar / 127.f;
1571+
y[i].d = d;
1572+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1573+
const __m256 mul = _mm256_set1_ps( id );
1574+
1575+
// Apply the multiplier
1576+
v0 = _mm256_mul_ps( v0, mul );
1577+
v1 = _mm256_mul_ps( v1, mul );
1578+
v2 = _mm256_mul_ps( v2, mul );
1579+
v3 = _mm256_mul_ps( v3, mul );
1580+
1581+
// Round to nearest integer
1582+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1583+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1584+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1585+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1586+
1587+
// Convert floats to integers
1588+
__m256i i0 = _mm256_cvtps_epi32( v0 );
1589+
__m256i i1 = _mm256_cvtps_epi32( v1 );
1590+
__m256i i2 = _mm256_cvtps_epi32( v2 );
1591+
__m256i i3 = _mm256_cvtps_epi32( v3 );
1592+
1593+
#if defined(__AVX2__)
1594+
// Convert int32 to int16
1595+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1596+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1597+
// Convert int16 to int8
1598+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1599+
1600+
// We got our precious signed bytes, but the order is now wrong
1601+
// These AVX2 pack instructions process 16-byte pieces independently
1602+
// The following instruction is fixing the order
1603+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1604+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
1605+
1606+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
1607+
#else
1608+
// Since we don't have in AVX some necessary functions,
1609+
// we split the registers in half and call AVX2 analogs from SSE
1610+
__m128i ni0 = _mm256_castsi256_si128( i0 );
1611+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
1612+
__m128i ni2 = _mm256_castsi256_si128( i1 );
1613+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
1614+
__m128i ni4 = _mm256_castsi256_si128( i2 );
1615+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
1616+
__m128i ni6 = _mm256_castsi256_si128( i3 );
1617+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
1618+
1619+
// Convert int32 to int16
1620+
ni0 = _mm_packs_epi32( ni0, ni1 );
1621+
ni2 = _mm_packs_epi32( ni2, ni3 );
1622+
ni4 = _mm_packs_epi32( ni4, ni5 );
1623+
ni6 = _mm_packs_epi32( ni6, ni7 );
1624+
// Convert int16 to int8
1625+
ni0 = _mm_packs_epi16( ni0, ni2 );
1626+
ni4 = _mm_packs_epi16( ni4, ni6 );
1627+
1628+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1629+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1630+
#endif
1631+
}
1632+
#else
1633+
// scalar
15161634
quantize_row_q8_0_reference(x, y, k);
1635+
#endif
15171636
}
15181637

15191638
// reference implementation for deterministic creation of model files
15201639
static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1640+
assert(QK8_1 == 32);
15211641
assert(k % QK8_1 == 0);
15221642
const int nb = k / QK8_1;
15231643

0 commit comments

Comments
 (0)