Skip to content

Commit f7c823a

Browse files
authored
Merge pull request #25 from kohya-ss/dev
metadata inspect, model sorting/filtering
2 parents f9ef5d5 + a41a988 commit f7c823a

File tree

1 file changed

+99
-19
lines changed

1 file changed

+99
-19
lines changed

scripts/additional_networks.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
import glob
3+
import zipfile
4+
import json
5+
import stat
36
from collections import OrderedDict
47

58
import torch
@@ -16,26 +19,58 @@
1619

1720

1821
MAX_MODEL_COUNT = 5
19-
LORA_MODEL_EXTS = ["pt", "ckpt", "safetensors"]
22+
LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"]
2023
lora_models = {}
2124
lora_models_dir = os.path.join(scripts.basedir(), "models/LoRA")
2225
os.makedirs(lora_models_dir, exist_ok=True)
2326

2427

28+
def traverse_all_files(curr_path, model_list):
29+
f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)]
30+
for f_info in f_list:
31+
fname, fstat = f_info
32+
if os.path.splitext(fname)[1] in LORA_MODEL_EXTS:
33+
model_list.append(f_info)
34+
elif stat.S_ISDIR(fstat.st_mode):
35+
model_list = traverse_all_files(fname, model_list)
36+
return model_list
37+
38+
39+
def get_all_models(sort_by, filter_by, path):
40+
res = OrderedDict()
41+
fileinfos = traverse_all_files(path, [])
42+
filter_by = filter_by.strip(" ")
43+
if len(filter_by) != 0:
44+
fileinfos = [x for x in fileinfos if filter_by.lower() in os.path.basename(x[0]).lower()]
45+
if sort_by == "name":
46+
fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
47+
elif sort_by == "date":
48+
fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
49+
elif sort_by == "path name":
50+
fileinfos = sorted(fileinfos)
51+
52+
for finfo in fileinfos:
53+
filename = finfo[0]
54+
name = os.path.splitext(os.path.basename(filename))[0]
55+
# Prevent a hypothetical "None.pt" from being listed.
56+
if name != "None":
57+
res[name + f"({sd_models.model_hash(filename)})"] = filename
58+
59+
return res
60+
61+
2562
def update_lora_models():
2663
global lora_models
27-
res = {}
64+
res = OrderedDict()
2865
paths = [lora_models_dir]
2966
extra_lora_path = shared.opts.data.get("additional_networks_extra_lora_path", None)
3067
if extra_lora_path and os.path.exists(extra_lora_path):
3168
paths.append(extra_lora_path)
3269
for path in paths:
33-
for ext in LORA_MODEL_EXTS:
34-
for filename in sorted(glob.iglob(os.path.join(path, f"**/*.{ext}"), recursive=True)):
35-
name = os.path.splitext(os.path.basename(filename))[0]
36-
# Prevent a hypothetical "None.pt" from being listed.
37-
if name != "None":
38-
res[name + f"({sd_models.model_hash(filename)})"] = filename
70+
sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
71+
filter_by = shared.opts.data.get("additional_networks_model_name_filter", "")
72+
found = get_all_models(sort_by, filter_by, path)
73+
res = {**found, **res}
3974
lora_models = OrderedDict(**{"None": None}, **res)
4075

4176

@@ -68,9 +103,9 @@ def ui(self, is_img2img):
68103
for i in range(MAX_MODEL_COUNT):
69104
with gr.Row():
70105
module = gr.Dropdown(["LoRA"], label=f"Network module {i+1}", value="LoRA")
71-
model = gr.Dropdown(sorted(lora_models.keys()),
72-
label=f"Model {i+1}",
73-
value="None")
106+
model = gr.Dropdown(list(lora_models.keys()),
107+
label=f"Model {i+1}",
108+
value="None")
74109

