-
Notifications
You must be signed in to change notification settings - Fork 377
Add emulate in float8 and relative checks #1214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I left some inline comments.
.ci/docker/requirements.txt
Outdated
@@ -8,3 +8,4 @@ tabulate | |||
wandb | |||
fsspec | |||
tyro | |||
torchao |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the recommended way of installing torchao is still via nightly, similar to how we install pytorch nightly for CI
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/integration_test_8gpu.yaml#L39
but for torchao
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
"To enable support on older hardware, set `float8.emulate` to True.", | ||
) | ||
return | ||
elif float8_config.emulate and job_config.training.compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if emulate+compile works on H100? Since the original comment from @vkuzo is
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can only be used in eager mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will have some tests on it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test to be good, remove this exception
torchtitan/config_manager.py
Outdated
Whether to run on earlier hardware in CI test. | ||
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can | ||
only be used in eager mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether to run on earlier hardware in CI test. | |
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can | |
only be used in eager mode. | |
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, as the current CI does have sm_90 capability, required by Float8. | |
Not compatible with torch.compile. |
This is assuming torch.compile+emulate don't work on >= H100 either. If not we'll need to further adjust code and helper message.
return | ||
elif float8_config.emulate and job_config.training.compile: | ||
logger.warning( | ||
"Failed to run on emulate with compile on, please disable compile to allow on emulate.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just raise an exception if the configurations combination is not runnable.
@@ -26,9 +26,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
self.enabled = False | |||
|
|||
float8_config: Float8 = job_config.float8 | |||
if not has_cuda_capability(8, 9): | |||
if not has_cuda_capability(8, 9) and not float8_config.emulate: | |||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to #1214 (comment) we should raise error instead of do this warning
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | ||
as the current CI does have sm_90 capability, required by Float8. | ||
Not compatible with torch.compile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
as the current CI does have sm_90 capability, required by Float8. | |
Not compatible with torch.compile. | |
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
as the current CI does not have sm_89 capability, required by Float8. |
logger.warning( | ||
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later", | ||
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later." | ||
"To enable support on older hardware, set `float8.emulate` to True.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"To enable support on older hardware, set `float8.emulate` to True.", | |
"To enable testing on older hardware, set `float8.emulate` to True in eager mode.", |
@@ -26,9 +26,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
self.enabled = False | |||
|
|||
float8_config: Float8 = job_config.float8 | |||
if not has_cuda_capability(8, 9): | |||
if not has_cuda_capability(8, 9) and not float8_config.emulate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On sm < 89, we can't enable torch.compile with/without emulate, right? If so let's do
if not has_cuda_capability(8, 9) and not float8_config.emulate: | |
if not has_cuda_capability(8, 9) and (job_config.training.compile or not float8_config.emulate): |
Also it's a bit hard to read. A better way may be
if has_cuda_capability(8, 9) or (float8_config.emulate and not job_config.training.compile): pass
else: raise ValueError(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CPU CI error is because we change warning to exception when sm < 89.
I think we can just add the emulate
flag to https://github.com/pytorch/torchtitan/blob/main/tests/unit_tests/test_model_converter.py#L42
Add emulate in float8, to enable test on older hardware.
Change relative warnings
Test result:

Test locally on 8 H100 server.
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate