20
20
#if !defined(ELEM_TYPE )
21
21
# define ELEM_TYPE double
22
22
#endif
23
+ #if !defined(EPSILON )
24
+ # define EPSILON 1E-3
25
+ #endif
23
26
#if !defined(MAX_KERNEL_DIM )
24
27
# define MAX_KERNEL_DIM 80
25
28
#endif
@@ -67,44 +70,66 @@ int main(int argc, char* argv[])
67
70
const int mn = m * n , mk = m * k , kn = k * n ;
68
71
#endif
69
72
#if defined(WARMUP ) && (0 < WARMUP ) && !defined(_DEBUG )
70
- const int warmup = WARMUP ;
73
+ const int warmup = MAX ( WARMUP , 2 ) / 2 * 2 ;
71
74
#else
72
75
const int warmup = 0 ;
73
76
#endif
74
- int * stack_hst = NULL , * stack_dev = NULL ;
77
+ int * stack_hst = NULL , * stack_dev = NULL , * trans_hst = NULL , * trans_dev = NULL ;
75
78
ELEM_TYPE * amat_hst = NULL , * bmat_hst = NULL , * cmat_hst = NULL ;
76
79
ELEM_TYPE * amat_dev = NULL , * bmat_dev = NULL , * cmat_dev = NULL ;
77
- int result = EXIT_SUCCESS , r , i ;
80
+ int result = EXIT_SUCCESS , ndevices = 0 , r , i ;
78
81
void * stream = NULL ;
79
82
#if defined(USE_LIBXSMM )
80
83
libxsmm_timer_tickint start ;
81
- double duration ;
84
+ double duration , transpose ;
82
85
#endif
83
86
assert (m <= (mn / n ) && 0 == (mn % n ) && k <= (mk / k ) && 0 == (mk % k ) && n <= (kn / n ) && 0 == (kn % n ));
84
- printf ("%s%s%i %i %i %i %i\n" , 0 < argc ? argv [0 ] : "" , 0 < argc ? " " : "" , nrepeat , stack_size , m , n , k );
87
+ printf ("%s%s%i %i %i %i %i %i %i %i\n" , 0 < argc ? argv [0 ] : "" , 0 < argc ? " " : "" ,
88
+ nrepeat , stack_size , m , n , k , nc , na , nb );
85
89
CHECK (acc_init (), & result );
90
+ CHECK (acc_get_ndevices (& ndevices ), & result );
91
+ if (0 < ndevices ) {
92
+ #if defined(_DEBUG )
93
+ fprintf (stderr , "number of devices found: %i\n" , ndevices );
94
+ #endif
95
+ }
96
+ else {
97
+ #if defined(_DEBUG )
98
+ fprintf (stderr , "Error: no device found!\n" );
99
+ #endif
100
+ CHECK (acc_finalize (), NULL );
101
+ return result ;
102
+ }
103
+ printf ("element type: %s\n" , DBCSR_STRINGIFY (ELEM_TYPE ));
86
104
CHECK (acc_stream_create (& stream , "stream" , -1 /*default priority*/ ), & result );
87
- CHECK (acc_host_mem_allocate ((void * * )& amat_hst , sizeof (ELEM_TYPE ) * mk * stack_size , stream ), & result );
88
- CHECK (acc_host_mem_allocate ((void * * )& bmat_hst , sizeof (ELEM_TYPE ) * kn * stack_size , stream ), & result );
89
- CHECK (acc_host_mem_allocate ((void * * )& cmat_hst , sizeof (ELEM_TYPE ) * mn * stack_size , stream ), & result );
105
+ CHECK (acc_host_mem_allocate ((void * * )& amat_hst , sizeof (ELEM_TYPE ) * mk * na , stream ), & result );
106
+ CHECK (acc_host_mem_allocate ((void * * )& bmat_hst , sizeof (ELEM_TYPE ) * kn * nb , stream ), & result );
107
+ CHECK (acc_host_mem_allocate ((void * * )& cmat_hst , sizeof (ELEM_TYPE ) * mn * nc , stream ), & result );
90
108
CHECK (acc_host_mem_allocate ((void * * )& stack_hst , sizeof (int ) * 3 * stack_size , stream ), & result );
109
+ CHECK (acc_host_mem_allocate ((void * * )& trans_hst , sizeof (int ) * nb , stream ), & result );
91
110
CHECK (acc_stream_sync (stream ), & result ); /* ensure host-data is allocated */
92
- for (i = 0 ; i < stack_size ; ++ i ) { /* initialize matrices */
111
+ /* initialize matrices */
112
+ for (i = 0 ; i < na ; ++ i ) {
93
113
init (i /*seed*/ + 42 , & amat_hst [i * mk ], m , k );
114
+ }
115
+ for (i = 0 ; i < nb ; ++ i ) {
94
116
init (i /*seed*/ + 24 , & bmat_hst [i * kn ], k , n );
117
+ trans_hst [i ] = i * kn ;
95
118
}
96
119
init_stack (stack_hst , stack_size , mn , mk , kn , nc , na , nb );
97
- CHECK (acc_dev_mem_allocate ((void * * )& amat_dev , sizeof (ELEM_TYPE ) * mk * stack_size ), & result );
98
- CHECK (acc_dev_mem_allocate ((void * * )& bmat_dev , sizeof (ELEM_TYPE ) * kn * stack_size ), & result );
99
- CHECK (acc_dev_mem_allocate ((void * * )& cmat_dev , sizeof (ELEM_TYPE ) * mn * stack_size ), & result );
120
+ CHECK (acc_dev_mem_allocate ((void * * )& amat_dev , sizeof (ELEM_TYPE ) * mk * na ), & result );
121
+ CHECK (acc_dev_mem_allocate ((void * * )& bmat_dev , sizeof (ELEM_TYPE ) * kn * nb ), & result );
122
+ CHECK (acc_dev_mem_allocate ((void * * )& cmat_dev , sizeof (ELEM_TYPE ) * mn * nc ), & result );
100
123
CHECK (acc_dev_mem_allocate ((void * * )& stack_dev , sizeof (int ) * 3 * stack_size ), & result );
101
- CHECK (acc_memset_zero (cmat_dev , 0 /*offset*/ , sizeof (ELEM_TYPE ) * mn * stack_size , stream ), & result );
124
+ CHECK (acc_dev_mem_allocate ((void * * )& trans_dev , sizeof (int ) * nb ), & result );
125
+ CHECK (acc_memset_zero (cmat_dev , 0 /*offset*/ , sizeof (ELEM_TYPE ) * mn * nc , stream ), & result );
126
+ CHECK (acc_memcpy_h2d (trans_hst , trans_dev , sizeof (int ) * nb , stream ), & result );
102
127
#if defined(USE_LIBXSMM )
103
128
CHECK (acc_stream_sync (stream ), & result );
104
129
start = libxsmm_timer_tick ();
105
130
#endif
106
- CHECK (acc_memcpy_h2d (amat_hst , amat_dev , sizeof (ELEM_TYPE ) * mk * stack_size , stream ), & result );
107
- CHECK (acc_memcpy_h2d (bmat_hst , bmat_dev , sizeof (ELEM_TYPE ) * kn * stack_size , stream ), & result );
131
+ CHECK (acc_memcpy_h2d (amat_hst , amat_dev , sizeof (ELEM_TYPE ) * mk * na , stream ), & result );
132
+ CHECK (acc_memcpy_h2d (bmat_hst , bmat_dev , sizeof (ELEM_TYPE ) * kn * nb , stream ), & result );
108
133
CHECK (acc_memcpy_h2d (stack_hst , stack_dev , sizeof (int ) * 3 * stack_size , stream ), & result );
109
134
#if defined(USE_LIBXSMM )
110
135
CHECK (acc_stream_sync (stream ), & result );
@@ -113,55 +138,118 @@ int main(int argc, char* argv[])
113
138
(sizeof (ELEM_TYPE ) * (mk + kn ) + sizeof (int ) * 3 )
114
139
* stack_size / (duration * (1ULL << 30 )));
115
140
#endif
116
- /* warmup execution and prebuild JIT kernels */
141
+ /* warmup execution and prebuild transpose-kernel */
142
+ for (r = 0 ; r < warmup / 2 ; ++ r ) {
143
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev ,
144
+ DBCSR_TYPE (ELEM_TYPE ), k , n , MAX_KERNEL_DIM , stream ), & result );
145
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev ,
146
+ DBCSR_TYPE (ELEM_TYPE ), n , k , MAX_KERNEL_DIM , stream ), & result );
147
+ }
148
+ #if defined(USE_LIBXSMM )
149
+ CHECK (acc_stream_sync (stream ), & result );
150
+ start = libxsmm_timer_tick ();
151
+ #endif
152
+ /* to perform NN-SMMs on the device, all B-matrices are transposed upfront (SMM-kernel is limited to NT) */
153
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev ,
154
+ DBCSR_TYPE (ELEM_TYPE ), k , n , MAX_KERNEL_DIM , stream ), & result );
155
+ #if defined(USE_LIBXSMM )
156
+ CHECK (acc_stream_sync (stream ), & result );
157
+ transpose = libxsmm_timer_duration (start , libxsmm_timer_tick ());
158
+ #endif
159
+ /* warmup execution and prebuild SMM-kernel */
117
160
for (r = 0 ; r < warmup ; ++ r ) {
118
161
CHECK (libsmm_acc_process (stack_hst , stack_dev , stack_size , 3 /*nparams*/ , DBCSR_TYPE (ELEM_TYPE ),
119
162
amat_dev , bmat_dev , cmat_dev , m , n , k , MAX_KERNEL_DIM , 1 /*homogeneous*/ , stream , stream ), & result );
120
163
}
164
+ CHECK (acc_memset_zero (cmat_dev , 0 /*offset*/ , sizeof (ELEM_TYPE ) * mn * nc , stream ), & result );
121
165
#if defined(USE_LIBXSMM )
122
166
CHECK (acc_stream_sync (stream ), & result );
123
167
start = libxsmm_timer_tick ();
124
168
#endif
125
169
for (r = 0 ; r < nrepeat ; ++ r ) {
126
- /* GPU-kernel is limited to C += Ai * Bi^T ( i.e., NT, for NN, all Bi must be transposed upfront) */
170
+ /* GPU-kernel is limited to C += Ai * Bi^T, i.e., NT ( for NN, all Bi must be transposed upfront) */
127
171
CHECK (libsmm_acc_process (stack_hst , stack_dev , stack_size , 3 /*nparams*/ , DBCSR_TYPE (ELEM_TYPE ),
128
172
amat_dev , bmat_dev , cmat_dev , m , n , k , MAX_KERNEL_DIM , 1 /*homogeneous*/ , stream , stream ), & result );
129
173
}
130
174
#if defined(USE_LIBXSMM )
131
175
CHECK (acc_stream_sync (stream ), & result );
132
176
duration = libxsmm_timer_duration (start , libxsmm_timer_tick ());
133
177
if (EXIT_SUCCESS == result ) {
134
- const char transa = 'N' , transb = 'T' ;
178
+ ELEM_TYPE * const gold_hst = (ELEM_TYPE * )libxsmm_malloc (sizeof (ELEM_TYPE ) * mn * nc );
179
+ const char transa = 'N' , transb = 'N' ;
135
180
const ELEM_TYPE alpha = 1 , beta = 1 ;
181
+ printf ("transpose: %.1f ms %.1f GFLOPS/s\n" , 1000.0 * (duration + transpose ) / nrepeat ,
182
+ ((size_t )2 * m * n * k ) * stack_size / ((duration + transpose ) * (1ULL << 30 ) / nrepeat ));
136
183
printf ("device: %.1f ms %.1f GFLOPS/s\n" , 1000.0 * duration / nrepeat ,
137
184
((size_t )2 * m * n * k ) * stack_size / (duration * (1ULL << 30 ) / nrepeat ));
138
- memset (cmat_hst , 0 , sizeof (ELEM_TYPE ) * mn * stack_size );
185
+ memset (gold_hst , 0 , sizeof (ELEM_TYPE ) * mn * nc );
186
+ for (r = 0 ; r < warmup ; ++ r ) {
187
+ libxsmm_gemm_batch_omp (LIBXSMM_GEMM_PRECISION (ELEM_TYPE ), LIBXSMM_GEMM_PRECISION (ELEM_TYPE ),
188
+ & transa , & transb , m , n , k , & alpha , amat_hst , & m /*lda*/ , bmat_hst , & k /*ldb*/ ,
189
+ & beta , gold_hst , & m /*ldc*/ , 1 /*index_base*/ , sizeof (int ) * 3 ,
190
+ stack_hst + 0 , stack_hst + 1 , stack_hst + 2 , stack_size );
191
+ }
192
+ memset (gold_hst , 0 , sizeof (ELEM_TYPE ) * mn * nc );
139
193
start = libxsmm_timer_tick ();
194
+ /* CPU-kernel operates on data that is not initialized in NUMA-aware fashion */
140
195
for (r = 0 ; r < nrepeat ; ++ r ) {
141
- /* CPU-kernel performs C += Ai * Bi^T to match result of GPU-kernel (NT may perform below NN) */
142
196
libxsmm_gemm_batch_omp (LIBXSMM_GEMM_PRECISION (ELEM_TYPE ), LIBXSMM_GEMM_PRECISION (ELEM_TYPE ),
143
197
& transa , & transb , m , n , k , & alpha , amat_hst , & m /*lda*/ , bmat_hst , & k /*ldb*/ ,
144
- & beta , cmat_hst , & m /*ldc*/ , 1 /*index_base*/ , sizeof (int ) * 3 ,
198
+ & beta , gold_hst , & m /*ldc*/ , 1 /*index_base*/ , sizeof (int ) * 3 ,
145
199
stack_hst + 0 , stack_hst + 1 , stack_hst + 2 , stack_size );
146
200
}
147
201
duration = libxsmm_timer_duration (start , libxsmm_timer_tick ());
148
202
printf ("host: %.1f ms %.1f GFLOPS/s\n" , 1000.0 * duration / nrepeat ,
149
203
((size_t )2 * m * n * k ) * stack_size / (duration * (1ULL << 30 ) / nrepeat ));
150
- /* transfer result from device back to host for validation */
151
- CHECK (acc_memcpy_d2h (cmat_dev , cmat_hst , sizeof (ELEM_TYPE ) * mn * stack_size , stream ), & result );
204
+ /* transfer result from device to host for validation */
205
+ CHECK (acc_memcpy_d2h (cmat_dev , cmat_hst , sizeof (ELEM_TYPE ) * mn * nc , stream ), & result );
152
206
CHECK (acc_stream_sync (stream ), & result );
153
- /* TODO: validation code TBD */
207
+ if (EXIT_SUCCESS == result ) {
208
+ double abserror = 0 , relerror = 0 ;
209
+ for (i = 0 ; i < nc ; ++ i ) {
210
+ const ELEM_TYPE * const gold = gold_hst + mn * i ;
211
+ const ELEM_TYPE * const test = cmat_hst + mn * i ;
212
+ double diff = 0 , a = 0 , b = 0 ;
213
+ for (r = 0 ; r < (m * n ); ++ r ) {
214
+ const double ar = (double )gold [r ];
215
+ const double br = (double )test [r ];
216
+ const double d = fabs (ar - br );
217
+ if (d > diff ) {
218
+ diff = d ;
219
+ a = ar ;
220
+ b = br ;
221
+ }
222
+ }
223
+ if (0 < diff ) {
224
+ # if defined(_DEBUG )
225
+ print (stderr , "gold = " , gold , m , n );
226
+ print (stderr , "test = " , test , m , n );
227
+ fprintf (stderr , "diff = %g (%g != %g)\n" , diff , a , b );
228
+ # endif
229
+ if (abserror < diff ) {
230
+ relerror = fabs (0 != a ? (diff / a ) : (diff / b ));
231
+ abserror = diff ;
232
+ }
233
+ }
234
+ }
235
+ printf ("max.error: rel=%g\n" , relerror );
236
+ if (EPSILON < relerror ) result = EXIT_FAILURE ;
237
+ }
238
+ libxsmm_free (gold_hst );
154
239
}
155
240
#endif
156
241
CHECK (acc_host_mem_deallocate (stack_hst , stream ), NULL );
242
+ CHECK (acc_host_mem_deallocate (trans_hst , stream ), NULL );
157
243
CHECK (acc_host_mem_deallocate (amat_hst , stream ), NULL );
158
244
CHECK (acc_host_mem_deallocate (bmat_hst , stream ), NULL );
159
245
CHECK (acc_host_mem_deallocate (cmat_hst , stream ), NULL );
160
246
CHECK (acc_dev_mem_deallocate (stack_dev ), NULL );
247
+ CHECK (acc_dev_mem_deallocate (trans_dev ), NULL );
161
248
CHECK (acc_dev_mem_deallocate (amat_dev ), NULL );
162
249
CHECK (acc_dev_mem_deallocate (bmat_dev ), NULL );
163
250
CHECK (acc_dev_mem_deallocate (cmat_dev ), NULL );
164
251
CHECK (acc_stream_destroy (stream ), NULL );
252
+ CHECK (acc_finalize (), NULL );
165
253
if (EXIT_SUCCESS != result ) {
166
254
fprintf (stderr , "FAILED\n" );
167
255
}
0 commit comments