@@ -203,6 +203,91 @@ void GEMMBenchmark(benchmark::State& state,
203
203
benchmark::Counter::kIsRate );
204
204
}
205
205
206
+ void GEMMBenchmark (benchmark::State& state,
207
+ xnn_qs8_qc4w_gemm_minmax_ukernel_fn gemm,
208
+ xnn_init_qs8_qc8w_conv_minmax_params_fn init_params,
209
+ xnn_pack_qs8_qc4w_gemm_fn pack, size_t mr, size_t nr, size_t kr,
210
+ size_t sr, uint64_t arch_flags) {
211
+ if (!benchmark::utils::CheckArchFlags (state, arch_flags)) {
212
+ return ;
213
+ }
214
+
215
+ const size_t mc = state.range (0 );
216
+ const size_t nc = state.range (1 );
217
+ const size_t kc = state.range (2 );
218
+
219
+ const size_t nc_stride = benchmark::utils::RoundUp (nc, nr);
220
+ const size_t kc_stride = benchmark::utils::RoundUp (kc, kr * sr) / 2 ;
221
+
222
+ std::random_device random_device;
223
+ auto rng = std::mt19937 (random_device ());
224
+ auto i32rng = std::bind (std::uniform_int_distribution<int32_t >(-10000 , 10000 ),
225
+ std::ref (rng));
226
+
227
+ xnnpack::Buffer<int8_t > a (mc * kc, xnnpack::XnnExtraBytes);
228
+ xnnpack::fill_uniform_random_bits (a.data (), a.size (), rng);
229
+ xnnpack::Buffer<uint8_t > k (nc * kc / 2 );
230
+ xnnpack::fill_uniform_random_bits (k.data (), k.size (), rng);
231
+ xnnpack::Buffer<int32_t > b (nc);
232
+ std::generate (b.begin (), b.end (), std::ref (i32rng));
233
+
234
+ const size_t w_size = nc_stride * (sizeof (float ) + sizeof (int32_t )) +
235
+ kc_stride * nc_stride;
236
+ const size_t c_elements = mc * nc;
237
+ const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp<size_t >(
238
+ benchmark::utils::GetMaxCacheSize (),
239
+ w_size + c_elements * sizeof (int8_t ));
240
+
241
+ xnnpack::Buffer<char , XNN_ALLOCATION_ALIGNMENT> w (w_size * num_buffers);
242
+
243
+ const xnn_qs8_qc4w_packing_params packing_params = {int8_t (127 - 0x80 )};
244
+ pack (/* g=*/ 1 , nc, kc, nr, kr, sr, k.data (), b.data (), /* scale=*/ nullptr ,
245
+ w.data (), nr * sizeof (float ), &packing_params);
246
+
247
+ xnnpack::Buffer<int8_t > c (c_elements * num_buffers);
248
+
249
+ union xnn_qs8_qc8w_conv_minmax_params quantization_params;
250
+ init_params (&quantization_params,
251
+ /* output_zero_point=*/ 127 ,
252
+ /* output_min=*/ -127 ,
253
+ /* output_max=*/ 126 );
254
+
255
+ size_t buffer_index = 0 ;
256
+ for (auto _ : state) {
257
+ // Use circular buffers (exceeding cache size) and prefetch to control cache
258
+ // state:
259
+ // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
260
+ // - W is not in cache (for any cache level)
261
+ // - C is not in cache (for any cache level)
262
+ state.PauseTiming ();
263
+ benchmark::utils::PrefetchToL1 (a.data (), a.size () * sizeof (int8_t ));
264
+ buffer_index = (buffer_index + 1 ) % num_buffers;
265
+ state.ResumeTiming ();
266
+
267
+ for (uint32_t m = 0 ; m < mc; m += mr) {
268
+ const uint32_t mb = min (mc - m, mr);
269
+ for (uint32_t n = 0 ; n < nc; n += nr) {
270
+ const uint32_t nb = min (nc - n, nr);
271
+ gemm (mb, nb, kc * sizeof (int8_t ), a.data () + m * kc,
272
+ kc * sizeof (int8_t ),
273
+ w.data () + w_size * buffer_index +
274
+ n * (kc_stride + sizeof (int32_t )),
275
+ c.data () + (mc * buffer_index + m) * nc + n, nc * sizeof (int8_t ),
276
+ nr * sizeof (int8_t ), &quantization_params);
277
+ }
278
+ }
279
+ }
280
+
281
+ const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency ();
282
+ if (cpu_frequency != 0 ) {
283
+ state.counters [" cpufreq" ] = cpu_frequency;
284
+ }
285
+
286
+ state.counters [" OPS" ] =
287
+ benchmark::Counter (uint64_t (state.iterations ()) * 2 * mc * nc * kc,
288
+ benchmark::Counter::kIsRate );
289
+ }
290
+
206
291
void GEMMBenchmark (benchmark::State& state,
207
292
xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm,
208
293
xnn_init_f16_minmax_params_fn init_params,
0 commit comments