1
1
import os
2
2
import glob
3
+ import zipfile
4
+ import json
5
+ import stat
3
6
from collections import OrderedDict
4
7
5
8
import torch
16
19
17
20
18
21
MAX_MODEL_COUNT = 5
19
- LORA_MODEL_EXTS = ["pt" , "ckpt" , "safetensors" ]
22
+ LORA_MODEL_EXTS = [". pt" , ". ckpt" , ". safetensors" ]
20
23
lora_models = {}
21
24
lora_models_dir = os .path .join (scripts .basedir (), "models/LoRA" )
22
25
os .makedirs (lora_models_dir , exist_ok = True )
23
26
24
27
28
+ def traverse_all_files (curr_path , model_list ):
29
+ f_list = [(os .path .join (curr_path , entry .name ), entry .stat ()) for entry in os .scandir (curr_path )]
30
+ for f_info in f_list :
31
+ fname , fstat = f_info
32
+ if os .path .splitext (fname )[1 ] in LORA_MODEL_EXTS :
33
+ model_list .append (f_info )
34
+ elif stat .S_ISDIR (fstat .st_mode ):
35
+ model_list = traverse_all_files (fname , model_list )
36
+ return model_list
37
+
38
+
39
+ def get_all_models (sort_by , filter_by , path ):
40
+ res = OrderedDict ()
41
+ fileinfos = traverse_all_files (path , [])
42
+ filter_by = filter_by .strip (" " )
43
+ if len (filter_by ) != 0 :
44
+ fileinfos = [x for x in fileinfos if filter_by .lower () in os .path .basename (x [0 ]).lower ()]
45
+ if sort_by == "name" :
46
+ fileinfos = sorted (fileinfos , key = lambda x : os .path .basename (x [0 ]))
47
+ elif sort_by == "date" :
48
+ fileinfos = sorted (fileinfos , key = lambda x : - x [1 ].st_mtime )
49
+ elif sort_by == "path name" :
50
+ fileinfos = sorted (fileinfos )
51
+
52
+ for finfo in fileinfos :
53
+ filename = finfo [0 ]
54
+ name = os .path .splitext (os .path .basename (filename ))[0 ]
55
+ # Prevent a hypothetical "None.pt" from being listed.
56
+ if name != "None" :
57
+ res [name + f"({ sd_models .model_hash (filename )} )" ] = filename
58
+
59
+ return res
60
+
61
+
25
62
def update_lora_models ():
26
63
global lora_models
27
- res = {}
64
+ res = OrderedDict ()
28
65
paths = [lora_models_dir ]
29
66
extra_lora_path = shared .opts .data .get ("additional_networks_extra_lora_path" , None )
30
67
if extra_lora_path and os .path .exists (extra_lora_path ):
31
68
paths .append (extra_lora_path )
32
69
for path in paths :
33
- for ext in LORA_MODEL_EXTS :
34
- for filename in sorted (glob .iglob (os .path .join (path , f"**/*.{ ext } " ), recursive = True )):
35
- name = os .path .splitext (os .path .basename (filename ))[0 ]
36
- # Prevent a hypothetical "None.pt" from being listed.
37
- if name != "None" :
38
- res [name + f"({ sd_models .model_hash (filename )} )" ] = filename
70
+ sort_by = shared .opts .data .get ("additional_networks_sort_models_by" , "name" )
71
+ filter_by = shared .opts .data .get ("additional_networks_model_name_filter" , "" )
72
+ found = get_all_models (sort_by , filter_by , path )
73
+ res = {** found , ** res }
39
74
lora_models = OrderedDict (** {"None" : None }, ** res )
40
75
41
76
@@ -68,9 +103,9 @@ def ui(self, is_img2img):
68
103
for i in range (MAX_MODEL_COUNT ):
69
104
with gr .Row ():
70
105
module = gr .Dropdown (["LoRA" ], label = f"Network module { i + 1 } " , value = "LoRA" )
71
- model = gr .Dropdown (sorted (lora_models .keys ()),
72
- label = f"Model { i + 1 } " ,
73
- value = "None" )
106
+ model = gr .Dropdown (list (lora_models .keys ()),
107
+ label = f"Model { i + 1 } " ,
108
+ value = "None" )
74
109
75
110
weight = gr .Slider (label = f"Weight { i + 1 } " , value = 1.0 , minimum = - 1.0 , maximum = 2.0 , step = .05 )
76
111
ctrls .extend ((module , model , weight ))
@@ -90,7 +125,7 @@ def refresh_all_models(*dropdowns):
90
125
selected = dd
91
126
else :
92
127
selected = "None"
93
- update = gr .Dropdown .update (value = selected , choices = sorted (lora_models .keys ()))
128
+ update = gr .Dropdown .update (value = selected , choices = list (lora_models .keys ()))
94
129
updates .append (update )
95
130
return updates
96
131
@@ -106,11 +141,11 @@ def set_infotext_fields(self, p, params):
106
141
if model is None or model == "None" or len (model ) == 0 or weight == 0 :
107
142
continue
108
143
p .extra_generation_params .update ({
109
- "AddNet Enabled" : True ,
110
- f"AddNet Module { i + 1 } " : module ,
111
- f"AddNet Model { i + 1 } " : model ,
112
- f"AddNet Weight { i + 1 } " : weight ,
113
- })
144
+ "AddNet Enabled" : True ,
145
+ f"AddNet Module { i + 1 } " : module ,
146
+ f"AddNet Model { i + 1 } " : model ,
147
+ f"AddNet Weight { i + 1 } " : weight ,
148
+ })
114
149
115
150
def process (self , p , * args ):
116
151
unet = p .sd_model .model .diffusion_model
@@ -185,9 +220,54 @@ def restore_networks():
185
220
self .set_infotext_fields (p , self .latest_params )
186
221
187
222
223
+ def on_ui_tabs ():
224
+ with gr .Blocks (analytics_enabled = False ) as additional_networks_interface :
225
+ with gr .Row ().style (equal_height = False ):
226
+ with gr .Column (variant = 'panel' ):
227
+ gr .HTML (value = "Inspect additional network metadata" )
228
+
229
+ with gr .Row ():
230
+ module = gr .Dropdown (["LoRA" ], label = f"Network module" , value = "LoRA" , interactive = True )
231
+ model = gr .Dropdown (list (lora_models .keys ()), label = f"Model" , value = "None" , interactive = True )
232
+ modules .ui .create_refresh_button (model , update_lora_models , lambda : {"choices" : list (lora_models .keys ())}, "refresh_lora_models" )
233
+ with gr .Column ():
234
+ metadata_view = gr .JSON (value = "test" )
235
+
236
+ def update_metadata (module , model ):
237
+ if model == "None" :
238
+ return {}
239
+ model_path = lora_models .get (model , None )
240
+ if model_path is None :
241
+ return f"model not found: { model } "
242
+
243
+ if model_path .startswith ("\" " ) and model_path .endswith ("\" " ): # trim '"' at start/end
244
+ model_path = model_path [1 :- 1 ]
245
+ if not os .path .exists (model_path ):
246
+ return f"file not found: { model_path } "
247
+
248
+ metadata = None
249
+ if module == "LoRA" :
250
+ if os .path .splitext (model_path )[1 ] == '.safetensors' :
251
+ from safetensors .torch import safe_open
252
+ with safe_open (model_path , framework = "pt" ) as f : # default device is 'cpu'
253
+ metadata = f .metadata ()
254
+
255
+ if metadata is None :
256
+ return "No metadata found."
257
+ else :
258
+ return metadata
259
+
260
+ model .change (update_metadata , inputs = [module , model ], outputs = [metadata_view ])
261
+
262
+ return [(additional_networks_interface , "Additional Networks" , "additional_networks" )]
263
+
264
+
188
265
def on_ui_settings ():
189
- section = ('additional_networks' , "Additional Networks" )
190
- shared .opts .add_option ("additional_networks_extra_lora_path" , shared .OptionInfo ("" , "Extra path to scan for LoRA models (e.g. training output directory)" , section = section ))
266
+ section = ('additional_networks' , "Additional Networks" )
267
+ shared .opts .add_option ("additional_networks_extra_lora_path" , shared .OptionInfo ("" , "Extra path to scan for LoRA models (e.g. training output directory)" , section = section ))
268
+ shared .opts .add_option ("additional_networks_sort_models_by" , shared .OptionInfo ("name" , "Sort LoRA models by" , gr .Radio , {"choices" : ["name" , "date" , "path name" ]}, section = section ))
269
+ shared .opts .add_option ("additional_networks_model_name_filter" , shared .OptionInfo ("" , "LoRA model name filter" , section = section ))
191
270
192
271
272
+ script_callbacks .on_ui_tabs (on_ui_tabs )
193
273
script_callbacks .on_ui_settings (on_ui_settings )
0 commit comments