Skip to content

Commit da0829f

Browse files
fbarchardxnnpack-bot
authored andcommitted
Add int8xint4 FC XNNPACK kernel for SSE
- Uses fake vnni method to replace vnni with pmaddubsw - Same as qd8_qc4w but input is qs8 signed int and requires an XOR 0x80 - ssse3 uses _mm_mul_epu32 which is faster than sse4 _mm_mullo_epi32 on Intel - 4 bits are unsigned and kept in low 4 bits so pmaddubsw does not overflow - quantization is same as qs8_qc8w which is per NC channel - 1.6x faster than using 16 bit pmaddwd method - prefetch supported and 5% faster on silvermont (sse4 atom) - supports 32 bit, but 2x4c8 is needed to avoid spill - ssse3 is required for pmaddubsw 8 bit multiply PiperOrigin-RevId: 785454672
1 parent 11a60c3 commit da0829f

27 files changed

+3701
-49
lines changed

cmake/gen/sse41_microkernels.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ SET(NON_PROD_SSE41_MICROKERNEL_SRCS
175175
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u8.c
176176
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u24.c
177177
src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u32.c
178+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c
179+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd.c
180+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c
181+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd.c
182+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c
183+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd.c
184+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c
185+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd.c
178186
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p8c-minmax-fp32-sse41-mul16-add16.c
179187
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p8c-minmax-fp32-sse41-mul32.c
180188
src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p16c-minmax-fp32-sse41-mul16-add16.c

cmake/gen/ssse3_microkernels.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ SET(NON_PROD_SSSE3_MICROKERNEL_SRCS
3737
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c
3838
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c
3939
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c
40+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c
41+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd.c
42+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c
43+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd.c
44+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c
45+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd.c
46+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c
47+
src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd.c
4048
src/qs8-rsum/gen/qs8-rsum-ssse3-u16.c
4149
src/qs8-rsum/gen/qs8-rsum-ssse3-u64-acc2.c
4250
src/qs8-rsum/gen/qs8-rsum-ssse3-u64-acc4.c

gen/sse41_microkernels.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ NON_PROD_SSE41_MICROKERNEL_SRCS = [
172172
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u8.c",
173173
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u24.c",
174174
"src/qs8-f32-vcvt/gen/qs8-f32-vcvt-sse41-u32.c",
175+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c",
176+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd.c",
177+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c",
178+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd.c",
179+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c",
180+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd.c",
181+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c",
182+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd.c",
175183
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p8c-minmax-fp32-sse41-mul16-add16.c",
176184
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p8c-minmax-fp32-sse41-mul32.c",
177185
"src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-9p16c-minmax-fp32-sse41-mul16-add16.c",

gen/ssse3_microkernels.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ NON_PROD_SSSE3_MICROKERNEL_SRCS = [
3434
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c",
3535
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c",
3636
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c",
37+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c",
38+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd.c",
39+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c",
40+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd.c",
41+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c",
42+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd.c",
43+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c",
44+
"src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd.c",
3745
"src/qs8-rsum/gen/qs8-rsum-ssse3-u16.c",
3846
"src/qs8-rsum/gen/qs8-rsum-ssse3-u64-acc2.c",
3947
"src/qs8-rsum/gen/qs8-rsum-ssse3-u64-acc4.c",

scripts/generate-qs8-gemm.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,24 @@ tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QC4_F32 -D SSE
15631563
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QC4_F32 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION= -o src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c &
15641564
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QC4_F32 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION= -o src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c &
15651565

1566+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd.c &
1567+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd.c &
1568+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd.c &
1569+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd.c &
1570+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd.c &
1571+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd.c &
1572+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd.c &
1573+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd.c &
1574+
1575+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c &
1576+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c &
1577+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c &
1578+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC4 -D SSE=3 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c &
1579+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=1 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c &
1580+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=2 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c &
1581+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=3 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c &
1582+
tools/xngen src/qs8-gemm/MRx4c8-ssevnni.c.in -D MR=4 -D DATATYPE=QS8_QC4 -D SSE=4 -D AVX=0 -D VARIANT=MADD -D GFNI=0 -D PREFETCH=1 -D REQUANTIZATION=FP32 -o src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c &
1583+
15661584
################################## x86 AVX256 VNNI EVEX #################################
15671585
### C8 micro-kernels
15681586
tools/xngen src/qs8-gemm/MRx8c8-avxvnni.c.in -D MR=1 -D DATATYPE=QC8 -D AVX=10 -D VARIANT= -D GFNI=0 -D PREFETCH=0 -D REQUANTIZATION=FP32 -o src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni.c &

src/qs8-gemm/MRx4c8-ssevnni.c.in

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
// This source code is licensed under the BSD-style license found in the
44
// LICENSE file in the root directory of this source tree.
55

6-
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
76
$assert REQUANTIZATION == "FP32" or not REQUANTIZATION
8-
$assert DATATYPE in ["QC4_F32"]
7+
$assert DATATYPE in ["QC4_F32", "QS8_QC4"]
98
$assert SSE in [3, 4]
109
#include <assert.h>
1110
#include <stddef.h>
@@ -24,14 +23,14 @@ $if PREFETCH:
2423
#include "src/xnnpack/unaligned.h"
2524

2625

27-
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QU8": "qu8", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
26+
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8_QC4": "qs8_qc4w", "QU8": "qu8", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
2827
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else ""
2928
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" if REQUANTIZATION else "scalar"
30-
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QU8": "union xnn_qu8_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
29+
$PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC4": "union xnn_qs8_qc8w_conv_minmax_params", "QU8": "union xnn_qu8_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
3130
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
32-
$OUT_T = {"QC8": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QU8": "uint8_t"}[DATATYPE]
31+
$OUT_T = {"QC8": "int8_t", "QS8_QC4": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QU8": "uint8_t"}[DATATYPE]
3332
$_MM_PACKXS_EPI16 = "_mm_packus_epi16" if DATATYPE == "QU8" else "_mm_packs_epi16"
34-
$_MM_MAX_EPX8 = "_mm_max_epu8" if DATATYPE == "QU8" else "_mm_max_epi8"
33+
$_MM_MAX_EPX16 = "_mm_max_epu16" if DATATYPE == "QU8" else "_mm_max_epi16"
3534
$_MM_CVTXEPI32_EPI8 = "_mm_cvtusepi32_epi8" if DATATYPE == "QU8" else "_mm_cvtsepi32_epi8"
3635
$_MM_DPBUSD_EPI32 = "_mm_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm_dpbusd_avx_epi32" if AVX == 2 else "_mm_dpbusd_epi32"
3736
$ISA = "sse41" if SSE == 4 else "ssse3"
@@ -97,24 +96,21 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
9796
$else:
9897
const __m128 voutput_min = _mm_set1_ps(params->scalar.min);
9998
const __m128 voutput_max = _mm_set1_ps(params->scalar.max);
100-
$if DATATYPE in ["QC4_F16", "QC4_F32"]:
101-
$if VARIANT == "MADD":
102-
const __m128i vmask = _mm_set1_epi8(0x0F);
103-
$else:
104-
const __m128i vmask = _mm_set1_epi8(0xF0);
105-
XNN_FORCE_REALIZATION(vmask);
106-
$if GFNI:
107-
const __m128i vshl4 = _mm_set1_epi64x(0x01020408);
108-
XNN_FORCE_REALIZATION(vshl4);
10999
$else:
110100
const __m128i vsign_mask = _mm_set1_epi8(0x80);
111101
XNN_FORCE_REALIZATION(vsign_mask);
112102
const __m128 voutput_max_less_zero_point = _mm_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point);
113103
const __m128i voutput_zero_point = _mm_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point);
114-
const __m128i voutput_min = _mm_set1_epi8(params->${PARAMS_STRUCT}.output_min);
115-
// XNN_FORCE_REALIZATION(voutput_max_less_zero_point);
116-
// XNN_FORCE_REALIZATION(voutput_zero_point);
117-
// XNN_FORCE_REALIZATION(voutput_min);
104+
const __m128i voutput_min = _mm_set1_epi16(params->${PARAMS_STRUCT}.output_min);
105+
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
106+
$if VARIANT == "MADD":
107+
const __m128i vmask = _mm_set1_epi8(0x0F);
108+
$else:
109+
const __m128i vmask = _mm_set1_epi8(0xF0);
110+
XNN_FORCE_REALIZATION(vmask);
111+
$if GFNI:
112+
const __m128i vshl4 = _mm_set1_epi64x(0x01020408);
113+
XNN_FORCE_REALIZATION(vshl4);
118114
do {
119115
$if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]:
120116
const __m128i vksum0123 = _mm_load_si128(w);
@@ -133,8 +129,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
133129
__m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x01, _mm_setzero_si128());
134130
__m128i vacc${M}x23 = _mm_unpacklo_epi32(vsum${M}x23, _mm_setzero_si128());
135131
$else:
136-
__m128i vacc0x01 = _mm_unpacklo_epi32(vsum${M}x0123, _mm_setzero_si128());
137-
__m128i vacc0x23 = _mm_unpackhi_epi32(vsum${M}x0123, _mm_setzero_si128());
132+
const __m128i vksum0123 = _mm_load_si128(w);
133+
__m128i vacc0x01 = _mm_unpacklo_epi32(vksum0123, _mm_setzero_si128());
134+
__m128i vacc0x23 = _mm_unpackhi_epi32(vksum0123, _mm_setzero_si128());
138135
$for M in range(1, MR):
139136
__m128i vacc${M}x01 = vacc0x01;
140137
__m128i vacc${M}x23 = vacc0x23;
@@ -155,7 +152,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
155152
const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask);
156153
a${M} += 16;
157154

158-
$if DATATYPE in ["QC4_F16", "QC4_F32"]:
155+
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
159156
const __m128i vbb01234567x0123 = _mm_load_si128(w);
160157
const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + 16));
161158
$if GFNI:
@@ -198,7 +195,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
198195
vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x89ABCDEF, vb01234567x23);
199196
vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x89ABCDEF, vb89ABCDEFx23);
200197

