Skip to content

Commit b7493a8

Browse files
Add support for e5e2 and default to hybrid when launcher is used (#3640)
* add support for e5e2 and defaumt to hybrid when launcher is used * style
1 parent a16d2bb commit b7493a8

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

src/accelerate/commands/config/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,8 +774,8 @@ def get_cluster_input():
774774
)
775775
fp8_config["fp8_format"] = _ask_options(
776776
"Which weight format should be used?",
777-
["HYBRID", "E4M3"],
778-
lambda x: "HYBRID" if x == 0 else "E4M3",
777+
["HYBRID", "E4M3", "E5M2"],
778+
lambda i: ["HYBRID", "E4M3", "E5M2"][i],
779779
default=0,
780780
)
781781
fp8_config["amax_history_length"] = _ask_field(

src/accelerate/commands/launch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,8 @@ def launch_command_parser(subparsers=None):
692692
fp8_args.add_argument(
693693
"--fp8_format",
694694
type=str,
695-
default="E4M3",
696-
choices=["E4M3", "HYBRID"],
695+
default="HYBRID",
696+
choices=["HYBRID", "E4M3", "E5M2"],
697697
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
698698
)
699699
fp8_args.add_argument(

src/accelerate/utils/dataclasses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def __post_init__(self):
289289
# Literals
290290
Backend = Literal["MSAMP", "TE"]
291291
OptLevel = Literal["O1", "O2"]
292-
FP8Format = Literal["E4M3", "HYBRID"]
292+
FP8Format = Literal["HYBRID", "E4M3", "E5M2"]
293293
AmaxComputeAlgorithm = Literal["max", "most_recent"]
294294

295295

@@ -342,8 +342,8 @@ class TERecipeKwargs(KwargsHandler):
342342
interval (`int`, *optional*, default to 1):
343343
The interval to use for how often the scaling factor is recomputed.
344344
fp8_format (`str`, *optional*, default to "HYBRID"):
345-
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
346-
`E4M3` for evaluation)
345+
The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for
346+
training, `E4M3` or `E5M2` for evaluation)
347347
amax_history_len (`int`, *optional*, default to 1024):
348348
The length of the history to use for the scaling factor computation
349349
amax_compute_algo (`str`, *optional*, default to "most_recent"):

0 commit comments

Comments
 (0)