@@ -17,10 +17,13 @@ def convert_lora_sd(diffusers_lora_sd):
17
17
18
18
prefix = "diffusion_model."
19
19
20
+ double_block_pattern = "transformer.transformer_blocks"
21
+ single_block_pattern = "transformer.single_transformer_blocks"
22
+
20
23
converted_lora_sd = {}
21
24
for key in diffusers_lora_sd .keys ():
22
25
# double_blocks
23
- if key .startswith ("transformer_blocks" ):
26
+ if key .startswith (double_block_pattern ):
24
27
# img_attn
25
28
if key .endswith ("to_q.lora_A.weight" ):
26
29
# lora_A
@@ -29,7 +32,7 @@ def convert_lora_sd(diffusers_lora_sd):
29
32
to_v_A = diffusers_lora_sd [key .replace ("to_q" , "to_v" )]
30
33
31
34
to_qkv_A = torch .cat ([to_q_A , to_k_A , to_v_A ], dim = 0 )
32
- qkv_A_key = key .replace ("transformer_blocks" , prefix + "double_blocks" ).replace (
35
+ qkv_A_key = key .replace (double_block_pattern , prefix + "double_blocks" ).replace (
33
36
"attn.to_q" , "img_attn_qkv"
34
37
)
35
38
converted_lora_sd [qkv_A_key ] = to_qkv_A
@@ -51,7 +54,7 @@ def convert_lora_sd(diffusers_lora_sd):
51
54
to_v_A = diffusers_lora_sd [key .replace ("add_q_proj" , "add_v_proj" )]
52
55
53
56
to_qkv_A = torch .cat ([to_q_A , to_k_A , to_v_A ], dim = 0 )
54
- qkv_A_key = key .replace ("transformer_blocks" , prefix + "double_blocks" ).replace (
57
+ qkv_A_key = key .replace (double_block_pattern , prefix + "double_blocks" ).replace (
55
58
"attn.add_q_proj" , "txt_attn_qkv"
56
59
)
57
60
converted_lora_sd [qkv_A_key ] = to_qkv_A
@@ -68,20 +71,23 @@ def convert_lora_sd(diffusers_lora_sd):
68
71
# just rename
69
72
for k , v in double_block_patterns .items ():
70
73
if k in key :
71
- new_key = key .replace (k , v ).replace ("transformer_blocks" , prefix + "double_blocks" )
74
+ new_key = key .replace (k , v ).replace (double_block_pattern , prefix + "double_blocks" )
72
75
converted_lora_sd [new_key ] = diffusers_lora_sd [key ]
73
76
74
77
# single_blocks
75
- elif key .startswith ("single_transformer_blocks" ):
78
+ elif key .startswith (single_block_pattern ):
76
79
if key .endswith ("to_q.lora_A.weight" ):
77
80
# lora_A
78
81
to_q_A = diffusers_lora_sd [key ]
79
82
to_k_A = diffusers_lora_sd [key .replace ("to_q" , "to_k" )]
80
83
to_v_A = diffusers_lora_sd [key .replace ("to_q" , "to_v" )]
81
- proj_mlp_A = diffusers_lora_sd [key .replace ("attn.to_q" , "proj_mlp" )]
82
-
84
+ proj_mlp_A_key = key .replace ("attn.to_q" , "proj_mlp" )
85
+ if proj_mlp_A_key in diffusers_lora_sd :
86
+ proj_mlp_A = diffusers_lora_sd [proj_mlp_A_key ]
87
+ else :
88
+ proj_mlp_A = torch .zeros ((to_q_A .shape [0 ], to_q_A .shape [1 ]))
83
89
linear1_A = torch .cat ([to_q_A , to_k_A , to_v_A , proj_mlp_A ], dim = 0 )
84
- linear1_A_key = key .replace ("single_transformer_blocks" , prefix + "single_blocks" ).replace (
90
+ linear1_A_key = key .replace (single_block_pattern , prefix + "single_blocks" ).replace (
85
91
"attn.to_q" , "linear1"
86
92
)
87
93
converted_lora_sd [linear1_A_key ] = linear1_A
@@ -90,16 +96,17 @@ def convert_lora_sd(diffusers_lora_sd):
90
96
to_q_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_q.lora_B" )]
91
97
to_k_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_k.lora_B" )]
92
98
to_v_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_v.lora_B" )]
93
- proj_mlp_B = diffusers_lora_sd [key .replace ("attn.to_q.lora_A" , "proj_mlp.lora_B" )]
94
-
99
+ proj_mlp_B_key = key .replace ("to_q.lora_A" , "attn.to_q.lora_B" )
100
+ if proj_mlp_B_key in diffusers_lora_sd :
101
+ proj_mlp_B = diffusers_lora_sd [proj_mlp_B_key ]
102
+ else :
103
+ proj_mlp_B = torch .zeros ((to_q_B .shape [0 ] * 4 , to_q_B .shape [1 ]))
95
104
linear1_B = torch .block_diag (to_q_B , to_k_B , to_v_B , proj_mlp_B )
96
105
linear1_B_key = linear1_A_key .replace ("lora_A" , "lora_B" )
97
106
converted_lora_sd [linear1_B_key ] = linear1_B
98
107
99
108
elif "proj_out" in key :
100
- new_key = key .replace ("proj_out" , "linear2" ).replace (
101
- "single_transformer_blocks" , prefix + "single_blocks"
102
- )
109
+ new_key = key .replace ("proj_out" , "linear2" ).replace (single_block_pattern , prefix + "single_blocks" )
103
110
converted_lora_sd [new_key ] = diffusers_lora_sd [key ]
104
111
105
112
else :
0 commit comments