-
Notifications
You must be signed in to change notification settings - Fork 1.1k
fp8 support #2989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fp8 support #2989
Conversation
Current test failures:
|
#if __CUDA_ARCH__ < 700 | ||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd | ||
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. | ||
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 | ||
__device__ __half atomicAdd(__half *address, __half val) { | ||
//__device__ __half atomicAdd(__half *address, __half val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsure why just this signature was present
@@ -25,6 +25,18 @@ constexpr uint8_t max_value<uint8_t>() { | |||
return 0xFFu; | |||
} | |||
|
|||
template <> | |||
__host__ __device__ | |||
constexpr int32_t max_value<int32_t>() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these symbols were missing when cuda bindgen ran for some reason
@@ -36,6 +36,12 @@ extern "C" __global__ void FN_NAME( \ | |||
WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) | |||
WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) | |||
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) | |||
|
|||
WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__CUDA_ARCH__
guard should be 890
here
fp8 seems somewhat slower than I would expect in the candle benchmark harness (this is on a GH100):
Probably because we're double-casting from fp8->half->f32? #define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) |
@LaurentMazare let me know if this directionally looks good, happy to make any changes to the approach if needed. |
I fixed a couple things for the CI's to pass, but besides that It looks good to me. |
Thanks for taking a look @greenrazer! If there's any hesitance to merge as-is, would putting it behind a feature help? We'd probably have to leave the kernel additions but could gate all of the Rust code. |
@zackangelo can you confirm that the CUDA build works on CC > 8 and CC < 8 (i.e. maintaining compatability)? |
@EricLBuehler I tested an earlier build but would probably be worth getting some time on an A100 and verifying again, I'll see if I can get around to that today or tomorrow |
Plucked from @EricLBuehler's work in #2745.
This implements fp8 operations where they are straightforward. Many fp8 ops can't be implemented because they require a scale tensor alongside the main tensor to compensate for fp8's limited dynamic range (e.g. matmul).