Skip to content

Commit 4710c49

Browse files
author
Minh Quan Ho
committed
Detect and deal with mis-balancing in GEMM macro-kernel (flame#437)
Details: - In some multi-threading schemes, JR_NT and IR_NT may produce idle threads not performing any computation. - This commits detect such situation and implement a collapse of JR/IR loops.
1 parent e8caf20 commit 4710c49

File tree

2 files changed

+299
-116
lines changed

2 files changed

+299
-116
lines changed

frame/3/gemm/bli_gemm_ker_var2.c

+133-37
Original file line numberDiff line numberDiff line change
@@ -312,55 +312,151 @@ void PASTEMAC(ch,varname) \
312312
dim_t ir_nt = bli_thread_n_way( caucus ); \
313313
dim_t ir_tid = bli_thread_work_id( caucus ); \
314314
\
315-
dim_t jr_start, jr_end; \
316-
dim_t ir_start, ir_end; \
317-
dim_t jr_inc, ir_inc; \
318-
\
319-
/* Determine the thread range and increment for the 2nd and 1st loops.
320-
NOTE: The definition of bli_thread_range_jrir() will depend on whether
321-
slab or round-robin partitioning was requested at configure-time. */ \
322-
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
323-
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
324-
\
325-
/* Loop over the n dimension (NR columns at a time). */ \
326-
for ( j = jr_start; j < jr_end; j += jr_inc ) \
315+
/* Mis-balancing detection: if n_iter (or m_iter) is not multiple of
316+
jr_nt (or ir_nt), then we collapse the two JR/IR loops and dispatch
317+
work on jr_nt * ir_nt threads. */ \
318+
const bool misbalancing = ((n_iter % jr_nt) != 0) || ((m_iter % ir_nt) != 0); \
319+
\
320+
if ( !misbalancing ) \
327321
{ \
328-
ctype* restrict a1; \
329-
ctype* restrict c11; \
330-
ctype* restrict b2; \
322+
/* Use traditional two loops */ \
323+
dim_t jr_start, jr_end; \
324+
dim_t ir_start, ir_end; \
325+
dim_t jr_inc, ir_inc; \
326+
\
327+
/* Determine the thread range and increment for the 2nd and 1st loops.
328+
NOTE: The definition of bli_thread_range_jrir() will depend on whether
329+
slab or round-robin partitioning was requested at configure-time. */ \
330+
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
331+
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
332+
\
333+
/* Loop over the n dimension (NR columns at a time). */ \
334+
for ( j = jr_start; j < jr_end; j += jr_inc ) \
335+
{ \
336+
ctype* restrict a1; \
337+
ctype* restrict c11; \
338+
ctype* restrict b2; \
331339
\
332-
b1 = b_cast + j * cstep_b; \
333-
c1 = c_cast + j * cstep_c; \
340+
b1 = b_cast + j * cstep_b; \
341+
c1 = c_cast + j * cstep_c; \
334342
\
335-
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
343+
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
336344
\
337-
/* Initialize our next panel of B to be the current panel of B. */ \
338-
b2 = b1; \
345+
/* Initialize our next panel of B to be the current panel of B. */ \
346+
b2 = b1; \
339347
\
340-
/* Loop over the m dimension (MR rows at a time). */ \
341-
for ( i = ir_start; i < ir_end; i += ir_inc ) \
342-
{ \
343-
ctype* restrict a2; \
348+
/* Loop over the m dimension (MR rows at a time). */ \
349+
for ( i = ir_start; i < ir_end; i += ir_inc ) \
350+
{ \
351+
ctype* restrict a2; \
352+
\
353+
a1 = a_cast + i * rstep_a; \
354+
c11 = c1 + i * rstep_c; \
355+
\
356+
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
357+
\
358+
/* Compute the addresses of the next panels of A and B. */ \
359+
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
360+
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
361+
{ \
362+
a2 = a_cast; \
363+
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
364+
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
365+
b2 = b_cast; \
366+
} \
367+
\
368+
/* Save addresses of next panels of A and B to the auxinfo_t
369+
object. */ \
370+
bli_auxinfo_set_next_a( a2, &aux ); \
371+
bli_auxinfo_set_next_b( b2, &aux ); \
372+
\
373+
/* Handle interior and edge cases separately. */ \
374+
if ( m_cur == MR && n_cur == NR ) \
375+
{ \
376+
/* Invoke the gemm micro-kernel. */ \
377+
gemm_ukr \
378+
( \
379+
k, \
380+
alpha_cast, \
381+
a1, \
382+
b1, \
383+
beta_cast, \
384+
c11, rs_c, cs_c, \
385+
&aux, \
386+
cntx \
387+
); \
388+
} \
389+
else \
390+
{ \
391+
/* Invoke the gemm micro-kernel. */ \
392+
gemm_ukr \
393+
( \
394+
k, \
395+
alpha_cast, \
396+
a1, \
397+
b1, \
398+
zero, \
399+
ct, rs_ct, cs_ct, \
400+
&aux, \
401+
cntx \
402+
); \
403+
\
404+
/* Scale the bottom edge of C and add the result from above. */ \
405+
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
406+
ct, rs_ct, cs_ct, \
407+
beta_cast, \
408+
c11, rs_c, cs_c ); \
409+
} \
410+
} \
411+
} \
412+
} \
413+
else /* misbalancing == TRUE */ \
414+
{ \
415+
/* Fused number of threads in JR and IR loops */ \
416+
dim_t jrir_nt = jr_nt * ir_nt; \
344417
\
345-
a1 = a_cast + i * rstep_a; \
346-
c11 = c1 + i * rstep_c; \
418+
/* My thread id in the fused thread-domain */ \
419+
dim_t jrir_tid = ir_tid * jr_nt + jr_tid; \
347420
\
348-
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
421+
/* Build a temporary thrinfo_t for dispatching
422+
NOTE: Only n_way and work_id is needed for bli_thread_range_jrir() */ \
423+
thrinfo_t jrir_tinfo; \
424+
bli_thrinfo_init( &jrir_tinfo, NULL, 0, jrir_nt, jrir_tid, FALSE, BLIS_NO_PART, NULL ); \
349425
\
350-
/* Compute the addresses of the next panels of A and B. */ \
351-
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
352-
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
353-
{ \
354-
a2 = a_cast; \
355-
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
356-
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
357-
b2 = b_cast; \
358-
} \
426+
/* Dispatch (n_iter * m_iter) micro-tiles */ \
427+
dim_t jrir_start, jrir_end; \
428+
dim_t jrir_inc; \
429+
bli_thread_range_jrir( &jrir_tinfo, (n_iter * m_iter), 1, FALSE, \
430+
&jrir_start, &jrir_end, &jrir_inc ); \
431+
\
432+
/* Loop over the fused JR/IR dimension. */ \
433+
for ( dim_t ji = jrir_start; ji < jrir_end; ji += jrir_inc ) \
434+
{ \
435+
ctype* restrict a1; \
436+
ctype* restrict c11; \
437+
\
438+
/* Update current tile */ \
439+
j = ji % n_iter; \
440+
i = ji / n_iter; \
441+
a1 = a_cast + i * rstep_a; \
442+
b1 = b_cast + j * cstep_b; \
443+
\
444+
/* Next tile and panels */ \
445+
dim_t jnext = (ji + jrir_inc) % n_iter; \
446+
dim_t inext = (ji + jrir_inc) / n_iter; \
447+
ctype* restrict a2 = a_cast + inext * rstep_a; \
448+
ctype* restrict b2 = b_cast + jnext * cstep_b; \
359449
\
360450
/* Save addresses of next panels of A and B to the auxinfo_t
361451
object. */ \
362452
bli_auxinfo_set_next_a( a2, &aux ); \
363453
bli_auxinfo_set_next_b( b2, &aux ); \
454+
\
455+
c1 = c_cast + j * cstep_c; \
456+
c11 = c1 + i * rstep_c; \
457+
\
458+
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
459+
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
364460
\
365461
/* Handle interior and edge cases separately. */ \
366462
if ( m_cur == MR && n_cur == NR ) \
@@ -400,7 +496,7 @@ void PASTEMAC(ch,varname) \
400496
c11, rs_c, cs_c ); \
401497
} \
402498
} \
403-
} \
499+
} /* misbalancing */ \
404500
\
405501
/*
406502
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \

0 commit comments

Comments
 (0)