-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix FP8 tests, enable FP8 to be used without direct Accelerator()
configuring
#3677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,8 @@ fp8_config: | |
fp8_format: E4M3 | ||
interval: 1 | ||
margin: 0 | ||
override_linear_precision: (false, false, false) | ||
override_linear_precision: [false, false, false] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this isn't exercised in CI anywhere but I caught this bug while using this to debug locally 🤷 . The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch |
||
# Generally this should always be set to `false` to have the most realistic fp8 eval performance | ||
use_autocast_during_eval: false | ||
# If using MS-AMP, we ignore all of the prior and set a opt_level | ||
#opt_level: O1 | ||
#opt_level: O1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
import torch.utils.hooks as hooks | ||
from huggingface_hub import split_torch_state_dict_into_shards | ||
|
||
from accelerate.utils.dataclasses import FP8BackendType | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we already had this enum, so I figured it was worth using this here instead of the string comparisons |
||
|
||
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state | ||
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches | ||
from .logging import get_logger | ||
|
@@ -301,6 +303,7 @@ def __init__( | |
self.project_configuration = ProjectConfiguration(project_dir=project_dir) | ||
if project_dir is not None and self.project_dir is None: | ||
self.project_configuration.set_directories(project_dir) | ||
|
||
if mixed_precision is not None: | ||
mixed_precision = str(mixed_precision) | ||
if mixed_precision not in PrecisionType: | ||
|
@@ -458,27 +461,34 @@ def __init__( | |
|
||
# Check for automatic FP8 recipe creation | ||
if self.fp8_enabled and not self.has_fp8_handler: | ||
# Prioritize AO -> TE -> MSAMP | ||
if is_torchao_available(): | ||
logger.info("Found `torchao` installed, using it for FP8 training.") | ||
if self.fp8_backend == FP8BackendType.AO: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we first defer to the fp8_backend specified in the yaml, and only if we didn't specify one do we revert to the AO -> TE -> MSAMP preference. |
||
self.ao_recipe_handler = AORecipeKwargs() | ||
elif is_transformer_engine_available(): | ||
logger.info("Found `transformer-engine` installed, using it for FP8 training.") | ||
elif self.fp8_backend == FP8BackendType.TE: | ||
self.te_recipe_handler = TERecipeKwargs() | ||
elif is_msamp_available(): | ||
logger.info("Found `msamp` installed, using it for FP8 training.") | ||
elif self.fp8_backend == FP8BackendType.MSAMP: | ||
self.msamp_recipe_handler = MSAMPRecipeKwargs() | ||
else: | ||
raise ImportError( | ||
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. " | ||
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`." | ||
) | ||
elif self.fp8_backend == FP8BackendType.NO: | ||
# Prioritize AO -> TE -> MSAMP | ||
if is_torchao_available(): | ||
logger.info("Found `torchao` installed, using it for FP8 training.") | ||
self.ao_recipe_handler = AORecipeKwargs() | ||
elif is_transformer_engine_available(): | ||
logger.info("Found `transformer-engine` installed, using it for FP8 training.") | ||
self.te_recipe_handler = TERecipeKwargs() | ||
elif is_msamp_available(): | ||
logger.info("Found `msamp` installed, using it for FP8 training.") | ||
self.msamp_recipe_handler = MSAMPRecipeKwargs() | ||
else: | ||
raise ImportError( | ||
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. " | ||
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`." | ||
) | ||
self.has_fp8_handler = True | ||
|
||
self.delayed_fp8_autocast = False | ||
if self.has_fp8_handler: | ||
# We already check if FP8 is available during `self.state` | ||
if mixed_precision != "fp8" and ( | ||
if not self.fp8_enabled and ( | ||
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) | ||
): | ||
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.") | ||
|
@@ -488,7 +498,11 @@ def __init__( | |
) | ||
|
||
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too | ||
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading: | ||
if ( | ||
self.fp8_backend == FP8BackendType.AO | ||
and self.state.distributed_type == DistributedType.FSDP | ||
and self.state.fsdp_plugin.cpu_ram_efficient_loading | ||
): | ||
Comment on lines
+501
to
+505
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was about to submit a PR that made this exact change :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related: huggingface/transformers#39370 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah IIRC this was where some of the existing fp8 tests were failing |
||
raise ValueError( | ||
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended." | ||
) | ||
|
@@ -572,7 +586,7 @@ def __init__( | |
elif self.fp8_enabled: | ||
# We always enable `native_amp` for FP8 | ||
self.native_amp = True | ||
if self.fp8_backend == "MSAMP": | ||
if self.fp8_backend == FP8BackendType.MSAMP: | ||
if self.distributed_type == DistributedType.FSDP: | ||
raise NotImplementedError( | ||
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. " | ||
|
@@ -1419,9 +1433,9 @@ def prepare(self, *args, device_placement=None): | |
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU." | ||
) | ||
args = self._prepare_ipex(*args) | ||
if self.fp8_backend == "TE": | ||
if self.fp8_backend == FP8BackendType.TE: | ||
args = self._prepare_te(*args) | ||
elif self.fp8_backend == "AO": | ||
elif self.fp8_backend == FP8BackendType.AO: | ||
args = self._prepare_ao(*args) | ||
if self.distributed_type == DistributedType.DEEPSPEED: | ||
result = self._prepare_deepspeed(*args) | ||
|
@@ -1430,7 +1444,7 @@ def prepare(self, *args, device_placement=None): | |
elif self.is_fsdp2: | ||
result = self._prepare_fsdp2(*args) | ||
else: | ||
if self.fp8_backend == "MSAMP": | ||
if self.fp8_backend == FP8BackendType.MSAMP: | ||
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) | ||
result = tuple( | ||
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) | ||
|
@@ -1570,7 +1584,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |
model._original_forward = model.forward | ||
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) | ||
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward` | ||
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"): | ||
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"): | ||
model_forward_func = model.forward | ||
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) | ||
else: | ||
|
@@ -1580,7 +1594,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) | ||
|
||
# We prepare TE after, allowing for bf16 autocast to happen first | ||
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast: | ||
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast: | ||
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) | ||
|
||
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( | ||
|
@@ -1806,7 +1820,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched: | ||
model = xmp.MpModelWrapper(model).to(self.device) | ||
# Now we can apply the FP8 autocast | ||
if self.fp8_backend == "TE" and self.delayed_fp8_autocast: | ||
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast: | ||
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) | ||
# torch.compile should be called last and only if the model isn't already compiled | ||
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): | ||
|
@@ -1884,7 +1898,7 @@ def _prepare_deepspeed(self, *args): | |
import deepspeed | ||
|
||
ds_initialize = deepspeed.initialize | ||
if self.fp8_backend == "MSAMP": | ||
if self.fp8_backend == FP8BackendType.MSAMP: | ||
# MS-AMP requires DeepSpeed patches | ||
from msamp import deepspeed as msamp_deepspeed | ||
|
||
|
@@ -2022,7 +2036,7 @@ def _prepare_deepspeed(self, *args): | |
|
||
if model is not None: | ||
# If we are using FP8, we need to apply the autowrap now | ||
if self.fp8_backend == "TE": | ||
if self.fp8_backend == FP8BackendType.TE: | ||
model = apply_fp8_autowrap(model, self.fp8_recipe_handler) | ||
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules | ||
deepspeed_plugin.set_moe_leaf_modules(model) | ||
|
@@ -2479,7 +2493,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N | |
device_placement = self.device_placement | ||
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`, | ||
# Their optimizer handles it for us. | ||
scaler = None if self.fp8_backend == "MSAMP" else self.scaler | ||
scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler | ||
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) | ||
self._optimizers.append(optimizer) | ||
return optimizer | ||
|
@@ -3668,7 +3682,7 @@ def _get_named_parameters(self, *args, drop_refs=False): | |
|
||
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called, | ||
# the underlying pointer is actually hidden in `_tensor` attribute | ||
if self.fp8_backend == "AO": | ||
if self.fp8_backend == FP8BackendType.AO: | ||
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor | ||
|
||
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor" | ||
|
@@ -3977,17 +3991,18 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: | |
) | ||
|
||
@property | ||
def fp8_backend(self): | ||
def fp8_backend(self) -> FP8BackendType: | ||
"Returns the configured backend for training in FP8" | ||
if self.has_fp8_handler: | ||
if self.fp8_recipe_handler is not None: | ||
return self.fp8_recipe_handler.backend | ||
return FP8BackendType(self.fp8_recipe_handler.backend) | ||
elif self.ao_recipe_handler is not None: | ||
return "AO" | ||
return FP8BackendType.AO | ||
elif self.te_recipe_handler is not None: | ||
return "TE" | ||
return FP8BackendType.TE | ||
elif self.msamp_recipe_handler is not None: | ||
return "MSAMP" | ||
return FP8BackendType.MSAMP | ||
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: | ||
return "MSAMP" | ||
return None | ||
return FP8BackendType.MSAMP | ||
|
||
return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the container used for fp8 CI tests, which include deepspeed. so even though we don't use it in these benchmark scripts (which we should double-check are still functional 😄), this allows the
requires_deepspeed
tests to run in the fp8 tests.