@@ -1509,15 +1509,135 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1509
1509
}
1510
1510
1511
1511
static void quantize_row_q8_0 (const float * restrict x , void * restrict vy , int k ) {
1512
+ assert (QK8_0 == 32 );
1512
1513
assert (k % QK8_0 == 0 );
1514
+ const int nb = k / QK8_0 ;
1513
1515
1514
1516
block_q8_0 * restrict y = vy ;
1515
1517
1518
+ #if defined(__ARM_NEON )
1519
+ for (int i = 0 ; i < nb ; i ++ ) {
1520
+ float32x4_t srcv [8 ];
1521
+ float32x4_t asrcv [8 ];
1522
+ float32x4_t amaxv [8 ];
1523
+
1524
+ for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
1525
+ for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
1526
+
1527
+ for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
1528
+ for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
1529
+ for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
1530
+
1531
+ const float amax = vmaxvq_f32 (amaxv [0 ]);
1532
+
1533
+ const float d = amax / ((1 << 7 ) - 1 );
1534
+ const float id = d ? 1.0f /d : 0.0f ;
1535
+
1536
+ y [i ].d = d ;
1537
+
1538
+ for (int l = 0 ; l < 8 ; l ++ ) {
1539
+ const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1540
+ const int32x4_t vi = vcvtnq_s32_f32 (v );
1541
+
1542
+ y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
1543
+ y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
1544
+ y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
1545
+ y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
1546
+ }
1547
+ }
1548
+ #elif defined(__AVX2__ ) || defined(__AVX__ )
1549
+ for (int i = 0 ; i < nb ; i ++ ) {
1550
+ // Load elements into 4 AVX vectors
1551
+ __m256 v0 = _mm256_loadu_ps ( x );
1552
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
1553
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
1554
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
1555
+ x += 32 ;
1556
+
1557
+ // Compute max(abs(e)) for the block
1558
+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
1559
+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
1560
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
1561
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
1562
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
1563
+
1564
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
1565
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
1566
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
1567
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
1568
+
1569
+ // Quantize these floats
1570
+ const float d = maxScalar / 127.f ;
1571
+ y [i ].d = d ;
1572
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f ;
1573
+ const __m256 mul = _mm256_set1_ps ( id );
1574
+
1575
+ // Apply the multiplier
1576
+ v0 = _mm256_mul_ps ( v0 , mul );
1577
+ v1 = _mm256_mul_ps ( v1 , mul );
1578
+ v2 = _mm256_mul_ps ( v2 , mul );
1579
+ v3 = _mm256_mul_ps ( v3 , mul );
1580
+
1581
+ // Round to nearest integer
1582
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
1583
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
1584
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
1585
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
1586
+
1587
+ // Convert floats to integers
1588
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
1589
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
1590
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
1591
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
1592
+
1593
+ #if defined(__AVX2__ )
1594
+ // Convert int32 to int16
1595
+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1596
+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1597
+ // Convert int16 to int8
1598
+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1599
+
1600
+ // We got our precious signed bytes, but the order is now wrong
1601
+ // These AVX2 pack instructions process 16-byte pieces independently
1602
+ // The following instruction is fixing the order
1603
+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
1604
+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
1605
+
1606
+ _mm256_storeu_si256 ((__m256i * )y [i ].qs , i0 );
1607
+ #else
1608
+ // Since we don't have in AVX some necessary functions,
1609
+ // we split the registers in half and call AVX2 analogs from SSE
1610
+ __m128i ni0 = _mm256_castsi256_si128 ( i0 );
1611
+ __m128i ni1 = _mm256_extractf128_si256 ( i0 , 1 );
1612
+ __m128i ni2 = _mm256_castsi256_si128 ( i1 );
1613
+ __m128i ni3 = _mm256_extractf128_si256 ( i1 , 1 );
1614
+ __m128i ni4 = _mm256_castsi256_si128 ( i2 );
1615
+ __m128i ni5 = _mm256_extractf128_si256 ( i2 , 1 );
1616
+ __m128i ni6 = _mm256_castsi256_si128 ( i3 );
1617
+ __m128i ni7 = _mm256_extractf128_si256 ( i3 , 1 );
1618
+
1619
+ // Convert int32 to int16
1620
+ ni0 = _mm_packs_epi32 ( ni0 , ni1 );
1621
+ ni2 = _mm_packs_epi32 ( ni2 , ni3 );
1622
+ ni4 = _mm_packs_epi32 ( ni4 , ni5 );
1623
+ ni6 = _mm_packs_epi32 ( ni6 , ni7 );
1624
+ // Convert int16 to int8
1625
+ ni0 = _mm_packs_epi16 ( ni0 , ni2 );
1626
+ ni4 = _mm_packs_epi16 ( ni4 , ni6 );
1627
+
1628
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 0 ), ni0 );
1629
+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 16 ), ni4 );
1630
+ #endif
1631
+ }
1632
+ #else
1633
+ // scalar
1516
1634
quantize_row_q8_0_reference (x , y , k );
1635
+ #endif
1517
1636
}
1518
1637
1519
1638
// reference implementation for deterministic creation of model files
1520
1639
static void quantize_row_q8_1_reference (const float * restrict x , block_q8_1 * restrict y , int k ) {
1640
+ assert (QK8_1 == 32 );
1521
1641
assert (k % QK8_1 == 0 );
1522
1642
const int nb = k / QK8_1 ;
1523
1643
0 commit comments