3
3
// This source code is licensed under the BSD-style license found in the
4
4
// LICENSE file in the root directory of this source tree.
5
5
6
- $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7
6
$assert REQUANTIZATION == "FP32" or not REQUANTIZATION
8
- $assert DATATYPE in ["QC4_F32"]
7
+ $assert DATATYPE in ["QC4_F32", "QS8_QC4" ]
9
8
$assert SSE in [3, 4]
10
9
#include <assert.h>
11
10
#include <stddef.h>
@@ -24,14 +23,14 @@ $if PREFETCH:
24
23
#include "src/xnnpack/unaligned.h"
25
24
26
25
27
- $DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QU8": "qu8", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
26
+ $DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QS8_QC4": "qs8_qc4w", " QU8": "qu8", "QD8_F16" : "qd8_f16_qc8w", "QD8_F32": "qd8_f32_qc8w", "QC4_F16": "qd8_f16_qc4w", "QC4_F32": "qd8_f32_qc4w"}[DATATYPE]
28
27
$REQUANTIZATION_SPEC = "_" + REQUANTIZATION.lower() if REQUANTIZATION else ""
29
28
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" if REQUANTIZATION else "scalar"
30
- $PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QU8": "union xnn_qu8_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
29
+ $PARAMS_TYPE = {"QC8": "union xnn_qs8_qc8w_conv_minmax_params", "QS8_QC4": "union xnn_qs8_qc8w_conv_minmax_params", " QU8": "union xnn_qu8_conv_minmax_params", "QD8_F16": "struct xnn_f16_minmax_params", "QD8_F32": "struct xnn_f32_minmax_params", "QC4_F16": "struct xnn_f16_qc4w_minmax_params", "QC4_F32": "struct xnn_f32_qc4w_minmax_params"}[DATATYPE]
31
30
$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
32
- $OUT_T = {"QC8": "int8_t", "QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QU8": "uint8_t"}[DATATYPE]
31
+ $OUT_T = {"QC8": "int8_t", "QS8_QC4": "int8_t", " QD8_F16": "xnn_float16", "QD8_F32": "float", "QC4_F16": "xnn_float16", "QC4_F32": "float", "QU8": "uint8_t"}[DATATYPE]
33
32
$_MM_PACKXS_EPI16 = "_mm_packus_epi16" if DATATYPE == "QU8" else "_mm_packs_epi16"
34
- $_MM_MAX_EPX8 = "_mm_max_epu8 " if DATATYPE == "QU8" else "_mm_max_epi8 "
33
+ $_MM_MAX_EPX16 = "_mm_max_epu16 " if DATATYPE == "QU8" else "_mm_max_epi16 "
35
34
$_MM_CVTXEPI32_EPI8 = "_mm_cvtusepi32_epi8" if DATATYPE == "QU8" else "_mm_cvtsepi32_epi8"
36
35
$_MM_DPBUSD_EPI32 = "_mm_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm_dpbusd_avx_epi32" if AVX == 2 else "_mm_dpbusd_epi32"
37
36
$ISA = "sse41" if SSE == 4 else "ssse3"
@@ -97,24 +96,21 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
97
96
$else:
98
97
const __m128 voutput_min = _mm_set1_ps(params->scalar.min);
99
98
const __m128 voutput_max = _mm_set1_ps(params->scalar.max);
100
- $if DATATYPE in ["QC4_F16", "QC4_F32"]:
101
- $if VARIANT == "MADD":
102
- const __m128i vmask = _mm_set1_epi8(0x0F);
103
- $else:
104
- const __m128i vmask = _mm_set1_epi8(0xF0);
105
- XNN_FORCE_REALIZATION(vmask);
106
- $if GFNI:
107
- const __m128i vshl4 = _mm_set1_epi64x(0x01020408);
108
- XNN_FORCE_REALIZATION(vshl4);
109
99
$else:
110
100
const __m128i vsign_mask = _mm_set1_epi8(0x80);
111
101
XNN_FORCE_REALIZATION(vsign_mask);
112
102
const __m128 voutput_max_less_zero_point = _mm_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point);
113
103
const __m128i voutput_zero_point = _mm_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point);
114
- const __m128i voutput_min = _mm_set1_epi8(params->${PARAMS_STRUCT}.output_min);
115
- // XNN_FORCE_REALIZATION(voutput_max_less_zero_point);
116
- // XNN_FORCE_REALIZATION(voutput_zero_point);
117
- // XNN_FORCE_REALIZATION(voutput_min);
104
+ const __m128i voutput_min = _mm_set1_epi16(params->${PARAMS_STRUCT}.output_min);
105
+ $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4"]:
106
+ $if VARIANT == "MADD":
107
+ const __m128i vmask = _mm_set1_epi8(0x0F);
108
+ $else:
109
+ const __m128i vmask = _mm_set1_epi8(0xF0);
110
+ XNN_FORCE_REALIZATION(vmask);
111
+ $if GFNI:
112
+ const __m128i vshl4 = _mm_set1_epi64x(0x01020408);
113
+ XNN_FORCE_REALIZATION(vshl4);
118
114
do {
119
115
$if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]:
120
116
const __m128i vksum0123 = _mm_load_si128(w);
@@ -133,8 +129,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
133
129
__m128i vacc${M}x01 = _mm_unpacklo_epi32(vsum${M}x01, _mm_setzero_si128());
134
130
__m128i vacc${M}x23 = _mm_unpacklo_epi32(vsum${M}x23, _mm_setzero_si128());
135
131
$else:
136
- __m128i vacc0x01 = _mm_unpacklo_epi32(vsum${M}x0123, _mm_setzero_si128());
137
- __m128i vacc0x23 = _mm_unpackhi_epi32(vsum${M}x0123, _mm_setzero_si128());
132
+ const __m128i vksum0123 = _mm_load_si128(w);
133
+ __m128i vacc0x01 = _mm_unpacklo_epi32(vksum0123, _mm_setzero_si128());
134
+ __m128i vacc0x23 = _mm_unpackhi_epi32(vksum0123, _mm_setzero_si128());
138
135
$for M in range(1, MR):
139
136
__m128i vacc${M}x01 = vacc0x01;
140
137
__m128i vacc${M}x23 = vacc0x23;
@@ -155,7 +152,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
155
152
const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask);
156
153
a${M} += 16;
157
154
158
- $if DATATYPE in ["QC4_F16", "QC4_F32"]:
155
+ $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4" ]:
159
156
const __m128i vbb01234567x0123 = _mm_load_si128(w);
160
157
const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + 16));
161
158
$if GFNI:
@@ -198,7 +195,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
198
195
vacc${M}x01 = ${_MM_DPBUSD_EPI32}(vacc${M}x01, va${M}x89ABCDEF, vb01234567x23);
199
196
vacc${M}x23 = ${_MM_DPBUSD_EPI32}(vacc${M}x23, va${M}x89ABCDEF, vb89ABCDEFx23);
200
197
201
- $if DATATYPE in ["QC4_F16", "QC4_F32"]:
198
+ $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4" ]:
202
199
w = (const ${XINT8_T}*) w + 32;
203
200
$else:
204
201
w = (const ${XINT8_T}*) w + 64;
@@ -210,10 +207,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
210
207
$if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]:
211
208
const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M}));
212
209
$else:
213
- const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8 )), vsign_mask);
210
+ const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask);
214
211
a${M} += 8;
215
212
216
- $if DATATYPE in ["QC4_F16", "QC4_F32"]:
213
+ $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4" ]:
217
214
const __m128i vbb01234567x0123 = _mm_load_si128(w);
218
215
const __m128i vbb89ABCDEFx0123 = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + 16));
219
216
$if GFNI:
@@ -247,7 +244,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
247
244
$for M in range(MR):
248
245
__m128i vacc${M}x0123 = _mm_hadd_epi32(vacc${M}x01, vacc${M}x23);
249
246
250
- $if DATATYPE in ["QC4_F16", "QC4_F32"] and VARIANT != "MADD":
247
+ $if DATATYPE in ["QC4_F16", "QC4_F32", "QS8_QC4" ] and VARIANT != "MADD":
251
248
$for M in range(MR):
252
249
vacc${M}x0123 = _mm_srai_epi32(vacc${M}x0123, 4);
253
250
$for M in range(MR):
@@ -339,20 +336,16 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
339
336
$for M in range(MR):
340
337
vacc${M}x0123 = _mm_add_epi32(vacc${M}x0123, voutput_zero_point);
341
338
342
- $if SSE >= 3:
343
- $for M in range(MR):
344
- vacc${M}x0123 = _mm_packs_epi32(vacc${M}x0123, vacc${M}x0123);
345
- __m128i voutb${M}x0123 = _mm_packs_epi16(vacc${M}x0123, vacc${M}x0123);
346
- $else:
347
- $for M in range(MR):
348
- __m128i voutb${M}x0123 = ${_MM_CVTXEPI32_EPI8}(vacc${M}x0123);
349
-
350
339
$for M in range(MR):
351
- voutb${M}x0123 = ${_MM_MAX_EPX8}(voutb${M}x0123, voutput_min);
340
+ vacc${M}x0123 = _mm_packs_epi32(vacc${M}x0123, vacc${M}x0123);
341
+ $for M in range(MR):
342
+ vacc${M}x0123 = ${_MM_MAX_EPX16}(vacc${M}x0123, voutput_min);
343
+ $for M in range(MR):
344
+ __m128i voutb${M}x0123 = _mm_packs_epi16(vacc${M}x0123, vacc${M}x0123);
352
345
353
346
if (nc >= 4) {
354
347
$for M in range(MR):
355
- _mm_storeu_ps (c${M}, voutb${M}x0123);
348
+ _mm_storeu_si32 (c${M}, voutb${M}x0123);
356
349
c${M} = (${OUT_T}*) ((uintptr_t) c${M} + cn_stride);
357
350
a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
358
351
@@ -368,7 +361,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$
368
361
}
369
362
if (nc & 1) {
370
363
$for M in range(MR):
371
- *c${M} = (${XINT8_T }) _mm_extract_epi8 (voutb${M}x0123, 0);
364
+ *c${M} = (${OUT_T }) _mm_extract_epi16 (voutb${M}x0123, 0);
372
365
}
373
366
$else:
374
367
// Prepare mask for valid 8-bit elements (depends on nc).
0 commit comments