Skip to content

Proof-of-concept: speeding up gemm reference kernel #863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions ref_kernels/3/bli_gemm_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,89 @@ INSERT_GENTFUNCR_BASIC( gemm_gen, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )
// instructions via constant loop bounds + #pragma omp simd directives.
// If compile-time MR/NR are not available (indicated by BLIS_[MN]R_x = -1),
// then the non-unrolled version (above) is used.
// first the fastest case, 4 macros for m==mr, n==nr, k>0
// cs_c = 1, beta != 0 (row major)
// cs_c = 1, beta == 0
// rs_c = 1, beta != 0 (column major)
// rs_c = 1, beta == 0

#define TAIL_NITER 5 // in units of 4x k iterations
#define CACHELINE_SIZE 64
#define TAXPBYS_BETA0(ch1,ch2,ch3,ch4,ch5,alpha,ab,beta,c) bli_tscal2s(ch1,ch2,ch3,ch4,alpha,ab,c)
#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, arch, suf, taxpbys, i_or_j, j_or_i, mr_or_nr, nr_or_mr ) \
\
static void PASTEMAC(ch,ch,opname,arch,suf) \
( \
dim_t k, \
const ctype* alpha, \
const ctype* a, \
const ctype* b, \
const ctype* beta, \
ctype* c, inc_t s_c \
) \
{ \
const dim_t mr = PASTECH(BLIS_,mr_or_nr,_,ch); \
const dim_t nr = PASTECH(BLIS_,nr_or_mr,_,ch); \
\
const inc_t cs_a = PASTECH(BLIS_PACKMR_,ch); \
const inc_t rs_b = PASTECH(BLIS_PACKNR_,ch); \
\
char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is copied from the reference kernel, this zero-init is redundant to the just-following loop L195-198. I suggest you check in the assembler whether the compiler did eliminate one of them (and did manage to vectorize the zeroing as well).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zero-init is actually required for certain versions of clang since it improperly optimizes out some of the later zero assignments. See #854.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly I saw #854 as well because this puzzled me too. But the compiler generated optimal code like this (for my case with MR=32, NR=6 for Zen4, from objdump -d:

     664:       c5 d1 57 ed             vxorpd %xmm5,%xmm5,%xmm5
     668:       48 89 c8                mov    %rcx,%rax
     66b:       48 89 e5                mov    %rsp,%rbp
     66e:       41 54                   push   %r12
     670:       53                      push   %rbx
     671:       41 bc 1a 00 00 00       mov    $0x1a,%r12d
     677:       48 8b 5d 10             mov    0x10(%rbp),%rbx
     67b:       49 29 fc                sub    %rdi,%r12
     67e:       62 f1 fd 48 28 f5       vmovapd %zmm5,%zmm6
     684:       62 f1 fd 48 28 fd       vmovapd %zmm5,%zmm7
     68a:       4c 89 e7                mov    %r12,%rdi
     68d:       62 71 fd 48 28 c5       vmovapd %zmm5,%zmm8
     693:       62 71 fd 48 28 cd       vmovapd %zmm5,%zmm9

and then continuing all the way up to zmm28, so 24 zero'd vectors of 8 doubles each, which is exactly 6*32.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I did not make the link to #854, hmm it looks a weird issue.

ctype* ab = (ctype*)ab_; \
const inc_t s_ab = nr; \
\
\
/* Initialize the accumulator elements in ab to zero. */ \
PRAGMA_SIMD \
for ( dim_t i = 0; i < mr * nr; ++i ) \
{ \
bli_tset0s( ch, ab[ i ] ); \
} \
\
/* Perform a series of k rank-1 updates into ab. */ \
dim_t l = 0; do \
{ \
dim_t i = l + TAIL_NITER*4 + mr - k; \
if ( i >= 0 && i < mr ) \
for ( dim_t j = 0; j < nr; j += CACHELINE_SIZE/sizeof(double) ) \
bli_prefetch( &c[ i*s_c + j ], 0, 3 ); \
for ( dim_t i = 0; i < mr; ++i ) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since PRAGMA_SIMD was used, why not make use of #pragma unroll(n) on MR-loop as well in order to fill in the core pipeline ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about unroll pragmas but its syntax is different between compilers, e.g. for GCC its #pragma GCC unroll n
OpenMP 5.1 has common syntax for it but not yet supported (will be e.g. in GCC 15 but not yet in 14). In the end GCC will unroll small loops with a constant number of iterations all by itself with -O3, which is used here by default so I had pretty optimal code generated here:

     730:       62 f1 fd 48 10 1a       vmovupd (%rdx),%zmm3
     736:       62 f1 fd 48 10 52 01    vmovupd 0x40(%rdx),%zmm2
     73d:       62 f1 fd 48 10 4a 02    vmovupd 0x80(%rdx),%zmm1
     744:       48 ff c1                inc    %rcx
     747:       62 f1 fd 48 10 42 03    vmovupd 0xc0(%rdx),%zmm0
     74e:       62 f2 fd 48 19 20       vbroadcastsd (%rax),%zmm4
     754:       48 81 c2 00 01 00 00    add    $0x100,%rdx
     75b:       48 83 c0 30             add    $0x30,%rax
     75f:       4c 01 df                add    %r11,%rdi
     762:       62 62 dd 48 b8 e3       vfmadd231pd %zmm3,%zmm4,%zmm28
     768:       62 62 dd 48 b8 da       vfmadd231pd %zmm2,%zmm4,%zmm27
     76e:       62 62 dd 48 b8 d1       vfmadd231pd %zmm1,%zmm4,%zmm26
     774:       62 62 dd 48 b8 c8       vfmadd231pd %zmm0,%zmm4,%zmm25
     77a:       62 f2 fd 48 19 60 fb    vbroadcastsd -0x28(%rax),%zmm4
     781:       62 62 e5 48 b8 c4       vfmadd231pd %zmm4,%zmm3,%zmm24
     787:       62 e2 ed 48 b8 fc       vfmadd231pd %zmm4,%zmm2,%zmm23
     78d:       62 e2 f5 48 b8 f4       vfmadd231pd %zmm4,%zmm1,%zmm22
     793:       62 e2 fd 48 b8 ec       vfmadd231pd %zmm4,%zmm0,%zmm21
     799:       62 f2 fd 48 19 60 fc    vbroadcastsd -0x20(%rax),%zmm4
     7a0:       62 e2 e5 48 b8 e4       vfmadd231pd %zmm4,%zmm3,%zmm20
     7a6:       62 e2 ed 48 b8 dc       vfmadd231pd %zmm4,%zmm2,%zmm19
     7ac:       62 e2 f5 48 b8 d4       vfmadd231pd %zmm4,%zmm1,%zmm18
     7b2:       62 e2 fd 48 b8 cc       vfmadd231pd %zmm4,%zmm0,%zmm17
     7b8:       62 f2 fd 48 19 60 fd    vbroadcastsd -0x18(%rax),%zmm4
     7bf:       62 e2 e5 48 b8 c4       vfmadd231pd %zmm4,%zmm3,%zmm16
     7c5:       62 72 ed 48 b8 fc       vfmadd231pd %zmm4,%zmm2,%zmm15
     7cb:       62 72 f5 48 b8 f4       vfmadd231pd %zmm4,%zmm1,%zmm14
     7d1:       62 72 fd 48 b8 ec       vfmadd231pd %zmm4,%zmm0,%zmm13
     7d7:       62 f2 fd 48 19 60 fe    vbroadcastsd -0x10(%rax),%zmm4
     7de:       62 72 e5 48 b8 e4       vfmadd231pd %zmm4,%zmm3,%zmm12
     7e4:       62 72 ed 48 b8 dc       vfmadd231pd %zmm4,%zmm2,%zmm11
     7ea:       62 72 f5 48 b8 d4       vfmadd231pd %zmm4,%zmm1,%zmm10
     7f0:       62 72 fd 48 b8 cc       vfmadd231pd %zmm4,%zmm0,%zmm9
     7f6:       62 f2 fd 48 19 60 ff    vbroadcastsd -0x8(%rax),%zmm4
     7fd:       62 72 dd 48 b8 c3       vfmadd231pd %zmm3,%zmm4,%zmm8
     803:       62 f2 dd 48 b8 fa       vfmadd231pd %zmm2,%zmm4,%zmm7
     809:       62 f2 dd 48 b8 f1       vfmadd231pd %zmm1,%zmm4,%zmm6
     80f:       62 f2 dd 48 b8 e8       vfmadd231pd %zmm0,%zmm4,%zmm5
     815:       49 39 ca                cmp    %rcx,%r10
     818:       7e 28                   jle    842 <bli_ddgemm_vect_c_generic_ref+0x1e2>
     81a:       49 8d 1c 0c             lea    (%r12,%rcx,1),%rbx
     81e:       48 83 fb 05             cmp    $0x5,%rbx
     822:       0f 87 08 ff ff ff       ja     730 <bli_ddgemm_vect_c_generic_ref+0xd0>
     828:       0f 18 0f                prefetcht0 (%rdi)
     82b:       0f 18 4f 40             prefetcht0 0x40(%rdi)
     82f:       0f 18 8f 80 00 00 00    prefetcht0 0x80(%rdi)
     836:       0f 18 8f c0 00 00 00    prefetcht0 0xc0(%rdi)
     83d:       e9 ee fe ff ff          jmp    730 <bli_ddgemm_vect_c_generic_ref+0xd0>

{ \
PRAGMA_SIMD \
for ( dim_t j = 0; j < nr; ++j ) \
{ \
bli_tdots \
( \
ch,ch,ch,ch, \
a[ i_or_j ], \
b[ j_or_i ], \
ab[ i*s_ab + j ] \
); \
} \
} \
\
a += cs_a; \
b += rs_b; \
} while ( ++l < k ); \
\
for ( dim_t i = 0; i < mr; ++i ) \
PRAGMA_SIMD \
for ( dim_t j = 0; j < nr; ++j ) \
taxpbys \
Copy link
Contributor

@hominhquan hominhquan Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably what brought you some gain compared to the reference kernel. The scaling-by-alpha is done at the same time of accumulation-to-c by using AXBY (FMA).

Now the point is: This reference kernel was written to be simple-and-stupid, easily comprehensible and not aimed to be fast. Do we really want to make it a little harder to new people to understand, in exchange of some percentage of performance, given it was not the original purpose of this kernel.

BTW, maybe I was a bit paranoid, and perhaps a simple comment saying // Scaling ab by alpha and accumulate to c with AXBY() suffices to help the reader.

I would prefer some direct modification in the original reference kernel (PRAGMA_UNROLL, remove redundant ab-zero-init, AXBY-alpha-scaling-accumulation, or even __builtin_prefetch()) with well-explained comments, rather than four new reference kernels which will take more time to further people to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well there are already two original reference kernels; the slow version is the first one in the file called gemm_genxx already and gemm itself was already the fast path, as it says in the comment:

// An implementation that attempts to facilitate emission of vectorized
// instructions via constant loop bounds + #pragma omp simd directives.
// If compile-time MR/NR are not available (indicated by BLIS_[MN]R_x = -1),
// then the non-unrolled version (above) is used.

If I try to make the original fast path faster I simply don't get the same speed ups because the whole C tile is spilled to memory and I might as well not change anything.

An alternative also would be to have the 4 new kernels doing the only fast path, and let all other oddball cases use gemm_genxx, ie. changing the if in

        /* If either BLIS_MR_? or BLIS_NR_? was left undefined by the subconfig,
           the compiler can't fully unroll the MR and NR loop iterations below,
           which means there's no benefit to using this kernel over a general-
           purpose implementation instead. */ \
        if ( mr == -1 || nr == -1 || rs_a != 1 || cs_b != 1 ) \
        { \
                PASTEMAC(ch,ch,gemm_gen,arch,suf) \

to

      if ( mr != m || nr != n || rs_a != 1 || cs_b != 1 || (cs_c != 1 && rs_c != 1) || k == 0) \

or even flipping it around: the slow version is the main reference gemm and the fast version is called via that if statement flipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have to think if the k==0 restriction can be potentially lifted for all gemm kernels, in which case modifying the check as suggested would be fine.

( \
ch,ch,ch,ch,ch, \
*alpha, \
ab[ i*s_ab + j ], \
*beta, \
c [ i*s_c + j ] \
); \
}

INSERT_GENTFUNC_BASIC( gemm_vect_r_beta0, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, TAXPBYS_BETA0, i, j, MR, NR )
INSERT_GENTFUNC_BASIC( gemm_vect_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, bli_taxpbys, i, j, MR, NR )
INSERT_GENTFUNC_BASIC( gemm_vect_c_beta0, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, TAXPBYS_BETA0, j, i, NR, MR )
INSERT_GENTFUNC_BASIC( gemm_vect_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, bli_taxpbys, j, i, NR, MR )

#undef GENTFUNC
#define GENTFUNC( ctype, ch, opname, arch, suf ) \
Expand Down Expand Up @@ -210,6 +293,36 @@ void PASTEMAC(ch,ch,opname,arch,suf) \
); \
return; \
} \
\
if ( m == mr && n == nr && k > 0 ) \
{ \
if ( cs_c == 1 ) \
{ \
(bli_teq0s( ch, *beta ) ? PASTEMAC(ch,ch,gemm_vect_r_beta0,arch,suf) : PASTEMAC(ch,ch,gemm_vect_r,arch,suf)) \
( \
k, \
alpha, \
a, \
b, \
beta, \
c, rs_c \
); \
return; \
} \
if ( rs_c == 1 ) \
{ \
(bli_teq0s( ch, *beta ) ? PASTEMAC(ch,ch,gemm_vect_c_beta0,arch,suf) : PASTEMAC(ch,ch,gemm_vect_c,arch,suf)) \
( \
k, \
alpha, \
a, \
b, \
beta, \
c, cs_c \
); \
return; \
} \
} \
\
char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \
ctype* ab = (ctype*)ab_; \
Expand Down Expand Up @@ -382,5 +495,3 @@ void PASTEMAC(chab,chc,opname,arch,suf) \
}

INSERT_GENTFUNC2_MIX_P( gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX )