Skip to content

Commit 098cf0e

Browse files
authored
add script to convert HunyuanVideo diffusers lora to original weights format (#255)
1 parent 836ac78 commit 098cf0e

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import argparse
3+
import torch
4+
from safetensors.torch import load_file, save_file
5+
6+
7+
def convert_lora_sd(diffusers_lora_sd):
8+
double_block_patterns = {
9+
"attn.to_out.0": "img_attn.proj",
10+
"ff.net.0.proj": "img_mlp.0",
11+
"ff.net.2": "img_mlp.2",
12+
"attn.to_add_out": "txt_attn.proj",
13+
"ff_context.net.0.proj": "txt_mlp.0",
14+
"ff_context.net.2": "txt_mlp.2",
15+
}
16+
17+
prefix = "diffusion_model."
18+
19+
converted_lora_sd = {}
20+
for key in diffusers_lora_sd.keys():
21+
# double_blocks
22+
if key.startswith("transformer_blocks"):
23+
# img_attn
24+
if key.endswith("to_q.lora_A.weight"):
25+
# lora_A
26+
to_q_A = diffusers_lora_sd[key]
27+
to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")]
28+
to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")]
29+
30+
to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0)
31+
qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace("attn.to_q", "img_attn.qkv")
32+
converted_lora_sd[qkv_A_key] = to_qkv_A
33+
34+
# lora_B
35+
to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")]
36+
to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")]
37+
to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")]
38+
39+
to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B)
40+
qkv_B_key = qkv_A_key.replace("lora_A", "lora_B")
41+
converted_lora_sd[qkv_B_key] = to_qkv_B
42+
43+
# txt_attn
44+
elif key.endswith("add_q_proj.lora_A.weight"):
45+
# lora_A
46+
to_q_A = diffusers_lora_sd[key]
47+
to_k_A = diffusers_lora_sd[key.replace("add_q_proj", "add_k_proj")]
48+
to_v_A = diffusers_lora_sd[key.replace("add_q_proj", "add_v_proj")]
49+
50+
to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0)
51+
qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace("attn.add_q_proj", "txt_attn.qkv")
52+
converted_lora_sd[qkv_A_key] = to_qkv_A
53+
54+
# lora_B
55+
to_q_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_q_proj.lora_B")]
56+
to_k_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_k_proj.lora_B")]
57+
to_v_B = diffusers_lora_sd[key.replace("add_q_proj.lora_A", "add_v_proj.lora_B")]
58+
59+
to_qkv_B = torch.block_diag(to_q_B, to_k_B, to_v_B)
60+
qkv_B_key = qkv_A_key.replace("lora_A", "lora_B")
61+
converted_lora_sd[qkv_B_key] = to_qkv_B
62+
63+
# just rename
64+
for k, v in double_block_patterns.items():
65+
if k in key:
66+
new_key = key.replace(k, v).replace("transformer_blocks", prefix + "double_blocks")
67+
converted_lora_sd[new_key] = diffusers_lora_sd[key]
68+
69+
# single_blocks
70+
elif key.startswith("single_transformer_blocks"):
71+
if key.endswith("to_q.lora_A.weight"):
72+
# lora_A
73+
to_q_A = diffusers_lora_sd[key]
74+
to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")]
75+
to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")]
76+
proj_mlp_A = diffusers_lora_sd[key.replace("attn.to_q", "proj_mlp")]
77+
78+
linear1_A = torch.cat([to_q_A, to_k_A, to_v_A, proj_mlp_A], dim=0)
79+
linear1_A_key = key.replace("single_transformer_blocks", prefix + "single_blocks").replace("attn.to_q", "linear1")
80+
converted_lora_sd[linear1_A_key] = linear1_A
81+
82+
# lora_B
83+
to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")]
84+
to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")]
85+
to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")]
86+
proj_mlp_B = diffusers_lora_sd[key.replace("attn.to_q.lora_A", "proj_mlp.lora_B")]
87+
88+
linear1_B = torch.block_diag(to_q_B, to_k_B, to_v_B, proj_mlp_B)
89+
linear1_B_key = linear1_A_key.replace("lora_A", "lora_B")
90+
converted_lora_sd[linear1_B_key] = linear1_B
91+
92+
elif "proj_out" in key:
93+
new_key = key.replace("proj_out", "linear2").replace("single_transformer_blocks", prefix + "single_blocks")
94+
converted_lora_sd[new_key] = diffusers_lora_sd[key]
95+
96+
else:
97+
print(f"unknown or not implemented: {key}")
98+
99+
return converted_lora_sd
100+
101+
102+
def get_args():
103+
parser = argparse.ArgumentParser()
104+
parser.add_argument("--input_lora", type=str, required=True, help="Path to LoRA .safetensors")
105+
parser.add_argument("--alpha", type=float, default=None, help="Optional alpha value, defaults to rank")
106+
parser.add_argument("--dtype", type=str, default=None, help="Optional dtype (bfloat16, float16, float32), defaults to input dtype")
107+
parser.add_argument("--debug", action="store_true", help="Print converted keys instead of saving")
108+
return parser.parse_args()
109+
110+
111+
if __name__ == "__main__":
112+
args = get_args()
113+
114+
converted_lora_sd = convert_lora_sd(load_file(args.input_lora))
115+
116+
if args.alpha is not None:
117+
for key in list(converted_lora_sd.keys()):
118+
if "lora_A" in key:
119+
alpha_name = key.replace(".lora_A.weight", ".alpha")
120+
converted_lora_sd[alpha_name] = torch.tensor([args.alpha], dtype=converted_lora_sd[key].dtype)
121+
122+
dtype = None
123+
if args.dtype == "bfloat16":
124+
dtype = torch.bfloat16
125+
elif args.dtype == "float16":
126+
dtype = torch.float16
127+
elif args.dtype == "float32":
128+
dtype = torch.float32
129+
130+
if dtype is not None:
131+
dtype_min = torch.finfo(dtype).min
132+
dtype_max = torch.finfo(dtype).max
133+
for key in converted_lora_sd.keys():
134+
if converted_lora_sd[key].min() < dtype_min or converted_lora_sd[key].max() > dtype_max:
135+
print(f"warning: {key} has values outside of {dtype} {dtype_min} {dtype_max} range")
136+
converted_lora_sd[key] = converted_lora_sd[key].to(dtype)
137+
138+
if args.debug:
139+
for key in sorted(list(converted_lora_sd.keys())):
140+
print(key, converted_lora_sd[key].shape, converted_lora_sd[key].dtype)
141+
exit()
142+
143+
output_path = os.path.splitext(args.input_lora)[0] + "_converted.safetensors"
144+
save_file(converted_lora_sd, output_path)
145+
print(f"saved to {output_path}")

0 commit comments

Comments
 (0)