Skip to content

Fix FP8 tests, enable FP8 to be used without direct Accelerator() configuring #3677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/fp8/transformer_engine/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git

RUN cd accelerate && \
pip install -e . && \
pip install -e .[deepspeed] && \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the container used for fp8 CI tests, which include deepspeed. so even though we don't use it in these benchmark scripts (which we should double-check are still functional 😄), this allows the requires_deepspeed tests to run in the fp8 tests.

cd benchmarks/fp8

RUN /bin/bash
Expand Down
4 changes: 2 additions & 2 deletions examples/config_yaml_templates/fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ fp8_config:
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: (false, false, false)
override_linear_precision: [false, false, false]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't exercised in CI anywhere but I caught this bug while using this to debug locally 🤷 . The () expression just gets evaluated to a string rather than a list

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1
#opt_level: O1
79 changes: 47 additions & 32 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards

from accelerate.utils.dataclasses import FP8BackendType
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already had this enum, so I figured it was worth using this here instead of the string comparisons


from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
Expand Down Expand Up @@ -301,6 +303,7 @@ def __init__(
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
if project_dir is not None and self.project_dir is None:
self.project_configuration.set_directories(project_dir)

if mixed_precision is not None:
mixed_precision = str(mixed_precision)
if mixed_precision not in PrecisionType:
Expand Down Expand Up @@ -458,27 +461,34 @@ def __init__(

# Check for automatic FP8 recipe creation
if self.fp8_enabled and not self.has_fp8_handler:
# Prioritize AO -> TE -> MSAMP
if is_torchao_available():
logger.info("Found `torchao` installed, using it for FP8 training.")
if self.fp8_backend == FP8BackendType.AO:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we first defer to the fp8_backend specified in the yaml, and only if we didn't specify one do we revert to the AO -> TE -> MSAMP preference.

self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available():
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
elif self.fp8_backend == FP8BackendType.TE:
self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available():
logger.info("Found `msamp` installed, using it for FP8 training.")
elif self.fp8_backend == FP8BackendType.MSAMP:
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
)
elif self.fp8_backend == FP8BackendType.NO:
# Prioritize AO -> TE -> MSAMP
if is_torchao_available():
logger.info("Found `torchao` installed, using it for FP8 training.")
self.ao_recipe_handler = AORecipeKwargs()
elif is_transformer_engine_available():
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
self.te_recipe_handler = TERecipeKwargs()
elif is_msamp_available():
logger.info("Found `msamp` installed, using it for FP8 training.")
self.msamp_recipe_handler = MSAMPRecipeKwargs()
else:
raise ImportError(
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
)
self.has_fp8_handler = True

self.delayed_fp8_autocast = False
if self.has_fp8_handler:
# We already check if FP8 is available during `self.state`
if mixed_precision != "fp8" and (
if not self.fp8_enabled and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
Expand All @@ -488,7 +498,11 @@ def __init__(
)

# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading:
if (
self.fp8_backend == FP8BackendType.AO
and self.state.distributed_type == DistributedType.FSDP
and self.state.fsdp_plugin.cpu_ram_efficient_loading
):
Comment on lines +501 to +505

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to submit a PR that made this exact change :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah IIRC this was where some of the existing fp8 tests were failing

raise ValueError(
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
)
Expand Down Expand Up @@ -572,7 +586,7 @@ def __init__(
elif self.fp8_enabled:
# We always enable `native_amp` for FP8
self.native_amp = True
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
if self.distributed_type == DistributedType.FSDP:
raise NotImplementedError(
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
Expand Down Expand Up @@ -1419,9 +1433,9 @@ def prepare(self, *args, device_placement=None):
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
)
args = self._prepare_ipex(*args)
if self.fp8_backend == "TE":
if self.fp8_backend == FP8BackendType.TE:
args = self._prepare_te(*args)
elif self.fp8_backend == "AO":
elif self.fp8_backend == FP8BackendType.AO:
args = self._prepare_ao(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
Expand All @@ -1430,7 +1444,7 @@ def prepare(self, *args, device_placement=None):
elif self.is_fsdp2:
result = self._prepare_fsdp2(*args)
else:
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
Expand Down Expand Up @@ -1570,7 +1584,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model._original_forward = model.forward
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
model_forward_func = model.forward
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
else:
Expand All @@ -1580,7 +1594,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)

# We prepare TE after, allowing for bf16 autocast to happen first
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
Expand Down Expand Up @@ -1806,7 +1820,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
Expand Down Expand Up @@ -1884,7 +1898,7 @@ def _prepare_deepspeed(self, *args):
import deepspeed

ds_initialize = deepspeed.initialize
if self.fp8_backend == "MSAMP":
if self.fp8_backend == FP8BackendType.MSAMP:
# MS-AMP requires DeepSpeed patches
from msamp import deepspeed as msamp_deepspeed

Expand Down Expand Up @@ -2022,7 +2036,7 @@ def _prepare_deepspeed(self, *args):

if model is not None:
# If we are using FP8, we need to apply the autowrap now
if self.fp8_backend == "TE":
if self.fp8_backend == FP8BackendType.TE:
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
deepspeed_plugin.set_moe_leaf_modules(model)
Expand Down Expand Up @@ -2479,7 +2493,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N
device_placement = self.device_placement
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
# Their optimizer handles it for us.
scaler = None if self.fp8_backend == "MSAMP" else self.scaler
scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
self._optimizers.append(optimizer)
return optimizer
Expand Down Expand Up @@ -3668,7 +3682,7 @@ def _get_named_parameters(self, *args, drop_refs=False):

# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
# the underlying pointer is actually hidden in `_tensor` attribute
if self.fp8_backend == "AO":
if self.fp8_backend == FP8BackendType.AO:
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor

accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
Expand Down Expand Up @@ -3977,17 +3991,18 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
)

@property
def fp8_backend(self):
def fp8_backend(self) -> FP8BackendType:
"Returns the configured backend for training in FP8"
if self.has_fp8_handler:
if self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend
return FP8BackendType(self.fp8_recipe_handler.backend)
elif self.ao_recipe_handler is not None:
return "AO"
return FP8BackendType.AO
elif self.te_recipe_handler is not None:
return "TE"
return FP8BackendType.TE
elif self.msamp_recipe_handler is not None:
return "MSAMP"
return FP8BackendType.MSAMP
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
return None
return FP8BackendType.MSAMP

return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))
2 changes: 2 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,10 @@ class FP8BackendType(str, enum.Enum):
"""

# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
NO = "NO"
TE = "TE"
MSAMP = "MSAMP"
AO = "AO"


class ComputeEnvironment(str, enum.Enum):
Expand Down
6 changes: 3 additions & 3 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
value = getattr(args, arg)
if value is not None:
if arg == "fp8_override_linear_precision":
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env
Expand Down
Loading