Skip to content

Commit 168b520

Browse files
author
Salman Mohammadi
committed
reverting upcast
1 parent 379daa0 commit 168b520

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import functools
1616
import os
1717
import shutil
18+
import warnings
1819
from collections import defaultdict
1920
from contextlib import nullcontext
2021
from pathlib import Path
@@ -693,6 +694,18 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
693694
if hasattr(model, "tie_weights"):
694695
model.tie_weights()
695696

697+
# There is no `dtype` attribution for nn.Module
698+
# Set it to None if it doesn't exist and do the upcast always
699+
model_dtype = getattr(model, "dtype", None)
700+
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
701+
# We upcast the model according to `deepspeed`'s implementation
702+
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
703+
model = model.to(torch.float32)
704+
if accelerator.is_main_process:
705+
# TODO(siro1): Add a warning for each parameter that was upcasted
706+
warnings.warn(
707+
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
708+
)
696709
return model
697710

698711

0 commit comments

Comments
 (0)