Skip to content

Commit 091ab04

Browse files
author
Veera Gopu
committed
Added fp8 gemm gelu_aux_bias support
1 parent 385d10e commit 091ab04

File tree

3 files changed

+73
-10
lines changed

3 files changed

+73
-10
lines changed

ci/core.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ for _gemm in hipblaslt rocblas; do
3030
fi
3131
echo ===== Run GEMM $_gemm tests =====
3232
ctest --test-dir build -j4 -R "OperatorTest/GEMMTestSuite" $_exclude
33+
# fp8 GELU_AUX_BIAS tests
34+
ctest --test-dir build -j4 -R "OperatorTest/GEMMTestSuite.Testfp8xfp8xfp16xfp16xfp8/.*"
3335
test $? -eq 0 || test_run_error
3436
done
3537

tests/cpp/operator/test_cublaslt_gemm.cu

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,18 @@ void compute_ref(
7676
size_t m, size_t k, size_t n,
7777
D_Type* ref_d_data,
7878
float* ref_d_amax,
79-
Gelu_Type* ref_gelu_data){
79+
Gelu_Type* ref_gelu_data,
80+
bool transa,
81+
bool transb){
8082

8183
*ref_d_amax = 0;
8284
for(size_t ii = 0; ii < m; ii++){
8385
for(size_t jj = 0; jj < n; jj++){
8486
float val = 0;
8587
for(size_t kk = 0; kk < k; kk++){
86-
val += a_scale_inv*b_scale_inv*((float)a_data[ii + kk*m])*((float)b_data[kk + jj*k]);
88+
float a_val = transa ? (float)a_data[kk + ii*k] : (float)a_data[ii + kk*m];
89+
float b_val = transb ? (float)b_data[jj + kk*n] : (float)b_data[kk + jj*k];
90+
val += a_scale_inv*b_scale_inv*a_val*b_val;
8791
}
8892
if(bias_data){
8993
val += (float)bias_data[ii];
@@ -103,16 +107,24 @@ void compute_ref(
103107
}
104108

105109
template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type>
106-
void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, const size_t n) {
110+
void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, const size_t n, char transa_char = 'N', char transb_char = 'N') {
107111
DType atype = TypeInfo<A_Type>::dtype;
108112
DType btype = TypeInfo<B_Type>::dtype;
109113
DType bias_type = TypeInfo<Bias_Type>::dtype;
110114
DType gelu_type = TypeInfo<Gelu_Type>::dtype;
111115
DType dtype = TypeInfo<D_Type>::dtype;
112-
116+
bool transa = (transa_char == 'T' || transa_char == 't');
117+
bool transb = (transb_char == 'T' || transb_char == 't');
118+
113119
// pytorch tensor storage is row-major while cublas/rocblas is column-major
114120
Tensor A({ k, m }, atype);
121+
if (transa){
122+
A = Tensor({ m, k }, atype);
123+
}
115124
Tensor B({ n, k }, btype);
125+
if (transb){
126+
B = Tensor({ k, n }, atype);
127+
}
116128
Tensor D({ n, m }, dtype);
117129
Tensor bias;
118130
if(use_bias){
@@ -133,8 +145,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
133145
if(isFp8Type(dtype)){
134146
setRandomScale(&D);
135147
}
136-
bool transa = false;
137-
bool transb = false;
148+
138149
bool grad = false;
139150
bool accumulate = false;
140151

@@ -189,7 +200,9 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
189200
m, k, n,
190201
ref_D.get(),
191202
&ref_amax_d,
192-
use_gelu? ref_pre_gelu_out.get(): nullptr);
203+
use_gelu? ref_pre_gelu_out.get(): nullptr,
204+
transa,
205+
transb);
193206
// check if error message happens in running
194207
cudaDeviceSynchronize();
195208
auto err = cudaGetLastError();
@@ -221,7 +234,28 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
221234
using fp32=float;
222235
using fp8=fp8e4m3;
223236
using bf8=fp8e5m2;
224-
237+
238+
TEST_P(GEMMTestSuite, Testfp8xfp8xfp16xfp16xfp8) {
239+
using namespace transformer_engine;
240+
using namespace test;
241+
242+
const size_t m = std::get<0>(std::get<0>(GetParam()));
243+
const size_t k = std::get<1>(std::get<0>(GetParam()));
244+
const size_t n = std::get<2>(std::get<0>(GetParam()));
245+
const bool use_bias = std::get<1>(GetParam());
246+
const bool use_gelu = std::get<2>(GetParam());
247+
char transa_char = 'T';
248+
char transb_char = 'N';
249+
250+
using A_Type = fp8;
251+
using B_Type = fp8;
252+
using Bias_Type = fp16;
253+
using Gelu_Type = fp16;
254+
using D_Type = fp8;
255+
256+
performTest<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(use_bias, use_gelu, m, k, n, transa_char, transb_char);
257+
}
258+
225259
TEST_P(GEMMTestSuite, Testfp32xfp32xfp32xfp32xfp32) {
226260
using namespace transformer_engine;
227261
using namespace test;

transformer_engine/common/gemm/rocm_gemm.cu

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,8 @@ void hipblaslt_gemm(const Tensor *inputA,
10091009
void *B = inputB->data.dptr;
10101010
void *B_scale_inverse = inputB->scale_inv.dptr;
10111011
void *D = outputD->data.dptr;
1012+
void *D_amax = outputD->amax.dptr;
1013+
void *D_scale = outputD->scale.dptr;
10121014
void *bias_ptr = inputBias->data.dptr;
10131015
const bool bias = bias_ptr != nullptr;
10141016
void *pre_gelu_out = outputPreGelu->data.dptr;
@@ -1028,8 +1030,16 @@ void hipblaslt_gemm(const Tensor *inputA,
10281030
// check consistency of arguments:
10291031
// if fp8 is desired, context cannot be null
10301032
// fp8 + gelu fusion + fp8 aux is unavailable right now.
1031-
if (use_fp8) {
1032-
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
1033+
const hipblasltDatatype_t aux_type = get_hipblaslt_dtype(outputPreGelu->data.dtype);
1034+
bool allow_fp8_gemm = (A_type == HIPBLASLT_R_8F_E4M3) &&
1035+
(B_type == HIPBLASLT_R_8F_E4M3) &&
1036+
(D_type == HIPBLASLT_R_8F_E4M3) &&
1037+
(bias_type == HIPBLASLT_R_16F) &&
1038+
(aux_type == HIPBLASLT_R_16F);
1039+
if(!allow_fp8_gemm){
1040+
if (use_fp8) {
1041+
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
1042+
}
10331043
}
10341044
float one = 1.0;
10351045
float zero = 0.0;
@@ -1091,11 +1101,28 @@ void hipblaslt_gemm(const Tensor *inputA,
10911101
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
10921102
&B_scale_inverse,
10931103
sizeof(B_scale_inverse)));
1104+
if (is_fp8_dtype(outputD->data.dtype)) {
1105+
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
1106+
HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER,
1107+
&D_amax,
1108+
sizeof(D_amax)));
1109+
1110+
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
1111+
HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER ,
1112+
&D_scale,
1113+
sizeof(D_scale)));
1114+
}
10941115
if (bias) {
10951116
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
10961117
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
10971118
&bias_type, sizeof(bias_type)));
10981119
}
1120+
if (gelu){
1121+
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
1122+
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
1123+
&aux_type,
1124+
sizeof(aux_type)));
1125+
}
10991126
}
11001127

11011128
if (bias && gelu) {

0 commit comments

Comments
 (0)