201-
$if DATATYPE in ["QC4_F16", "QC4_F32"]:
198+
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
202199
w = (const ${XINT8_T}*) w + 32;
203200
$else:
204201
w = (const ${XINT8_T}*) w + 64;
@@ -210,10 +207,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
210207
$if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]:
211208
const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M}));
212209
$else:
213-
const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask);
210+
const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask);
214211
a${M} += 8;
215212

216-
$if DATATYPE in ["QC4_F16", "QC4_F32"]:
213+
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
217214
const __m128i vbb01234567x0123 = _mm_load_si128(w);
218215
const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + 16));
219216
$if GFNI:
@@ -247,7 +244,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
247244
$for M in range(MR):
248245
__m128i vacc${M}x0123 = _mm_hadd_epi32(vacc${M}x01, vacc${M}x23);
249246

250-
$if DATATYPE in ["QC4_F16", "QC4_F32"] and VARIANT != "MADD":
247+
$if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"] and VARIANT != "MADD":
251248
$for M in range(MR):
252249
vacc${M}x0123 = _mm_srai_epi32(vacc${M}x0123, 4);
253250
$for M in range(MR):
@@ -339,20 +336,16 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
339336
$for M in range(MR):
340337
vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, voutput_zero_point);
341338

342-
$if SSE >= 3:
343-
$for M in range(MR):
344-
vacc${M}x0123 = _mm_packs_epi32(vacc${M}x0123, vacc${M}x0123);
345-
__m128i voutb${M}x0123 = _mm_packs_epi16(vacc${M}x0123, vacc${M}x0123);
346-
$else:
347-
$for M in range(MR):
348-
__m128i voutb${M}x0123 = ${_MM_CVTXEPI32_EPI8}(vacc${M}x0123);
349-
350339
$for M in range(MR):
351-
voutb${M}x0123 = ${_MM_MAX_EPX8}(voutb${M}x0123, voutput_min);
340+
vacc${M}x0123 = _mm_packs_epi32(vacc${M}x0123, vacc${M}x0123);
341+
$for M in range(MR):
342+
vacc${M}x0123 = ${_MM_MAX_EPX16}(vacc${M}x0123, voutput_min);
343+
$for M in range(MR):
344+
__m128i voutb${M}x0123 = _mm_packs_epi16(vacc${M}x0123, vacc${M}x0123);
352345

353346
if (nc >= 4) {
354347
$for M in range(MR):
355-
_mm_storeu_ps(c${M}, voutb${M}x0123);
348+
_mm_storeu_si32(c${M}, voutb${M}x0123);
356349
c${M} = (${OUT_T}*) ((uintptr_t) c${M} + cn_stride);
357350
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
358351

@@ -368,7 +361,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
368361
}
369362
if (nc & 1) {
370363
$for M in range(MR):
371-
*c${M} = (${XINT8_T}) _mm_extract_epi8(voutb${M}x0123, 0);
364+
*c${M} = (${OUT_T}) _mm_extract_epi16(voutb${M}x0123, 0);
372365
}
373366
$else:
374367
// Prepare mask for valid 8-bit elements (depends on nc).

0 commit comments

Comments
 (0)