Skip to content

Commit dce71df

Browse files
committed
re-work multi --styles-file
--styles-file change to append str --styles-file is [] then defaults to [styles.csv] --styles-file accepts paths or paths with wildcard "*" the first `--styles-file` entry is use as the default styles file path if thers a "*" wildcard then the first matching file is used if no match is found, create a new "styles.csv" in the same dir as the first path when saving a new style it will be save in the default styles file when saving a existing style, it will be saved to file it belongs to order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file
1 parent f939bce commit dce71df

File tree

4 files changed

+52
-47
lines changed

4 files changed

+52
-47
lines changed

modules/cmd_args.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
8989
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
9090
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
91-
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
91+
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
9292
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
9393
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
9494
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)

modules/shared.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import sys
23

34
import gradio as gr
@@ -11,7 +12,7 @@
1112

1213
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
1314
parallel_processing_allowed = True
14-
styles_filename = cmd_opts.styles_file
15+
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
1516
config_filename = cmd_opts.ui_settings_file
1617
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
1718

modules/styles.py

+43-42
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from pathlib import Path
12
import csv
2-
import fnmatch
33
import os
4-
import os.path
54
import typing
65
import shutil
76

87

98
class PromptStyle(typing.NamedTuple):
109
name: str
11-
prompt: str
12-
negative_prompt: str
13-
path: str = None
10+
prompt: str | None
11+
negative_prompt: str | None
12+
path: str | None = None
1413

1514

1615
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
7978

8079

8180
class StyleDatabase:
82-
def __init__(self, path: str):
81+
def __init__(self, paths: list[str | Path]):
8382
self.no_style = PromptStyle("None", "", "", None)
8483
self.styles = {}
85-
self.path = path
86-
87-
folder, file = os.path.split(self.path)
88-
filename, _, ext = file.partition('*')
89-
self.default_path = os.path.join(folder, filename + ext)
84+
self.paths = paths
85+
self.all_styles_files: list[Path] = []
86+
87+
folder, file = os.path.split(self.paths[0])
88+
if '*' in file:
89+
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
90+
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
91+
self.paths.insert(0, self.default_path)
92+
else:
93+
self.default_path = Path(self.paths[0])
9094

9195
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
9296

@@ -99,33 +103,31 @@ def reload(self):
99103
"""
100104
self.styles.clear()
101105

102-
path, filename = os.path.split(self.path)
103-
104-
if "*" in filename:
105-
fileglob = filename.split("*")[0] + "*.csv"
106-
filelist = []
107-
for file in os.listdir(path):
108-
if fnmatch.fnmatch(file, fileglob):
109-
filelist.append(file)
110-
# Add a visible divider to the style list
111-
half_len = round(len(file) / 2)
112-
divider = f"{'-' * (20 - half_len)} {file.upper()}"
113-
divider = f"{divider} {'-' * (40 - len(divider))}"
114-
self.styles[divider] = PromptStyle(
115-
f"{divider}", None, None, "do_not_save"
116-
)
117-
# Add styles from this CSV file
118-
self.load_from_csv(os.path.join(path, file))
119-
if len(filelist) == 0:
120-
print(f"No styles found in {path} matching {fileglob}")
121-
return
122-
elif not os.path.exists(self.path):
123-
print(f"Style database not found: {self.path}")
124-
return
125-
else:
126-
self.load_from_csv(self.path)
127-
128-
def load_from_csv(self, path: str):
106+
# scans for all styles files
107+
all_styles_files = []
108+
for pattern in self.paths:
109+
base_dir, file_pattern = os.path.split(pattern)
110+
if '*' in file_pattern:
111+
found_files = Path(base_dir).glob(file_pattern)
112+
[all_styles_files.append(file) for file in found_files]
113+
else:
114+
# if os.path.exists(pattern):
115+
all_styles_files.append(Path(pattern))
116+
117+
# Remove any duplicate entries
118+
seen = set()
119+
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
120+
121+
for styles_file in self.all_styles_files:
122+
if len(all_styles_files) > 1:
123+
# add divider when more than styles file
124+
# '---------------- STYLES ----------------'
125+
divider = f' {styles_file.stem.upper()} '.center(40, '-')
126+
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
127+
if styles_file.is_file():
128+
self.load_from_csv(styles_file)
129+
130+
def load_from_csv(self, path: str | Path):
129131
with open(path, "r", encoding="utf-8-sig", newline="") as file:
130132
reader = csv.DictReader(file, skipinitialspace=True)
131133
for row in reader:
@@ -137,19 +139,19 @@ def load_from_csv(self, path: str):
137139
negative_prompt = row.get("negative_prompt", "")
138140
# Add style to database
139141
self.styles[row["name"]] = PromptStyle(
140-
row["name"], prompt, negative_prompt, path
142+
row["name"], prompt, negative_prompt, str(path)
141143
)
142144

143145
def get_style_paths(self) -> set:
144146
"""Returns a set of all distinct paths of files that styles are loaded from."""
145147
# Update any styles without a path to the default path
146148
for style in list(self.styles.values()):
147149
if not style.path:
148-
self.styles[style.name] = style._replace(path=self.default_path)
150+
self.styles[style.name] = style._replace(path=str(self.default_path))
149151

150152
# Create a list of all distinct paths, including the default path
151153
style_paths = set()
152-
style_paths.add(self.default_path)
154+
style_paths.add(str(self.default_path))
153155
for _, style in self.styles.items():
154156
if style.path:
155157
style_paths.add(style.path)
@@ -177,7 +179,6 @@ def apply_negative_styles_to_prompt(self, prompt, styles):
177179

178180
def save_styles(self, path: str = None) -> None:
179181
# The path argument is deprecated, but kept for backwards compatibility
180-
_ = path
181182

182183
style_paths = self.get_style_paths()
183184

modules/ui_prompt_styles.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
2222
if not name:
2323
return gr.update(visible=False)
2424

25-
style = styles.PromptStyle(name, prompt, negative_prompt)
25+
existing_style = shared.prompt_styles.styles.get(name)
26+
path = existing_style.path if existing_style is not None else None
27+
28+
style = styles.PromptStyle(name, prompt, negative_prompt, path)
2629
shared.prompt_styles.styles[style.name] = style
27-
shared.prompt_styles.save_styles(shared.styles_filename)
30+
shared.prompt_styles.save_styles()
2831

2932
return gr.update(visible=True)
3033

@@ -34,7 +37,7 @@ def delete_style(name):
3437
return
3538

3639
shared.prompt_styles.styles.pop(name, None)
37-
shared.prompt_styles.save_styles(shared.styles_filename)
40+
shared.prompt_styles.save_styles()
3841

3942
return '', '', ''
4043

0 commit comments

Comments
 (0)