Skip to content

Add emulate in float8 and relative checks #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented May 21, 2025

Add emulate in float8, to enable test on older hardware.

Change relative warnings

Test result:
Test locally on 8 H100 server.
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd
Screenshot 2025-05-21 at 2 38 39 PM

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate
Screenshot 2025-05-21 at 2 39 01 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 21, 2025
@mori360 mori360 changed the title Add emulate in float 8 and relative checks Add emulate in float8 and relative checks May 21, 2025
@mori360 mori360 marked this pull request as ready for review May 22, 2025 03:09
@mori360 mori360 requested review from tianyu-l and vkuzo May 22, 2025 03:09
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I left some inline comments.

@@ -8,3 +8,4 @@ tabulate
wandb
fsspec
tyro
torchao
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the recommended way of installing torchao is still via nightly, similar to how we install pytorch nightly for CI
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/integration_test_8gpu.yaml#L39
but for torchao
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git

"To enable support on older hardware, set `float8.emulate` to True.",
)
return
elif float8_config.emulate and job_config.training.compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if emulate+compile works on H100? Since the original comment from @vkuzo is

torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can only be used in eager mode.

Copy link
Contributor Author

@mori360 mori360 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have some tests on it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test to be good, remove this exception

Comment on lines 455 to 457
Whether to run on earlier hardware in CI test.
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can
only be used in eager mode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Whether to run on earlier hardware in CI test.
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can
only be used in eager mode.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.

This is assuming torch.compile+emulate don't work on >= H100 either. If not we'll need to further adjust code and helper message.

return
elif float8_config.emulate and job_config.training.compile:
logger.warning(
"Failed to run on emulate with compile on, please disable compile to allow on emulate.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just raise an exception if the configurations combination is not runnable.

@mori360 mori360 marked this pull request as draft May 22, 2025 17:35
@mori360 mori360 marked this pull request as ready for review May 22, 2025 18:41
@mori360 mori360 requested review from fegin and tianyu-l May 22, 2025 18:41
@@ -26,9 +26,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config: Float8 = job_config.float8
if not has_cuda_capability(8, 9):
if not has_cuda_capability(8, 9) and not float8_config.emulate:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to #1214 (comment) we should raise error instead of do this warning

Comment on lines +455 to +457
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does not have sm_89 capability, required by Float8.

logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later."
"To enable support on older hardware, set `float8.emulate` to True.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"To enable support on older hardware, set `float8.emulate` to True.",
"To enable testing on older hardware, set `float8.emulate` to True in eager mode.",

@@ -26,9 +26,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config: Float8 = job_config.float8
if not has_cuda_capability(8, 9):
if not has_cuda_capability(8, 9) and not float8_config.emulate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On sm < 89, we can't enable torch.compile with/without emulate, right? If so let's do

Suggested change
if not has_cuda_capability(8, 9) and not float8_config.emulate:
if not has_cuda_capability(8, 9) and (job_config.training.compile or not float8_config.emulate):

Also it's a bit hard to read. A better way may be

if has_cuda_capability(8, 9) or (float8_config.emulate and not job_config.training.compile): pass
else: raise ValueError(...)

@mori360 mori360 marked this pull request as draft May 23, 2025 17:15
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU CI error is because we change warning to exception when sm < 89.
I think we can just add the emulate flag to https://github.com/pytorch/torchtitan/blob/main/tests/unit_tests/test_model_converter.py#L42

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants