-
Notifications
You must be signed in to change notification settings - Fork 340
/
Copy pathconvert_deepseek_unscanned_ckpt.py
349 lines (312 loc) · 14.8 KB
/
convert_deepseek_unscanned_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r"""Convert weights from a DeepSeek style model to a MaxText one.
Example cmd:
python3 -m MaxText.convert_deepseek_unscanned_ckpt --base_model_path <path/to/meta/ckpt> \
--maxtext_model_path <GCS/path/to/save/new/maxtext/ckpt> --model_size deepseek2-16b
"""
# pylint: disable=line-too-long
# pylint: disable=unsupported-assignment-operation
# pytype: disable=unsupported-operands
import argparse
import pathlib
import os
import gc
import logging
import numpy as np
import torch
import psutil
from tqdm import tqdm
from MaxText import convert_deepseek_ckpt as ds_ckpt
from MaxText import llama_or_mistral_ckpt
from MaxText import max_logging
from MaxText.inference_utils import str2bool
from safetensors import safe_open
def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) -> dict:
"""Convert Huggingface Checkpoint to Jax."""
base_num_decoder_layers = model_params["num_layers"]
first_num_dense_layers = model_params["first_num_dense_layers"]
base_num_query_heads = model_params["base_num_query_heads"]
base_emb_dim = model_params["base_emb_dim"]
num_experts = model_params["num_experts"]
q_lora_rank = model_params["q_lora_rank"]
kv_lora_rank = model_params["kv_lora_rank"]
qk_nope_head_dim = model_params["qk_nope_head_dim"]
qk_rope_head_dim = model_params["qk_rope_head_dim"]
v_head_dim = model_params["v_head_dim"]
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
chkpt_vars = {}
for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
parts = key.split(".")
layer = int(parts[2]) if "layers" in key else 0
if key.endswith("_scale_inv"):
raise ValueError("fp8 checkpoint is not supported.")
if ds_ckpt.is_key_ending_allowed(key, ds_ckpt.MTP_KEYS_SUFFIX):
mapped_key = ds_ckpt.hf_to_maxtext_mapping(layer, num_experts, first_num_dense_layers)[key]
chkpt_vars[mapped_key] = f.get_tensor(key)
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
# initialize the data structure for storing jax_weights
jax_weights = {
"decoder": {
"decoder_norm": {"scale": None},
"logits_dense": {"kernel": None},
},
"token_embedder": {"embedding": None},
}
# decoder norm scale ###########################################
max_logging.log("Processing decoder norm scale")
jax_weights["decoder"]["decoder_norm"]["scale"] = chkpt_vars["decoder_norm.scale"].to(torch.float16).numpy()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
# logits dense #################################################
max_logging.log("Processing logits dense")
jax_weights["decoder"]["logits_dense"]["kernel"] = chkpt_vars["logits_dense.kernel"].to(torch.float16).numpy().transpose()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
# token embedding ##############################################
max_logging.log("Processing token embeddings")
jax_weights["token_embedder"]["embedding"] = chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
layers = {
"dense_layers": first_num_dense_layers,
"moe_layers": base_num_decoder_layers - first_num_dense_layers,
}
# self attention and normalization ###############################################
max_logging.log("Processing self attention and normalization in dense layer")
for layer_key, layer_value in layers.items():
for layer_idx in tqdm(range(layer_value), desc=layer_key, leave=False):
layer_name = f"{layer_key}_{layer_idx}"
if layer_key == "dense_layers":
jax_weights["decoder"].update(
{
layer_name: {
"mlp": {
"wi_0": {"kernel": None},
"wi_1": {"kernel": None},
"wo": {"kernel": None},
},
"self_attention": {
"kv_norm": {"scale": None},
"wkv_a": {"kernel": None},
"wkv_b": {"kernel": None},
"out": {"kernel": None},
},
"pre_self_attention_layer_norm": {"scale": None},
"post_self_attention_layer_norm": {"scale": None},
},
}
)
else:
jax_weights["decoder"].update(
{
layer_name: {
"DeepSeekMoeBlock_0": {
"MoeBlock_0": {
"wi_0": None,
"wi_1": None,
"wo": None,
"gate": {"kernel": None},
},
"shared_experts": {
"wi_0": {"kernel": None},
"wi_1": {"kernel": None},
"wo": {"kernel": None},
},
},
"self_attention": {
"kv_norm": {"scale": None},
"wkv_a": {"kernel": None},
"wkv_b": {"kernel": None},
"out": {"kernel": None},
},
"pre_self_attention_layer_norm": {"scale": None},
"post_self_attention_layer_norm": {"scale": None},
},
}
)
self_attention = jax_weights["decoder"][layer_name]["self_attention"]
pre_self_attention_layer_norm = jax_weights["decoder"][layer_name]["pre_self_attention_layer_norm"]
post_self_attention_layer_norm = jax_weights["decoder"][layer_name]["post_self_attention_layer_norm"]
pre_self_attention = (
chkpt_vars[f"{layer_key}.{layer_idx}.pre_self_attention_layer_norm.scale"].to(torch.float16).numpy()
)
post_self_attention = (
chkpt_vars[f"{layer_key}.{layer_idx}.post_self_attention_layer_norm.scale"].to(torch.float16).numpy()
)
kv_norm = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.kv_norm.scale"].to(torch.float16).numpy().transpose()
wkv_a = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.wkv_a.kernel"].to(torch.float16).numpy().transpose()
wkv_b = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.wkv_b.kernel"].to(torch.float16).numpy().transpose()
out = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.out.kernel"].to(torch.float16).numpy().transpose()
if q_lora_rank != 0:
q_norm = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.q_norm.scale"].to(torch.float16).numpy()
wq_a = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.wq_a.kernel"].to(torch.float16).numpy().transpose()
wq_b = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.wq_b.kernel"].to(torch.float16).numpy().transpose()
else:
query = chkpt_vars[f"{layer_key}.{layer_idx}.self_attention.query.kernel"].to(torch.float16).numpy().transpose()
# reshape to match maxtext
wkv_b = np.reshape(wkv_b, [kv_lora_rank, base_num_query_heads, (qk_nope_head_dim + v_head_dim)])
out = np.reshape(out, [base_num_query_heads, v_head_dim, base_emb_dim])
if q_lora_rank != 0:
wq_b = np.reshape(wq_b, [q_lora_rank, base_num_query_heads, (qk_nope_head_dim + qk_rope_head_dim)])
else:
query = np.reshape(query, [base_emb_dim, base_num_query_heads, (qk_nope_head_dim + qk_rope_head_dim)])
if q_lora_rank != 0:
self_attention.update(
{
"q_norm": {"scale": None},
"wq_a": {"kernel": None},
"wq_b": {"kernel": None},
}
)
else:
self_attention.update({"query": {"kernel": None}})
self_attention["kv_norm"]["scale"] = kv_norm
self_attention["wkv_a"]["kernel"] = wkv_a
self_attention["wkv_b"]["kernel"] = wkv_b
self_attention["out"]["kernel"] = out
pre_self_attention_layer_norm["scale"] = pre_self_attention
post_self_attention_layer_norm["scale"] = post_self_attention
if q_lora_rank != 0:
self_attention["q_norm"]["scale"] = q_norm
self_attention["wq_a"]["kernel"] = wq_a
self_attention["wq_b"]["kernel"] = wq_b
else:
self_attention["query"]["kernel"] = query
jax_weights["decoder"][layer_name]["self_attention"] = self_attention
jax_weights["decoder"][layer_name]["pre_self_attention_layer_norm"] = pre_self_attention_layer_norm
jax_weights["decoder"][layer_name]["post_self_attention_layer_norm"] = post_self_attention_layer_norm
# layer weights ################################################
max_logging.log("Processing layer weights")
for layer_key, layer_value in layers.items():
for layer_idx in tqdm(range(layer_value), desc=layer_key, leave=False):
if layer_key == "dense_layers":
layer_name = f"{layer_key}_{layer_idx}"
mlp = jax_weights["decoder"][layer_name]["mlp"]
wi_0 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_0.kernel"].to(torch.float16).numpy().transpose()
wi_1 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_1.kernel"].to(torch.float16).numpy().transpose()
wo = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wo.kernel"].to(torch.float16).numpy().transpose()
mlp["wi_0"]["kernel"] = wi_0
mlp["wi_1"]["kernel"] = wi_1
mlp["wo"]["kernel"] = wo
jax_weights["decoder"][layer_name]["mlp"] = mlp
else:
layer_name = f"{layer_key}_{layer_idx}"
moe = jax_weights["decoder"][layer_name]["DeepSeekMoeBlock_0"]
if q_lora_rank != 0:
gate_bias = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias"]
.to(torch.float16)
.numpy()
.transpose()
)
gate = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel"]
.to(torch.float16)
.numpy()
.transpose()
)
shared_wi_0 = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel"]
.to(torch.float16)
.numpy()
.transpose()
)
shared_wi_1 = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel"]
.to(torch.float16)
.numpy()
.transpose()
)
shared_wo = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.shared_experts.wo.kernel"]
.to(torch.float16)
.numpy()
.transpose()
)
if q_lora_rank != 0:
moe["MoeBlock_0"]["gate"]["bias"] = gate_bias
moe["MoeBlock_0"]["gate"]["kernel"] = gate
moe["shared_experts"]["wi_0"]["kernel"] = shared_wi_0
moe["shared_experts"]["wi_1"]["kernel"] = shared_wi_1
moe["shared_experts"]["wo"]["kernel"] = shared_wo
for k in tqdm(range(num_experts), desc="experts", leave=False):
wi_0 = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.MoeBlock_0.{k}.wi_0"]
.to(torch.float16)
.numpy()
.transpose()
)
wi_1 = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.MoeBlock_0.{k}.wi_1"]
.to(torch.float16)
.numpy()
.transpose()
)
wo = (
chkpt_vars[f"{layer_key}.{layer_idx}.DeepSeekMoeBlock_0.MoeBlock_0.{k}.wo"]
.to(torch.float16)
.numpy()
.transpose()
)
if moe["MoeBlock_0"]["wi_0"] is None:
stack_shape = (num_experts,)
moe["MoeBlock_0"]["wi_0"] = np.zeros(stack_shape + wi_0.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_1"] = np.zeros(stack_shape + wi_1.shape, dtype=np.float16)
moe["MoeBlock_0"]["wo"] = np.zeros(stack_shape + wo.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_0"][k, ...] = wi_0
moe["MoeBlock_0"]["wi_1"][k, ...] = wi_1
moe["MoeBlock_0"]["wo"][k, ...] = wo
jax_weights["decoder"][layer_name]["DeepSeekMoeBlock_0"] = moe
del chkpt_vars
gc.collect()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
return jax_weights
def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict:
"""
Function to convert the checkpoint at base_model_path into Orbax checkpoint
for MaxText and output jax_weights ready for MaxText.
Args:
base_model_path: Path to the Hugging Face model checkpoint.
model_size: Model size key in MODEL_PARAMS_DICT.
mem_info: A process instance used for memory tracking.
Returns:
The converted JAX weights.
"""
model_params = ds_ckpt.MODEL_PARAMS_DICT[model_size]
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
max_logging.log(f"Loading the base model from {base_model_path}")
return _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
def main() -> None:
parser = argparse.ArgumentParser(description="Convert and save DeepSeek model weights.")
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--maxtext_model_path", type=str, required=True)
parser.add_argument("--model_size", type=str, required=True)
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True)
parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True)
args = parser.parse_args()
if args.model_size not in ds_ckpt.MODEL_PARAMS_DICT:
raise NotImplementedError(f"Model '{args.model_size}' is not supported.")
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}"
mem_info = psutil.Process()
llama_or_mistral_ckpt.save_weights_to_checkpoint(
args.maxtext_model_path,
_convert_to_jax_weights(args.base_model_path, args.model_size, mem_info),
args.simulated_cpu_devices_count,
args.use_ocdbt,
args.use_zarr3,
)
if __name__ == "__main__":
main()