Skip to content

Commit 3fe1db0

Browse files
authored
fix string matching for blocks (#360)
* fix string matching for blocks * format * empty proj_mlp
1 parent 66b3e82 commit 3fe1db0

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

examples/formats/hunyuan_video/convert_to_original_format.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ def convert_lora_sd(diffusers_lora_sd):
1717

1818
prefix = "diffusion_model."
1919

20+
double_block_pattern = "transformer.transformer_blocks"
21+
single_block_pattern = "transformer.single_transformer_blocks"
22+
2023
converted_lora_sd = {}
2124
for key in diffusers_lora_sd.keys():
2225
# double_blocks
23-
if key.startswith("transformer_blocks"):
26+
if key.startswith(double_block_pattern):
2427
# img_attn
2528
if key.endswith("to_q.lora_A.weight"):
2629
# lora_A
@@ -29,7 +32,7 @@ def convert_lora_sd(diffusers_lora_sd):
2932
to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")]
3033

3134
to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0)
32-
qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace(
35+
qkv_A_key = key.replace(double_block_pattern, prefix + "double_blocks").replace(
3336
"attn.to_q", "img_attn_qkv"
3437
)
3538
converted_lora_sd[qkv_A_key] = to_qkv_A
@@ -51,7 +54,7 @@ def convert_lora_sd(diffusers_lora_sd):
5154
to_v_A = diffusers_lora_sd[key.replace("add_q_proj", "add_v_proj")]
5255

5356
to_qkv_A = torch.cat([to_q_A, to_k_A, to_v_A], dim=0)
54-
qkv_A_key = key.replace("transformer_blocks", prefix + "double_blocks").replace(
57+
qkv_A_key = key.replace(double_block_pattern, prefix + "double_blocks").replace(
5558
"attn.add_q_proj", "txt_attn_qkv"
5659
)
5760
converted_lora_sd[qkv_A_key] = to_qkv_A
@@ -68,20 +71,23 @@ def convert_lora_sd(diffusers_lora_sd):
6871
# just rename
6972
for k, v in double_block_patterns.items():
7073
if k in key:
71-
new_key = key.replace(k, v).replace("transformer_blocks", prefix + "double_blocks")
74+
new_key = key.replace(k, v).replace(double_block_pattern, prefix + "double_blocks")
7275
converted_lora_sd[new_key] = diffusers_lora_sd[key]
7376

7477
# single_blocks
75-
elif key.startswith("single_transformer_blocks"):
78+
elif key.startswith(single_block_pattern):
7679
if key.endswith("to_q.lora_A.weight"):
7780
# lora_A
7881
to_q_A = diffusers_lora_sd[key]
7982
to_k_A = diffusers_lora_sd[key.replace("to_q", "to_k")]
8083
to_v_A = diffusers_lora_sd[key.replace("to_q", "to_v")]
81-
proj_mlp_A = diffusers_lora_sd[key.replace("attn.to_q", "proj_mlp")]
82-
84+
proj_mlp_A_key = key.replace("attn.to_q", "proj_mlp")
85+
if proj_mlp_A_key in diffusers_lora_sd:
86+
proj_mlp_A = diffusers_lora_sd[proj_mlp_A_key]
87+
else:
88+
proj_mlp_A = torch.zeros((to_q_A.shape[0], to_q_A.shape[1]))
8389
linear1_A = torch.cat([to_q_A, to_k_A, to_v_A, proj_mlp_A], dim=0)
84-
linear1_A_key = key.replace("single_transformer_blocks", prefix + "single_blocks").replace(
90+
linear1_A_key = key.replace(single_block_pattern, prefix + "single_blocks").replace(
8591
"attn.to_q", "linear1"
8692
)
8793
converted_lora_sd[linear1_A_key] = linear1_A
@@ -90,16 +96,17 @@ def convert_lora_sd(diffusers_lora_sd):
9096
to_q_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_q.lora_B")]
9197
to_k_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_k.lora_B")]
9298
to_v_B = diffusers_lora_sd[key.replace("to_q.lora_A", "to_v.lora_B")]
93-
proj_mlp_B = diffusers_lora_sd[key.replace("attn.to_q.lora_A", "proj_mlp.lora_B")]
94-
99+
proj_mlp_B_key = key.replace("to_q.lora_A", "attn.to_q.lora_B")
100+
if proj_mlp_B_key in diffusers_lora_sd:
101+
proj_mlp_B = diffusers_lora_sd[proj_mlp_B_key]
102+
else:
103+
proj_mlp_B = torch.zeros((to_q_B.shape[0] * 4, to_q_B.shape[1]))
95104
linear1_B = torch.block_diag(to_q_B, to_k_B, to_v_B, proj_mlp_B)
96105
linear1_B_key = linear1_A_key.replace("lora_A", "lora_B")
97106
converted_lora_sd[linear1_B_key] = linear1_B
98107

99108
elif "proj_out" in key:
100-
new_key = key.replace("proj_out", "linear2").replace(
101-
"single_transformer_blocks", prefix + "single_blocks"
102-
)
109+
new_key = key.replace("proj_out", "linear2").replace(single_block_pattern, prefix + "single_blocks")
103110
converted_lora_sd[new_key] = diffusers_lora_sd[key]
104111

105112
else:

0 commit comments

Comments
 (0)