Skip to content

Commit 2ca55ae

Browse files
author
maxtext authors
committed
Merge pull request #1629 from AI-Hypercomputer:lihao/fp8_to_bf16
PiperOrigin-RevId: 751590057
2 parents 66f2a85 + ed5dfbc commit 2ca55ae

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

MaxText/deepseek_fp8_to_bf16.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
r"""Convert weights from FP8 to BF16 for a HF model.
15+
16+
Install these dependencies before running this script:
17+
18+
pip install torch==2.4.1 safetensors==0.4.5
19+
20+
Example cmd:
21+
22+
python3 -m MaxText.deepseek_fp8_to_bf16 --input-fp8-hf-path <path/to/fp8/ckpt> \
23+
--output-bf16-hf-path <local/path/to/save/new/bf16/ckpt>
24+
"""
25+
26+
27+
import os
28+
import json
29+
from argparse import ArgumentParser
30+
from glob import glob
31+
import string
32+
from tqdm import tqdm
33+
34+
import torch
35+
from safetensors.torch import load_file, save_file
36+
37+
38+
def weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
39+
"""
40+
Dequantizes the given FP8 weight tensor using the provided scale tensor on CPU.
41+
42+
Args:
43+
x (torch.Tensor): The quantized FP8 weight tensor of shape (M, N), dtype=torch.float8.
44+
s (torch.Tensor): The scale tensor, dtype=torch.bfloat16 or float32.
45+
block_size (int, optional): Size of the block used in quantization.
46+
47+
Returns:
48+
torch.Tensor: The dequantized weight tensor, dtype=torch.bfloat16.
49+
50+
Raises:
51+
AssertionError: If the input tensors are not 2D.
52+
"""
53+
assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors"
54+
55+
M, N = x.shape
56+
57+
x = x.to(torch.float32)
58+
y = torch.empty_like(x, dtype=torch.get_default_dtype())
59+
60+
for i in range(0, M, block_size):
61+
for j in range(0, N, block_size):
62+
row_start = i
63+
row_end = min(i + block_size, M)
64+
col_start = j
65+
col_end = min(j + block_size, N)
66+
block = x[row_start:row_end, col_start:col_end]
67+
scale = s[i // block_size, j // block_size]
68+
y[row_start:row_end, col_start:col_end] = (block * scale).to(torch.get_default_dtype())
69+
70+
return y
71+
72+
73+
def convert_fp8_to_bf16(fp8_path: string, bf16_path: string, cache_file_num: int = 2):
74+
"""
75+
Converts a FP8 model to a BF16 model and saves the converted weights.
76+
77+
This function reads FP8 weights from the specified directory, converts them to BF16,
78+
and saves the converted weights to another specified directory. It also updates the
79+
model index file to reflect the changes. The conversion process runs on CPU devices.
80+
81+
Args:
82+
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
83+
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
84+
85+
Raises:
86+
KeyError: If a required scale_inv tensor is missing for a weight.
87+
88+
Notes:
89+
- The function assumes that the FP8 weights are stored in safetensor files.
90+
- The function caches loaded safetensor files to optimize memory usage.
91+
- The function updates the model index file to remove references to scale_inv tensors.
92+
"""
93+
torch.set_default_dtype(torch.bfloat16)
94+
os.makedirs(bf16_path, exist_ok=True)
95+
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
96+
with open(model_index_file, "r") as f:
97+
model_index = json.load(f)
98+
weight_map = model_index["weight_map"]
99+
100+
# Cache for loaded safetensor files
101+
loaded_files = {}
102+
fp8_weight_names = []
103+
104+
# Helper function to get tensor from the correct file
105+
def get_tensor(tensor_name):
106+
"""
107+
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
108+
109+
Args:
110+
tensor_name (str): The name of the tensor to retrieve.
111+
112+
Returns:
113+
torch.Tensor: The retrieved tensor.
114+
115+
Raises:
116+
KeyError: If the tensor does not exist in the safetensor file.
117+
"""
118+
file_name = weight_map[tensor_name]
119+
if file_name not in loaded_files:
120+
file_path = os.path.join(fp8_path, file_name)
121+
loaded_files[file_name] = load_file(file_path, device="cpu")
122+
return loaded_files[file_name][tensor_name]
123+
124+
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
125+
safetensor_files.sort()
126+
for safetensor_file in tqdm(safetensor_files):
127+
file_name = os.path.basename(safetensor_file)
128+
current_state_dict = load_file(safetensor_file, device="cpu")
129+
loaded_files[file_name] = current_state_dict
130+
131+
new_state_dict = {}
132+
for weight_name, weight in current_state_dict.items():
133+
if weight_name.endswith("_scale_inv"):
134+
continue
135+
elif weight.element_size() == 1: # FP8 weight
136+
scale_inv_name = f"{weight_name}_scale_inv"
137+
try:
138+
# Get scale_inv from the correct file
139+
scale_inv = get_tensor(scale_inv_name)
140+
fp8_weight_names.append(weight_name)
141+
new_state_dict[weight_name] = weight_dequant_cpu(weight, scale_inv)
142+
except KeyError:
143+
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
144+
new_state_dict[weight_name] = weight
145+
else:
146+
new_state_dict[weight_name] = weight
147+
148+
new_safetensor_file = os.path.join(bf16_path, file_name)
149+
save_file(new_state_dict, new_safetensor_file)
150+
151+
# Memory management: keep only the `cache_file_num` most recently used files
152+
while len(loaded_files) > cache_file_num:
153+
oldest_file = next(iter(loaded_files))
154+
del loaded_files[oldest_file]
155+
156+
# Update model index
157+
for weight_name in fp8_weight_names:
158+
scale_inv_name = f"{weight_name}_scale_inv"
159+
if scale_inv_name in weight_map:
160+
weight_map.pop(scale_inv_name)
161+
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
162+
with open(new_model_index_file, "w") as f:
163+
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
164+
165+
166+
if __name__ == "__main__":
167+
parser = ArgumentParser()
168+
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
169+
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
170+
parser.add_argument("--cache-file-num", type=int, required=False, default=2)
171+
args = parser.parse_args()
172+
convert_fp8_to_bf16(args.input_fp8_hf_path, args.output_bf16_hf_path, args.cache_file_num)

0 commit comments

Comments
 (0)