Skip to content

Add support for TE MXFP8 recipe in accelerate #3688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Jul 21, 2025

What does this PR do?

Adds support for the MXFP8 format in TE. See the TE docs pages for more background:
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling

This adds an additional fp8_recipe argument, use_mxfp8_block_scaling, that switches the recipe from the DelayedScaling recipe to MXFP8BlockScaling.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@pstjohn
Copy link
Contributor Author

pstjohn commented Jul 22, 2025

This is outside the initial scope for this PR, but there's some oddity when using Deepspeed + FP8 + the HF Trainer.

If you set bf16: True in your TrainingArguments, the trainer will override the fp8 parameter you pass to Accelerate (manually setting ACCELERATE_MIXED_PRECISION=bf16) here:

And if you omit it, ACCELERATE_MIXED_PRECISION stays as FP8 but you then get an error raised here about a config mismatch:

ValueError: Please correct the following DeepSpeed config values that mismatch TrainingArguments values:
- ds bf16.enabled=True vs hf bf16|bf16_full_eval=False
The easiest method is to set these DeepSpeed config values to 'auto'.

Interestingly it's still possible to use FP8 with deepspeed currently? But it seems like a bug. This check:
https://github.com/huggingface/accelerate/blame/2f075c724ccb4e38fade64db3b0627ca167b5fd2/src/accelerate/accelerator.py#L2046-L2047
will trigger when you pass an fp8 backend with deepspeed, because it only checks for fp8_backend, not fp8_enabled. But you wont create the TERecipeKwargs object, so it will just create the fp8 autowrap context with the default recipe.

There have been a number of "FP8 + deepspeed" PRs here in the past, I'm wondering if the cleanest option is to separate "mixed_precision" from fp8. fp8 typically uses bf16 for model weights and between FP8-enabled layers anyways.

@pstjohn pstjohn changed the title Add support for MXFP8 recipe in accelerate Add support for TE MXFP8 recipe in accelerate Jul 22, 2025
Comment on lines 332 to 340
if (
AcceleratorState._shared_state != {}
and AcceleratorState().distributed_type == DistributedType.DEEPSPEED
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting only change, not sure why it's changing it from main

@pstjohn pstjohn force-pushed the pstjohn/te-mxfp8-recipe branch from edbe9d5 to 46ebf27 Compare July 22, 2025 19:37
@pstjohn pstjohn marked this pull request as ready for review July 30, 2025 22:11
@pstjohn pstjohn force-pushed the pstjohn/te-mxfp8-recipe branch from 46ebf27 to 1fb8f76 Compare July 30, 2025 22:13
@S1ro1
Copy link
Member

S1ro1 commented Aug 2, 2025

Do I understand correctly that this only covers DeepSpeed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants