@@ -76,14 +76,18 @@ void compute_ref(
76
76
size_t m, size_t k, size_t n,
77
77
D_Type* ref_d_data,
78
78
float * ref_d_amax,
79
- Gelu_Type* ref_gelu_data){
79
+ Gelu_Type* ref_gelu_data,
80
+ bool transa,
81
+ bool transb){
80
82
81
83
*ref_d_amax = 0 ;
82
84
for (size_t ii = 0 ; ii < m; ii++){
83
85
for (size_t jj = 0 ; jj < n; jj++){
84
86
float val = 0 ;
85
87
for (size_t kk = 0 ; kk < k; kk++){
86
- val += a_scale_inv*b_scale_inv*((float )a_data[ii + kk*m])*((float )b_data[kk + jj*k]);
88
+ float a_val = transa ? (float )a_data[kk + ii*k] : (float )a_data[ii + kk*m];
89
+ float b_val = transb ? (float )b_data[jj + kk*n] : (float )b_data[kk + jj*k];
90
+ val += a_scale_inv*b_scale_inv*a_val*b_val;
87
91
}
88
92
if (bias_data){
89
93
val += (float )bias_data[ii];
@@ -103,16 +107,24 @@ void compute_ref(
103
107
}
104
108
105
109
template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
106
- void performTest (bool use_bias, bool use_gelu, const size_t m, const size_t k, const size_t n) {
110
+ void performTest (bool use_bias, bool use_gelu, const size_t m, const size_t k, const size_t n, char transa_char = ' N ' , char transb_char = ' N ' ) {
107
111
DType atype = TypeInfo<A_Type>::dtype;
108
112
DType btype = TypeInfo<B_Type>::dtype;
109
113
DType bias_type = TypeInfo<Bias_Type>::dtype;
110
114
DType gelu_type = TypeInfo<Gelu_Type>::dtype;
111
115
DType dtype = TypeInfo<D_Type>::dtype;
112
-
116
+ bool transa = (transa_char == ' T' || transa_char == ' t' );
117
+ bool transb = (transb_char == ' T' || transb_char == ' t' );
118
+
113
119
// pytorch tensor storage is row-major while cublas/rocblas is column-major
114
120
Tensor A ({ k, m }, atype);
121
+ if (transa){
122
+ A = Tensor ({ m, k }, atype);
123
+ }
115
124
Tensor B ({ n, k }, btype);
125
+ if (transb){
126
+ B = Tensor ({ k, n }, atype);
127
+ }
116
128
Tensor D ({ n, m }, dtype);
117
129
Tensor bias;
118
130
if (use_bias){
@@ -133,8 +145,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
133
145
if (isFp8Type (dtype)){
134
146
setRandomScale (&D);
135
147
}
136
- bool transa = false ;
137
- bool transb = false ;
148
+
138
149
bool grad = false ;
139
150
bool accumulate = false ;
140
151
@@ -189,7 +200,9 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
189
200
m, k, n,
190
201
ref_D.get (),
191
202
&ref_amax_d,
192
- use_gelu? ref_pre_gelu_out.get (): nullptr );
203
+ use_gelu? ref_pre_gelu_out.get (): nullptr ,
204
+ transa,
205
+ transb);
193
206
// check if error message happens in running
194
207
cudaDeviceSynchronize ();
195
208
auto err = cudaGetLastError ();
@@ -221,7 +234,28 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
221
234
using fp32=float ;
222
235
using fp8=fp8e4m3;
223
236
using bf8=fp8e5m2;
224
-
237
+
238
+ TEST_P (GEMMTestSuite, Testfp8xfp8xfp16xfp16xfp8) {
239
+ using namespace transformer_engine ;
240
+ using namespace test ;
241
+
242
+ const size_t m = std::get<0 >(std::get<0 >(GetParam ()));
243
+ const size_t k = std::get<1 >(std::get<0 >(GetParam ()));
244
+ const size_t n = std::get<2 >(std::get<0 >(GetParam ()));
245
+ const bool use_bias = std::get<1 >(GetParam ());
246
+ const bool use_gelu = std::get<2 >(GetParam ());
247
+ char transa_char = ' T' ;
248
+ char transb_char = ' N' ;
249
+
250
+ using A_Type = fp8;
251
+ using B_Type = fp8;
252
+ using Bias_Type = fp16;
253
+ using Gelu_Type = fp16;
254
+ using D_Type = fp8;
255
+
256
+ performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n, transa_char, transb_char);
257
+ }
258
+
225
259
TEST_P (GEMMTestSuite, Testfp32xfp32xfp32xfp32xfp32) {
226
260
using namespace transformer_engine ;
227
261
using namespace test ;
0 commit comments