1
+ import os
2
+ import argparse
3
+ import torch
4
+ from safetensors .torch import load_file , save_file
5
+
6
+
7
+ def convert_lora_sd (diffusers_lora_sd ):
8
+ double_block_patterns = {
9
+ "attn.to_out.0" : "img_attn.proj" ,
10
+ "ff.net.0.proj" : "img_mlp.0" ,
11
+ "ff.net.2" : "img_mlp.2" ,
12
+ "attn.to_add_out" : "txt_attn.proj" ,
13
+ "ff_context.net.0.proj" : "txt_mlp.0" ,
14
+ "ff_context.net.2" : "txt_mlp.2" ,
15
+ }
16
+
17
+ prefix = "diffusion_model."
18
+
19
+ converted_lora_sd = {}
20
+ for key in diffusers_lora_sd .keys ():
21
+ # double_blocks
22
+ if key .startswith ("transformer_blocks" ):
23
+ # img_attn
24
+ if key .endswith ("to_q.lora_A.weight" ):
25
+ # lora_A
26
+ to_q_A = diffusers_lora_sd [key ]
27
+ to_k_A = diffusers_lora_sd [key .replace ("to_q" , "to_k" )]
28
+ to_v_A = diffusers_lora_sd [key .replace ("to_q" , "to_v" )]
29
+
30
+ to_qkv_A = torch .cat ([to_q_A , to_k_A , to_v_A ], dim = 0 )
31
+ qkv_A_key = key .replace ("transformer_blocks" , prefix + "double_blocks" ).replace ("attn.to_q" , "img_attn.qkv" )
32
+ converted_lora_sd [qkv_A_key ] = to_qkv_A
33
+
34
+ # lora_B
35
+ to_q_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_q.lora_B" )]
36
+ to_k_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_k.lora_B" )]
37
+ to_v_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_v.lora_B" )]
38
+
39
+ to_qkv_B = torch .block_diag (to_q_B , to_k_B , to_v_B )
40
+ qkv_B_key = qkv_A_key .replace ("lora_A" , "lora_B" )
41
+ converted_lora_sd [qkv_B_key ] = to_qkv_B
42
+
43
+ # txt_attn
44
+ elif key .endswith ("add_q_proj.lora_A.weight" ):
45
+ # lora_A
46
+ to_q_A = diffusers_lora_sd [key ]
47
+ to_k_A = diffusers_lora_sd [key .replace ("add_q_proj" , "add_k_proj" )]
48
+ to_v_A = diffusers_lora_sd [key .replace ("add_q_proj" , "add_v_proj" )]
49
+
50
+ to_qkv_A = torch .cat ([to_q_A , to_k_A , to_v_A ], dim = 0 )
51
+ qkv_A_key = key .replace ("transformer_blocks" , prefix + "double_blocks" ).replace ("attn.add_q_proj" , "txt_attn.qkv" )
52
+ converted_lora_sd [qkv_A_key ] = to_qkv_A
53
+
54
+ # lora_B
55
+ to_q_B = diffusers_lora_sd [key .replace ("add_q_proj.lora_A" , "add_q_proj.lora_B" )]
56
+ to_k_B = diffusers_lora_sd [key .replace ("add_q_proj.lora_A" , "add_k_proj.lora_B" )]
57
+ to_v_B = diffusers_lora_sd [key .replace ("add_q_proj.lora_A" , "add_v_proj.lora_B" )]
58
+
59
+ to_qkv_B = torch .block_diag (to_q_B , to_k_B , to_v_B )
60
+ qkv_B_key = qkv_A_key .replace ("lora_A" , "lora_B" )
61
+ converted_lora_sd [qkv_B_key ] = to_qkv_B
62
+
63
+ # just rename
64
+ for k , v in double_block_patterns .items ():
65
+ if k in key :
66
+ new_key = key .replace (k , v ).replace ("transformer_blocks" , prefix + "double_blocks" )
67
+ converted_lora_sd [new_key ] = diffusers_lora_sd [key ]
68
+
69
+ # single_blocks
70
+ elif key .startswith ("single_transformer_blocks" ):
71
+ if key .endswith ("to_q.lora_A.weight" ):
72
+ # lora_A
73
+ to_q_A = diffusers_lora_sd [key ]
74
+ to_k_A = diffusers_lora_sd [key .replace ("to_q" , "to_k" )]
75
+ to_v_A = diffusers_lora_sd [key .replace ("to_q" , "to_v" )]
76
+ proj_mlp_A = diffusers_lora_sd [key .replace ("attn.to_q" , "proj_mlp" )]
77
+
78
+ linear1_A = torch .cat ([to_q_A , to_k_A , to_v_A , proj_mlp_A ], dim = 0 )
79
+ linear1_A_key = key .replace ("single_transformer_blocks" , prefix + "single_blocks" ).replace ("attn.to_q" , "linear1" )
80
+ converted_lora_sd [linear1_A_key ] = linear1_A
81
+
82
+ # lora_B
83
+ to_q_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_q.lora_B" )]
84
+ to_k_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_k.lora_B" )]
85
+ to_v_B = diffusers_lora_sd [key .replace ("to_q.lora_A" , "to_v.lora_B" )]
86
+ proj_mlp_B = diffusers_lora_sd [key .replace ("attn.to_q.lora_A" , "proj_mlp.lora_B" )]
87
+
88
+ linear1_B = torch .block_diag (to_q_B , to_k_B , to_v_B , proj_mlp_B )
89
+ linear1_B_key = linear1_A_key .replace ("lora_A" , "lora_B" )
90
+ converted_lora_sd [linear1_B_key ] = linear1_B
91
+
92
+ elif "proj_out" in key :
93
+ new_key = key .replace ("proj_out" , "linear2" ).replace ("single_transformer_blocks" , prefix + "single_blocks" )
94
+ converted_lora_sd [new_key ] = diffusers_lora_sd [key ]
95
+
96
+ else :
97
+ print (f"unknown or not implemented: { key } " )
98
+
99
+ return converted_lora_sd
100
+
101
+
102
+ def get_args ():
103
+ parser = argparse .ArgumentParser ()
104
+ parser .add_argument ("--input_lora" , type = str , required = True , help = "Path to LoRA .safetensors" )
105
+ parser .add_argument ("--alpha" , type = float , default = None , help = "Optional alpha value, defaults to rank" )
106
+ parser .add_argument ("--dtype" , type = str , default = None , help = "Optional dtype (bfloat16, float16, float32), defaults to input dtype" )
107
+ parser .add_argument ("--debug" , action = "store_true" , help = "Print converted keys instead of saving" )
108
+ return parser .parse_args ()
109
+
110
+
111
+ if __name__ == "__main__" :
112
+ args = get_args ()
113
+
114
+ converted_lora_sd = convert_lora_sd (load_file (args .input_lora ))
115
+
116
+ if args .alpha is not None :
117
+ for key in list (converted_lora_sd .keys ()):
118
+ if "lora_A" in key :
119
+ alpha_name = key .replace (".lora_A.weight" , ".alpha" )
120
+ converted_lora_sd [alpha_name ] = torch .tensor ([args .alpha ], dtype = converted_lora_sd [key ].dtype )
121
+
122
+ dtype = None
123
+ if args .dtype == "bfloat16" :
124
+ dtype = torch .bfloat16
125
+ elif args .dtype == "float16" :
126
+ dtype = torch .float16
127
+ elif args .dtype == "float32" :
128
+ dtype = torch .float32
129
+
130
+ if dtype is not None :
131
+ dtype_min = torch .finfo (dtype ).min
132
+ dtype_max = torch .finfo (dtype ).max
133
+ for key in converted_lora_sd .keys ():
134
+ if converted_lora_sd [key ].min () < dtype_min or converted_lora_sd [key ].max () > dtype_max :
135
+ print (f"warning: { key } has values outside of { dtype } { dtype_min } { dtype_max } range" )
136
+ converted_lora_sd [key ] = converted_lora_sd [key ].to (dtype )
137
+
138
+ if args .debug :
139
+ for key in sorted (list (converted_lora_sd .keys ())):
140
+ print (key , converted_lora_sd [key ].shape , converted_lora_sd [key ].dtype )
141
+ exit ()
142
+
143
+ output_path = os .path .splitext (args .input_lora )[0 ] + "_converted.safetensors"
144
+ save_file (converted_lora_sd , output_path )
145
+ print (f"saved to { output_path } " )
0 commit comments