Skip to content

Commit 567b2eb

Browse files
Charlie Joyntruchej
Charlie Joynt
authored andcommitted
Allow use of mutiple styles csv files
* AUTOMATIC1111#14122 Fix edge case where style text has multiple {prompt} placeholders * AUTOMATIC1111#14005
1 parent 322ca56 commit 567b2eb

File tree

1 file changed

+171
-32
lines changed

1 file changed

+171
-32
lines changed

modules/styles.py

+171-32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import csv
2+
import fnmatch
23
import os
34
import os.path
45
import re
@@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple):
1011
name: str
1112
prompt: str
1213
negative_prompt: str
14+
path: str = None
15+
16+
17+
def clean_text(text: str) -> str:
18+
"""
19+
Iterating through a list of regular expressions and replacement strings, we
20+
clean up the prompt and style text to make it easier to match against each
21+
other.
22+
"""
23+
re_list = [
24+
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
25+
("multiple spaces", re.compile("\s{2,}"), " "),
26+
]
27+
for _, regex, replace in re_list:
28+
text = regex.sub(replace, text)
29+
30+
return text.strip(", ")
1331

1432

1533
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles):
2644
for style in styles:
2745
prompt = merge_prompts(style, prompt)
2846

29-
return prompt
47+
return clean_text(prompt)
3048

3149

32-
re_spaces = re.compile(" +")
50+
def unwrap_style_text_from_prompt(style_text, prompt):
51+
"""
52+
Checks the prompt to see if the style text is wrapped around it. If so,
53+
returns True plus the prompt text without the style text. Otherwise, returns
54+
False with the original prompt.
3355
34-
35-
def extract_style_text_from_prompt(style_text, prompt):
36-
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
37-
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
56+
Note that the "cleaned" version of the style text is only used for matching
57+
purposes here. It isn't returned; the original style text is not modified.
58+
"""
59+
stripped_prompt = clean_text(prompt)
60+
stripped_style_text = clean_text(style_text)
3861
if "{prompt}" in stripped_style_text:
39-
left, right = stripped_style_text.split("{prompt}", 2)
62+
# Work out whether the prompt is wrapped in the style text. If so, we
63+
# return True and the "inner" prompt text that isn't part of the style.
64+
try:
65+
left, right = stripped_style_text.split("{prompt}", 2)
66+
except ValueError as e:
67+
# If the style text has multple "{prompt}"s, we can't split it into
68+
# two parts. This is an error, but we can't do anything about it.
69+
print(f"Unable to compare style text to prompt:\n{style_text}")
70+
print(f"Error: {e}")
71+
return False, prompt
4072
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
41-
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
73+
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
4274
return True, prompt
4375
else:
76+
# Work out whether the given prompt ends with the style text. If so, we
77+
# return True and the prompt text up to where the style text starts.
4478
if stripped_prompt.endswith(stripped_style_text):
45-
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
46-
47-
if prompt.endswith(', '):
79+
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
80+
if prompt.endswith(", "):
4881
prompt = prompt[:-2]
49-
5082
return True, prompt
5183

5284
return False, prompt
5385

5486

55-
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
87+
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
88+
"""
89+
Takes a style and compares it to the prompt and negative prompt. If the style
90+
matches, returns True plus the prompt and negative prompt with the style text
91+
removed. Otherwise, returns False with the original prompt and negative prompt.
92+
"""
5693
if not style.prompt and not style.negative_prompt:
5794
return False, prompt, negative_prompt
5895

59-
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
96+
match_positive, extracted_positive = unwrap_style_text_from_prompt(
97+
style.prompt, prompt
98+
)
6099
if not match_positive:
61100
return False, prompt, negative_prompt
62101

63-
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
102+
match_negative, extracted_negative = unwrap_style_text_from_prompt(
103+
style.negative_prompt, negative_prompt
104+
)
64105
if not match_negative:
65106
return False, prompt, negative_prompt
66107

@@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
69110

70111
class StyleDatabase:
71112
def __init__(self, path: str):
72-
self.no_style = PromptStyle("None", "", "")
113+
self.no_style = PromptStyle("None", "", "", None)
73114
self.styles = {}
74115
self.path = path
75116

117+
folder, file = os.path.split(self.path)
118+
self.default_file = file.split("*")[0] + ".csv"
119+
if self.default_file == ".csv":
120+
self.default_file = "styles.csv"
121+
self.default_path = os.path.join(folder, self.default_file)
122+
123+
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
124+
76125
self.reload()
77126

78127
def reload(self):
128+
"""
129+
Clears the style database and reloads the styles from the CSV file(s)
130+
matching the path used to initialize the database.
131+
"""
79132
self.styles.clear()
80133

