Skip to content

Commit 0272ce1

Browse files
committed
check torch version
1 parent 5c7f700 commit 0272ce1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/accelerate/utils/modeling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from .memory import clear_device_cache
4444
from .offload import load_offloaded_weight, offload_weight, save_offload_index
4545
from .tqdm import is_tqdm_available, tqdm
46-
from .versions import compare_versions
46+
from .versions import compare_versions, is_torch_version
4747

4848

4949
if is_npu_available(check_device=False):
@@ -161,8 +161,10 @@ def dtype_byte_size(dtype: torch.dtype):
161161
return 1 / 4
162162
elif dtype == CustomDtype.INT4:
163163
return 1 / 2
164-
elif dtype in [CustomDtype.FP8, torch.float8_e4m3fn]:
164+
elif dtype == CustomDtype.FP8:
165165
return 1
166+
elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
167+
return 1
166168
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
167169
if bit_search is None:
168170
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")

0 commit comments

Comments
 (0)