75110
weight = gr.Slider(label=f"Weight {i+1}", value=1.0, minimum=-1.0, maximum=2.0, step=.05)
76111
ctrls.extend((module, model, weight))
@@ -90,7 +125,7 @@ def refresh_all_models(*dropdowns):
90125
selected = dd
91126
else:
92127
selected = "None"
93-
update = gr.Dropdown.update(value=selected, choices=sorted(lora_models.keys()))
128+
update = gr.Dropdown.update(value=selected, choices=list(lora_models.keys()))
94129
updates.append(update)
95130
return updates
96131

@@ -106,11 +141,11 @@ def set_infotext_fields(self, p, params):
106141
if model is None or model == "None" or len(model) == 0 or weight == 0:
107142
continue
108143
p.extra_generation_params.update({
109-
"AddNet Enabled": True,
110-
f"AddNet Module {i+1}": module,
111-
f"AddNet Model {i+1}": model,
112-
f"AddNet Weight {i+1}": weight,
113-
})
144+
"AddNet Enabled": True,
145+
f"AddNet Module {i+1}": module,
146+
f"AddNet Model {i+1}": model,
147+
f"AddNet Weight {i+1}": weight,
148+
})
114149

115150
def process(self, p, *args):
116151
unet = p.sd_model.model.diffusion_model
@@ -185,9 +220,54 @@ def restore_networks():
185220
self.set_infotext_fields(p, self.latest_params)
186221

187222

223+
def on_ui_tabs():
224+
with gr.Blocks(analytics_enabled=False) as additional_networks_interface:
225+
with gr.Row().style(equal_height=False):
226+
with gr.Column(variant='panel'):
227+
gr.HTML(value="Inspect additional network metadata")
228+
229+
with gr.Row():
230+
module = gr.Dropdown(["LoRA"], label=f"Network module", value="LoRA", interactive=True)
231+
model = gr.Dropdown(list(lora_models.keys()), label=f"Model", value="None", interactive=True)
232+
modules.ui.create_refresh_button(model, update_lora_models, lambda: {"choices": list(lora_models.keys())}, "refresh_lora_models")
233+
with gr.Column():
234+
metadata_view = gr.JSON(value="test")
235+
236+
def update_metadata(module, model):
237+
if model == "None":
238+
return {}
239+
model_path = lora_models.get(model, None)
240+
if model_path is None:
241+
return f"model not found: {model}"
242+
243+
if model_path.startswith("\"") and model_path.endswith("\""): # trim '"' at start/end
244+
model_path = model_path[1:-1]
245+
if not os.path.exists(model_path):
246+
return f"file not found: {model_path}"
247+
248+
metadata = None
249+
if module == "LoRA":
250+
if os.path.splitext(model_path)[1] == '.safetensors':
251+
from safetensors.torch import safe_open
252+
with safe_open(model_path, framework="pt") as f: # default device is 'cpu'
253+
metadata = f.metadata()
254+
255+
if metadata is None:
256+
return "No metadata found."
257+
else:
258+
return metadata
259+
260+
model.change(update_metadata, inputs=[module, model], outputs=[metadata_view])
261+
262+
return [(additional_networks_interface, "Additional Networks", "additional_networks")]
263+
264+
188265
def on_ui_settings():
189-
section = ('additional_networks', "Additional Networks")
190-
shared.opts.add_option("additional_networks_extra_lora_path", shared.OptionInfo("", "Extra path to scan for LoRA models (e.g. training output directory)", section=section))
266+
section = ('additional_networks', "Additional Networks")
267+
shared.opts.add_option("additional_networks_extra_lora_path", shared.OptionInfo("", "Extra path to scan for LoRA models (e.g. training output directory)", section=section))
268+
shared.opts.add_option("additional_networks_sort_models_by", shared.OptionInfo("name", "Sort LoRA models by", gr.Radio, {"choices": ["name", "date", "path name"]}, section=section))
269+
shared.opts.add_option("additional_networks_model_name_filter", shared.OptionInfo("", "LoRA model name filter", section=section))
191270

192271

272+
script_callbacks.on_ui_tabs(on_ui_tabs)
193273
script_callbacks.on_ui_settings(on_ui_settings)

0 commit comments

Comments
 (0)