81-
if not os.path.exists(self.path):
134+
path, filename = os.path.split(self.path)
135+
136+
if "*" in filename:
137+
fileglob = filename.split("*")[0] + "*.csv"
138+
filelist = []
139+
for file in os.listdir(path):
140+
if fnmatch.fnmatch(file, fileglob):
141+
filelist.append(file)
142+
# Add a visible divider to the style list
143+
half_len = round(len(file) / 2)
144+
divider = f"{'-' * (20 - half_len)} {file.upper()}"
145+
divider = f"{divider} {'-' * (40 - len(divider))}"
146+
self.styles[divider] = PromptStyle(
147+
f"{divider}", None, None, "do_not_save"
148+
)
149+
# Add styles from this CSV file
150+
self.load_from_csv(os.path.join(path, file))
151+
if len(filelist) == 0:
152+
print(f"No styles found in {path} matching {fileglob}")
153+
return
154+
elif not os.path.exists(self.path):
155+
print(f"Style database not found: {self.path}")
82156
return
157+
else:
158+
self.load_from_csv(self.path)
83159

84-
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
160+
def load_from_csv(self, path: str):
161+
with open(path, "r", encoding="utf-8-sig", newline="") as file:
85162
reader = csv.DictReader(file, skipinitialspace=True)
86163
for row in reader:
164+
# Ignore empty rows or rows starting with a comment
165+
if not row or row["name"].startswith("#"):
166+
continue
87167
# Support loading old CSV format with "name, text"-columns
88168
prompt = row["prompt"] if "prompt" in row else row["text"]
89169
negative_prompt = row.get("negative_prompt", "")
90-
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
170+
# Add style to database
171+
self.styles[row["name"]] = PromptStyle(
172+
row["name"], prompt, negative_prompt, path
173+
)
174+
175+
def get_style_paths(self) -> list():
176+
"""
177+
Returns a list of all distinct paths, including the default path, of
178+
files that styles are loaded from."""
179+
# Update any styles without a path to the default path
180+
for style in list(self.styles.values()):
181+
if not style.path:
182+
self.styles[style.name] = style._replace(path=self.default_path)
183+
184+
# Create a list of all distinct paths, including the default path
185+
style_paths = set()
186+
style_paths.add(self.default_path)
187+
for _, style in self.styles.items():
188+
if style.path:
189+
style_paths.add(style.path)
190+
191+
# Remove any paths for styles that are just list dividers
192+
style_paths.remove("do_not_save")
193+
194+
return list(style_paths)
91195

92196
def get_style_prompts(self, styles):
93197
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -96,20 +200,53 @@ def get_negative_style_prompts(self, styles):
96200
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
97201

98202
def apply_styles_to_prompt(self, prompt, styles):
99-
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
203+
return apply_styles_to_prompt(
204+
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
205+
)
100206

101207
def apply_negative_styles_to_prompt(self, prompt, styles):
102-
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
103-
104-
def save_styles(self, path: str) -> None:
105-
# Always keep a backup file around
106-
if os.path.exists(path):
107-
shutil.copy(path, f"{path}.bak")
108-
109-
with open(path, "w", encoding="utf-8-sig", newline='') as file:
110-
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
111-
writer.writeheader()
112-
writer.writerows(style._asdict() for k, style in self.styles.items())
208+
return apply_styles_to_prompt(
209+
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
210+
)
211+
212+
def save_styles(self, path: str = None) -> None:
213+
# The path argument is deprecated, but kept for backwards compatibility
214+
_ = path
215+
216+
# Update any styles without a path to the default path
217+
for style in list(self.styles.values()):
218+
if not style.path:
219+
self.styles[style.name] = style._replace(path=self.default_path)
220+
221+
# Create a list of all distinct paths, including the default path
222+
style_paths = set()
223+
style_paths.add(self.default_path)
224+
for _, style in self.styles.items():
225+
if style.path:
226+
style_paths.add(style.path)
227+
228+
# Remove any paths for styles that are just list dividers
229+
style_paths.remove("do_not_save")
230+
231+
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
232+
233+
for style_path in style_paths:
234+
# Always keep a backup file around
235+
if os.path.exists(style_path):
236+
shutil.copy(style_path, f"{style_path}.bak")
237+
238+
# Write the styles to the CSV file
239+
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
240+
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
241+
writer.writeheader()
242+
for style in (s for s in self.styles.values() if s.path == style_path):
243+
# Skip style list dividers, e.g. "STYLES.CSV"
244+
if style.name.lower().strip("# ") in csv_names:
245+
continue
246+
# Write style fields, ignoring the path field
247+
writer.writerow(
248+
{k: v for k, v in style._asdict().items() if k != "path"}
249+
)
113250

114251
def extract_styles_from_prompt(self, prompt, negative_prompt):
115252
extracted = []
@@ -120,7 +257,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt):
120257
found_style = None
121258

122259
for style in applicable_styles:
123-
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
260+
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
261+
style, prompt, negative_prompt
262+
)
124263
if is_match:
125264
found_style = style
126265
prompt = new_prompt

0 commit comments

Comments
 (0)