Skip to content

Commit 77cb1d4

Browse files
committed
Add support for MXFP8 recipe in accelerate
1 parent 2f075c7 commit 77cb1d4

File tree

7 files changed

+60
-1
lines changed

7 files changed

+60
-1
lines changed

src/accelerate/test_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
require_torchvision,
5454
require_tpu,
5555
require_transformer_engine,
56+
require_transformer_engine_mxfp8,
5657
require_xpu,
5758
run_first,
5859
skip,

src/accelerate/test_utils/testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
is_torchvision_available,
7272
is_trackio_available,
7373
is_transformer_engine_available,
74+
is_transformer_engine_mxfp8_available,
7475
is_transformers_available,
7576
is_triton_available,
7677
is_wandb_available,
@@ -540,6 +541,13 @@ def require_transformer_engine(test_case):
540541
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
541542

542543

544+
def require_transformer_engine_mxfp8(test_case):
545+
"""
546+
Decorator marking a test that requires transformers engine FP8 block scaling available. These tests are skipped when transformers
547+
engine FP8 block scaling isn't available
548+
"""
549+
return unittest.skipUnless(is_transformer_engine_mxfp8_available(), "test requires transformers engine FP8 block scaling")(test_case)
550+
543551
def require_torchao(test_case):
544552
"""
545553
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
is_torchvision_available,
132132
is_trackio_available,
133133
is_transformer_engine_available,
134+
is_transformer_engine_mxfp8_available,
134135
is_transformers_available,
135136
is_triton_available,
136137
is_wandb_available,

src/accelerate/utils/dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ class TERecipeKwargs(KwargsHandler):
359359
amax_history_len: int = None
360360
amax_compute_algo: AmaxComputeAlgorithm = None
361361
override_linear_precision: tuple[bool, bool, bool] = None
362+
use_mxfp8_block_scaling: bool = None
362363

363364
def __post_init__(self):
364365
env_prefix = "ACCELERATE_FP8_"
@@ -387,6 +388,8 @@ def __post_init__(self):
387388
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
388389
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
389390
self.override_linear_precision = (fprop, dgrad, wgrad)
391+
if self.use_mxfp8_block_scaling is None:
392+
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
390393

391394

392395
@dataclass

src/accelerate/utils/imports.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def is_transformer_engine_available():
114114
return _is_package_available("transformer_engine", "transformer-engine")
115115

116116

117+
def is_transformer_engine_mxfp8_available():
118+
if _is_package_available("transformer_engine", "transformer-engine"):
119+
import transformer_engine.pytorch as te
120+
return te.fp8.check_mxfp8_support()[0]
121+
return False
122+
117123
def is_lomo_available():
118124
return _is_package_available("lomo_optim")
119125

src/accelerate/utils/transformer_engine.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,32 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
146146

147147
if is_hpu_available():
148148
import intel_transformer_engine.recipe as te_recipe
149+
is_fp8_block_scaling_available = False
150+
message = "MXFP8 block scaling is not available on HPU."
151+
149152
else:
150153
import transformer_engine.common.recipe as te_recipe
154+
import transformer_engine.pytorch as te
155+
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
151156

152157
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
153158
if "fp8_format" in kwargs:
154159
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
155160
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
156-
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
161+
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
162+
163+
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
164+
raise ValueError(f"MXFP8 block scaling is not available: {message}")
165+
166+
if use_mxfp8_block_scaling:
167+
if "amax_compute_algo" in kwargs:
168+
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
169+
if "amax_history_len" in kwargs:
170+
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
171+
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
172+
else:
173+
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
174+
157175
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
158176

159177
if hasattr(model.forward, "__func__"):

tests/test_fp8.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
require_multi_device,
3232
require_torchao,
3333
require_transformer_engine,
34+
require_transformer_engine_mxfp8,
3435
run_first,
3536
)
3637
from accelerate.test_utils.testing import require_deepspeed, run_command
@@ -109,6 +110,27 @@ def test_can_prepare_model_single_gpu_from_config(self):
109110
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
110111
run_command(command)
111112

113+
114+
@require_transformer_engine_mxfp8
115+
def test_can_prepare_model_with_mxfp8_block_scaling(self):
116+
with tempfile.TemporaryDirectory() as dir_name:
117+
config_file = Path(dir_name) / "config.yaml"
118+
config_file.write_text(
119+
textwrap.dedent(
120+
"""
121+
distributed_type: "NO"
122+
num_processes: 1
123+
mixed_precision: fp8
124+
fp8_config:
125+
backend: TE
126+
use_mxfp8_block_scaling: true
127+
"""
128+
)
129+
)
130+
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
131+
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
132+
run_command(command)
133+
112134
@require_multi_device
113135
def test_can_prepare_model_multi_gpu(self):
114136
command = get_launch_command(num_processes=2, monitor_interval=0.1)

0 commit comments

Comments
 (0)