File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change 43
43
from .memory import clear_device_cache
44
44
from .offload import load_offloaded_weight , offload_weight , save_offload_index
45
45
from .tqdm import is_tqdm_available , tqdm
46
- from .versions import compare_versions
46
+ from .versions import compare_versions , is_torch_version
47
47
48
48
49
49
if is_npu_available (check_device = False ):
@@ -161,8 +161,10 @@ def dtype_byte_size(dtype: torch.dtype):
161
161
return 1 / 4
162
162
elif dtype == CustomDtype .INT4 :
163
163
return 1 / 2
164
- elif dtype in [ CustomDtype .FP8 , torch . float8_e4m3fn ] :
164
+ elif dtype == CustomDtype .FP8 :
165
165
return 1
166
+ elif is_torch_version (">=" , "2.1.0" ) and dtype == torch .float8_e4m3fn :
167
+ return 1
166
168
bit_search = re .search (r"[^\d](\d+)$" , str (dtype ))
167
169
if bit_search is None :
168
170
raise ValueError (f"`dtype` is not a valid dtype: { dtype } ." )
You can’t perform that action at this time.
0 commit comments