@@ -312,55 +312,151 @@ void PASTEMAC(ch,varname) \
312
312
dim_t ir_nt = bli_thread_n_way ( caucus ); \
313
313
dim_t ir_tid = bli_thread_work_id ( caucus ); \
314
314
\
315
- dim_t jr_start , jr_end ; \
316
- dim_t ir_start , ir_end ; \
317
- dim_t jr_inc , ir_inc ; \
318
- \
319
- /* Determine the thread range and increment for the 2nd and 1st loops.
320
- NOTE: The definition of bli_thread_range_jrir() will depend on whether
321
- slab or round-robin partitioning was requested at configure-time. */ \
322
- bli_thread_range_jrir ( thread , n_iter , 1 , FALSE, & jr_start , & jr_end , & jr_inc ); \
323
- bli_thread_range_jrir ( caucus , m_iter , 1 , FALSE, & ir_start , & ir_end , & ir_inc ); \
324
- \
325
- /* Loop over the n dimension (NR columns at a time). */ \
326
- for ( j = jr_start ; j < jr_end ; j += jr_inc ) \
315
+ /* Mis-balancing detection: if n_iter (or m_iter) is not multiple of
316
+ jr_nt (or ir_nt), then we collapse the two JR/IR loops and dispatch
317
+ work on jr_nt * ir_nt threads. */ \
318
+ const bool misbalancing = ((n_iter % jr_nt ) != 0 ) || ((m_iter % ir_nt ) != 0 ); \
319
+ \
320
+ if ( !misbalancing ) \
327
321
{ \
328
- ctype * restrict a1 ; \
329
- ctype * restrict c11 ; \
330
- ctype * restrict b2 ; \
322
+ /* Use traditional two loops */ \
323
+ dim_t jr_start , jr_end ; \
324
+ dim_t ir_start , ir_end ; \
325
+ dim_t jr_inc , ir_inc ; \
326
+ \
327
+ /* Determine the thread range and increment for the 2nd and 1st loops.
328
+ NOTE: The definition of bli_thread_range_jrir() will depend on whether
329
+ slab or round-robin partitioning was requested at configure-time. */ \
330
+ bli_thread_range_jrir ( thread , n_iter , 1 , FALSE, & jr_start , & jr_end , & jr_inc ); \
331
+ bli_thread_range_jrir ( caucus , m_iter , 1 , FALSE, & ir_start , & ir_end , & ir_inc ); \
332
+ \
333
+ /* Loop over the n dimension (NR columns at a time). */ \
334
+ for ( j = jr_start ; j < jr_end ; j += jr_inc ) \
335
+ { \
336
+ ctype * restrict a1 ; \
337
+ ctype * restrict c11 ; \
338
+ ctype * restrict b2 ; \
331
339
\
332
- b1 = b_cast + j * cstep_b ; \
333
- c1 = c_cast + j * cstep_c ; \
340
+ b1 = b_cast + j * cstep_b ; \
341
+ c1 = c_cast + j * cstep_c ; \
334
342
\
335
- n_cur = ( bli_is_not_edge_f ( j , n_iter , n_left ) ? NR : n_left ); \
343
+ n_cur = ( bli_is_not_edge_f ( j , n_iter , n_left ) ? NR : n_left ); \
336
344
\
337
- /* Initialize our next panel of B to be the current panel of B. */ \
338
- b2 = b1 ; \
345
+ /* Initialize our next panel of B to be the current panel of B. */ \
346
+ b2 = b1 ; \
339
347
\
340
- /* Loop over the m dimension (MR rows at a time). */ \
341
- for ( i = ir_start ; i < ir_end ; i += ir_inc ) \
342
- { \
343
- ctype * restrict a2 ; \
348
+ /* Loop over the m dimension (MR rows at a time). */ \
349
+ for ( i = ir_start ; i < ir_end ; i += ir_inc ) \
350
+ { \
351
+ ctype * restrict a2 ; \
352
+ \
353
+ a1 = a_cast + i * rstep_a ; \
354
+ c11 = c1 + i * rstep_c ; \
355
+ \
356
+ m_cur = ( bli_is_not_edge_f ( i , m_iter , m_left ) ? MR : m_left ); \
357
+ \
358
+ /* Compute the addresses of the next panels of A and B. */ \
359
+ a2 = bli_gemm_get_next_a_upanel ( a1 , rstep_a , ir_inc ); \
360
+ if ( bli_is_last_iter ( i , ir_end , ir_tid , ir_nt ) ) \
361
+ { \
362
+ a2 = a_cast ; \
363
+ b2 = bli_gemm_get_next_b_upanel ( b1 , cstep_b , jr_inc ); \
364
+ if ( bli_is_last_iter ( j , jr_end , jr_tid , jr_nt ) ) \
365
+ b2 = b_cast ; \
366
+ } \
367
+ \
368
+ /* Save addresses of next panels of A and B to the auxinfo_t
369
+ object. */ \
370
+ bli_auxinfo_set_next_a ( a2 , & aux ); \
371
+ bli_auxinfo_set_next_b ( b2 , & aux ); \
372
+ \
373
+ /* Handle interior and edge cases separately. */ \
374
+ if ( m_cur == MR && n_cur == NR ) \
375
+ { \
376
+ /* Invoke the gemm micro-kernel. */ \
377
+ gemm_ukr \
378
+ ( \
379
+ k , \
380
+ alpha_cast , \
381
+ a1 , \
382
+ b1 , \
383
+ beta_cast , \
384
+ c11 , rs_c , cs_c , \
385
+ & aux , \
386
+ cntx \
387
+ ); \
388
+ } \
389
+ else \
390
+ { \
391
+ /* Invoke the gemm micro-kernel. */ \
392
+ gemm_ukr \
393
+ ( \
394
+ k , \
395
+ alpha_cast , \
396
+ a1 , \
397
+ b1 , \
398
+ zero , \
399
+ ct , rs_ct , cs_ct , \
400
+ & aux , \
401
+ cntx \
402
+ ); \
403
+ \
404
+ /* Scale the bottom edge of C and add the result from above. */ \
405
+ PASTEMAC (ch ,xpbys_mxn )( m_cur , n_cur , \
406
+ ct , rs_ct , cs_ct , \
407
+ beta_cast , \
408
+ c11 , rs_c , cs_c ); \
409
+ } \
410
+ } \
411
+ } \
412
+ } \
413
+ else /* misbalancing == TRUE */ \
414
+ { \
415
+ /* Fused number of threads in JR and IR loops */ \
416
+ dim_t jrir_nt = jr_nt * ir_nt ; \
344
417
\
345
- a1 = a_cast + i * rstep_a ; \
346
- c11 = c1 + i * rstep_c ; \
418
+ /* My thread id in the fused thread-domain */ \
419
+ dim_t jrir_tid = ir_tid * jr_nt + jr_tid ; \
347
420
\
348
- m_cur = ( bli_is_not_edge_f ( i , m_iter , m_left ) ? MR : m_left ); \
421
+ /* Build a temporary thrinfo_t for dispatching
422
+ NOTE: Only n_way and work_id is needed for bli_thread_range_jrir() */ \
423
+ thrinfo_t jrir_tinfo ; \
424
+ bli_thrinfo_init ( & jrir_tinfo , NULL , 0 , jrir_nt , jrir_tid , FALSE, BLIS_NO_PART , NULL ); \
349
425
\
350
- /* Compute the addresses of the next panels of A and B. */ \
351
- a2 = bli_gemm_get_next_a_upanel ( a1 , rstep_a , ir_inc ); \
352
- if ( bli_is_last_iter ( i , ir_end , ir_tid , ir_nt ) ) \
353
- { \
354
- a2 = a_cast ; \
355
- b2 = bli_gemm_get_next_b_upanel ( b1 , cstep_b , jr_inc ); \
356
- if ( bli_is_last_iter ( j , jr_end , jr_tid , jr_nt ) ) \
357
- b2 = b_cast ; \
358
- } \
426
+ /* Dispatch (n_iter * m_iter) micro-tiles */ \
427
+ dim_t jrir_start , jrir_end ; \
428
+ dim_t jrir_inc ; \
429
+ bli_thread_range_jrir ( & jrir_tinfo , (n_iter * m_iter ), 1 , FALSE, \
430
+ & jrir_start , & jrir_end , & jrir_inc ); \
431
+ \
432
+ /* Loop over the fused JR/IR dimension. */ \
433
+ for ( dim_t ji = jrir_start ; ji < jrir_end ; ji += jrir_inc ) \
434
+ { \
435
+ ctype * restrict a1 ; \
436
+ ctype * restrict c11 ; \
437
+ \
438
+ /* Update current tile */ \
439
+ j = ji % n_iter ; \
440
+ i = ji / n_iter ; \
441
+ a1 = a_cast + i * rstep_a ; \
442
+ b1 = b_cast + j * cstep_b ; \
443
+ \
444
+ /* Next tile and panels */ \
445
+ dim_t jnext = (ji + jrir_inc ) % n_iter ; \
446
+ dim_t inext = (ji + jrir_inc ) / n_iter ; \
447
+ ctype * restrict a2 = a_cast + inext * rstep_a ; \
448
+ ctype * restrict b2 = b_cast + jnext * cstep_b ; \
359
449
\
360
450
/* Save addresses of next panels of A and B to the auxinfo_t
361
451
object. */ \
362
452
bli_auxinfo_set_next_a ( a2 , & aux ); \
363
453
bli_auxinfo_set_next_b ( b2 , & aux ); \
454
+ \
455
+ c1 = c_cast + j * cstep_c ; \
456
+ c11 = c1 + i * rstep_c ; \
457
+ \
458
+ n_cur = ( bli_is_not_edge_f ( j , n_iter , n_left ) ? NR : n_left ); \
459
+ m_cur = ( bli_is_not_edge_f ( i , m_iter , m_left ) ? MR : m_left ); \
364
460
\
365
461
/* Handle interior and edge cases separately. */ \
366
462
if ( m_cur == MR && n_cur == NR ) \
@@ -400,7 +496,7 @@ void PASTEMAC(ch,varname) \
400
496
c11 , rs_c , cs_c ); \
401
497
} \
402
498
} \
403
- } \
499
+ } /* misbalancing */ \
404
500
\
405
501
/*
406
502
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
0 commit comments