Skip to content

Add support for TE MXFP8 recipe in accelerate #3688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,13 @@ def __init__(
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
"before using any functionality from the `accelerate` library."
)
# deepspeed handles mixed_precision using deepspeed_config
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
# if we're using fp8.
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
self._mixed_precision = "no"
else:
self._mixed_precision = mixed_precision

if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
if mixed_precision == "bf16":
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
Expand Down Expand Up @@ -1037,7 +1042,7 @@ def _check_initialized(self, mixed_precision=None, cpu=None):

@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED:
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
config = self.deepspeed_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
require_torchvision,
require_tpu,
require_transformer_engine,
require_transformer_engine_mxfp8,
require_xpu,
run_first,
skip,
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
Expand Down Expand Up @@ -540,6 +541,16 @@ def require_transformer_engine(test_case):
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)


def require_transformer_engine_mxfp8(test_case):
"""
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
when transformers engine MXFP8 block scaling isn't available
"""
return unittest.skipUnless(
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
)(test_case)


def require_torchao(test_case):
"""
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class TERecipeKwargs(KwargsHandler):
amax_history_len: int = None
amax_compute_algo: AmaxComputeAlgorithm = None
override_linear_precision: tuple[bool, bool, bool] = None
use_mxfp8_block_scaling: bool = None

def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
Expand Down Expand Up @@ -398,6 +399,8 @@ def __post_init__(self):
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
self.override_linear_precision = (fprop, dgrad, wgrad)
if self.use_mxfp8_block_scaling is None:
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")


@dataclass
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def is_transformer_engine_available():
return _is_package_available("transformer_engine", "transformer-engine")


def is_transformer_engine_mxfp8_available():
if _is_package_available("transformer_engine", "transformer-engine"):
import transformer_engine.pytorch as te

return te.fp8.check_mxfp8_support()[0]
return False


def is_lomo_available():
return _is_package_available("lomo_optim")

Expand Down
22 changes: 21 additions & 1 deletion src/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,34 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):

if is_hpu_available():
import intel_transformer_engine.recipe as te_recipe

is_fp8_block_scaling_available = False
message = "MXFP8 block scaling is not available on HPU."

else:
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te

is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()

kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)

if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
raise ValueError(f"MXFP8 block scaling is not available: {message}")

if use_mxfp8_block_scaling:
if "amax_compute_algo" in kwargs:
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
if "amax_history_len" in kwargs:
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
else:
fp8_recipe = te_recipe.DelayedScaling(**kwargs)

new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)

if hasattr(model.forward, "__func__"):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
require_multi_device,
require_torchao,
require_transformer_engine,
require_transformer_engine_mxfp8,
run_first,
)
from accelerate.test_utils.testing import require_deepspeed, run_command
Expand All @@ -49,6 +50,8 @@ def can_convert_te_model(from_config=False):
accelerator_kwargs = {}

accelerator = Accelerator(**accelerator_kwargs)
assert accelerator.fp8_enabled, "FP8 is not enabled"

dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
Expand Down Expand Up @@ -109,6 +112,26 @@ def test_can_prepare_model_single_gpu_from_config(self):
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)

@require_transformer_engine_mxfp8
def test_can_prepare_model_with_mxfp8_block_scaling(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
use_mxfp8_block_scaling: true
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)

@require_multi_device
def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
Expand Down Expand Up @@ -147,6 +170,35 @@ def test_can_prepare_model_multigpu_deepspeed(self):
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)

@require_deepspeed
@require_multi_device
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
os.environ["ZERO_STAGE"] = str(1)
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "DEEPSPEED"
deepspeed_config:
gradient_clipping: 1.0
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 1
deepspeed_multinode_launcher: standard
num_processes: 2
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)


@require_torchao
@require_huggingface_suite
Expand Down