Skip to content

fp8 support #2989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

fp8 support #2989

wants to merge 11 commits into from

Conversation

zackangelo
Copy link
Contributor

@zackangelo zackangelo commented Jun 11, 2025

Plucked from @EricLBuehler's work in #2745.

This implements fp8 operations where they are straightforward. Many fp8 ops can't be implemented because they require a scale tensor alongside the main tensor to compensate for fp8's limited dynamic range (e.g. matmul).

@zackangelo
Copy link
Contributor Author

Current test failures:

failures:

---- gather_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- embeddings_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- asort_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- scatter_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- index_select_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")

---- index_add_gpu stdout ----
Error: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")


failures:
    asort_gpu
    embeddings_gpu
    gather_gpu
    index_add_gpu
    index_select_gpu
    scatter_gpu

#if __CUDA_ARCH__ < 700
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
//__device__ __half atomicAdd(__half *address, __half val) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsure why just this signature was present

@@ -25,6 +25,18 @@ constexpr uint8_t max_value<uint8_t>() {
return 0xFFu;
}

template <>
__host__ __device__
constexpr int32_t max_value<int32_t>() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these symbols were missing when cuda bindgen ran for some reason

@@ -36,6 +36,12 @@ extern "C" __global__ void FN_NAME( \
WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)

WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__CUDA_ARCH__ guard should be 890 here

@zackangelo
Copy link
Contributor Author

zackangelo commented Jun 12, 2025

fp8 seems somewhat slower than I would expect in the candle benchmark harness (this is on a GH100):

cuda_affine_f32/iter    time:   [5.5286 µs 5.5322 µs 5.5349 µs]
                        thrpt:  [705.74 GiB/s 706.10 GiB/s 706.55 GiB/s]
                 change:
                        time:   [-0.0349% +0.0724% +0.1626%] (p = 0.17 > 0.05)
                        thrpt:  [-0.1623% -0.0723% +0.0349%]
                        No change in performance detected.
Found 11 outliers among 100 measurements (11.00%)
  6 (6.00%) low severe
  4 (4.00%) low mild
  1 (1.00%) high severe

cuda_affine_f16/iter    time:   [5.4081 µs 5.4135 µs 5.4180 µs]
                        thrpt:  [360.49 GiB/s 360.79 GiB/s 361.15 GiB/s]
                 change:
                        time:   [-0.8211% -0.7134% -0.6133%] (p = 0.00 < 0.05)
                        thrpt:  [+0.6170% +0.7185% +0.8279%]
                        Change within noise threshold.
Found 11 outliers among 100 measurements (11.00%)
  4 (4.00%) low severe
  6 (6.00%) low mild
  1 (1.00%) high severe

cuda_affine_bf16/iter   time:   [5.4118 µs 5.4154 µs 5.4185 µs]
                        thrpt:  [360.45 GiB/s 360.66 GiB/s 360.90 GiB/s]
                 change:
                        time:   [-0.7404% -0.6370% -0.5323%] (p = 0.00 < 0.05)
                        thrpt:  [+0.5351% +0.6411% +0.7459%]
                        Change within noise threshold.
Found 8 outliers among 100 measurements (8.00%)
  3 (3.00%) low severe
  5 (5.00%) low mild

cuda_affine_fp8/iter    time:   [5.6150 µs 5.6186 µs 5.6216 µs]
                        thrpt:  [173.72 GiB/s 173.81 GiB/s 173.92 GiB/s]
                 change:
                        time:   [+3.6965% +3.7735% +3.8562%] (p = 0.00 < 0.05)
                        thrpt:  [-3.7130% -3.6363% -3.5648%]
                        Performance has regressed.
Found 7 outliers among 100 measurements (7.00%)
  4 (4.00%) low severe
  1 (1.00%) low mild
  1 (1.00%) high mild
  1 (1.00%) high severe

cpu_affine_f32/iter     time:   [60.175 µs 60.214 µs 60.256 µs]
                        thrpt:  [64.827 GiB/s 64.873 GiB/s 64.915 GiB/s]
                 change:
                        time:   [-0.3960% -0.2927% -0.1974%] (p = 0.00 < 0.05)
                        thrpt:  [+0.1978% +0.2936% +0.3976%]
                        Change within noise threshold.
Found 3 outliers among 100 measurements (3.00%)
  3 (3.00%) high mild

cpu_affine_f16/iter     time:   [313.67 µs 314.25 µs 314.84 µs]
                        thrpt:  [6.2035 GiB/s 6.2151 GiB/s 6.2267 GiB/s]
                 change:
                        time:   [+0.2332% +0.3714% +0.5040%] (p = 0.00 < 0.05)
                        thrpt:  [-0.5015% -0.3700% -0.2326%]
                        Change within noise threshold.
Found 15 outliers among 100 measurements (15.00%)
  7 (7.00%) high mild
  8 (8.00%) high severe

cpu_affine_bf16/iter    time:   [3.5991 ms 3.5996 ms 3.6001 ms]
                        thrpt:  [555.54 MiB/s 555.61 MiB/s 555.70 MiB/s]
                 change:
                        time:   [-0.0397% -0.0205% +0.0000%] (p = 0.05 > 0.05)
                        thrpt:  [-0.0000% +0.0205% +0.0397%]
                        No change in performance detected.
Found 7 outliers among 100 measurements (7.00%)
  3 (3.00%) low severe
  1 (1.00%) low mild
  2 (2.00%) high mild
  1 (1.00%) high severe

cpu_affine_fp8/iter     time:   [10.386 ms 10.387 ms 10.389 ms]
                        thrpt:  [96.259 MiB/s 96.271 MiB/s 96.283 MiB/s]
                 change:
                        time:   [-0.4090% -0.3742% -0.3366%] (p = 0.00 < 0.05)
                        thrpt:  [+0.3377% +0.3756% +0.4107%]
                        Change within noise threshold.
Found 2 outliers among 100 measurements (2.00%)
  2 (2.00%) high mild

Probably because we're double-casting from fp8->half->f32?

#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add)))

@zackangelo zackangelo marked this pull request as ready for review June 21, 2025 00:40
@zackangelo
Copy link
Contributor Author

@LaurentMazare let me know if this directionally looks good, happy to make any changes to the approach if needed.

@greenrazer
Copy link
Collaborator

I fixed a couple things for the CI's to pass, but besides that It looks good to me.

@zackangelo
Copy link
Contributor Author

Thanks for taking a look @greenrazer!

If there's any hesitance to merge as-is, would putting it behind a feature help? We'd probably have to leave the kernel additions but could gate all of the Rust code.

@EricLBuehler
Copy link
Member

@zackangelo can you confirm that the CUDA build works on CC > 8 and CC < 8 (i.e. maintaining compatability)?

@zackangelo
Copy link
Contributor Author

@EricLBuehler I tested an earlier build but would probably be worth getting some time on an A100 and verifying again, I'll see if I can get around to that today or tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants