1
1
import csv
2
+ import fnmatch
2
3
import os
3
4
import os .path
4
5
import re
@@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple):
10
11
name : str
11
12
prompt : str
12
13
negative_prompt : str
14
+ path : str = None
15
+
16
+
17
+ def clean_text (text : str ) -> str :
18
+ """
19
+ Iterating through a list of regular expressions and replacement strings, we
20
+ clean up the prompt and style text to make it easier to match against each
21
+ other.
22
+ """
23
+ re_list = [
24
+ ("multiple commas" , re .compile ("(,+\s+)+,?" ), ", " ),
25
+ ("multiple spaces" , re .compile ("\s{2,}" ), " " ),
26
+ ]
27
+ for _ , regex , replace in re_list :
28
+ text = regex .sub (replace , text )
29
+
30
+ return text .strip (", " )
13
31
14
32
15
33
def merge_prompts (style_prompt : str , prompt : str ) -> str :
@@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles):
26
44
for style in styles :
27
45
prompt = merge_prompts (style , prompt )
28
46
29
- return prompt
47
+ return clean_text ( prompt )
30
48
31
49
32
- re_spaces = re .compile (" +" )
50
+ def unwrap_style_text_from_prompt (style_text , prompt ):
51
+ """
52
+ Checks the prompt to see if the style text is wrapped around it. If so,
53
+ returns True plus the prompt text without the style text. Otherwise, returns
54
+ False with the original prompt.
33
55
34
-
35
- def extract_style_text_from_prompt (style_text , prompt ):
36
- stripped_prompt = re .sub (re_spaces , " " , prompt .strip ())
37
- stripped_style_text = re .sub (re_spaces , " " , style_text .strip ())
56
+ Note that the "cleaned" version of the style text is only used for matching
57
+ purposes here. It isn't returned; the original style text is not modified.
58
+ """
59
+ stripped_prompt = clean_text (prompt )
60
+ stripped_style_text = clean_text (style_text )
38
61
if "{prompt}" in stripped_style_text :
39
- left , right = stripped_style_text .split ("{prompt}" , 2 )
62
+ # Work out whether the prompt is wrapped in the style text. If so, we
63
+ # return True and the "inner" prompt text that isn't part of the style.
64
+ try :
65
+ left , right = stripped_style_text .split ("{prompt}" , 2 )
66
+ except ValueError as e :
67
+ # If the style text has multple "{prompt}"s, we can't split it into
68
+ # two parts. This is an error, but we can't do anything about it.
69
+ print (f"Unable to compare style text to prompt:\n { style_text } " )
70
+ print (f"Error: { e } " )
71
+ return False , prompt
40
72
if stripped_prompt .startswith (left ) and stripped_prompt .endswith (right ):
41
- prompt = stripped_prompt [len (left ): len (stripped_prompt )- len (right )]
73
+ prompt = stripped_prompt [len (left ) : len (stripped_prompt ) - len (right )]
42
74
return True , prompt
43
75
else :
76
+ # Work out whether the given prompt ends with the style text. If so, we
77
+ # return True and the prompt text up to where the style text starts.
44
78
if stripped_prompt .endswith (stripped_style_text ):
45
- prompt = stripped_prompt [:len (stripped_prompt )- len (stripped_style_text )]
46
-
47
- if prompt .endswith (', ' ):
79
+ prompt = stripped_prompt [: len (stripped_prompt ) - len (stripped_style_text )]
80
+ if prompt .endswith (", " ):
48
81
prompt = prompt [:- 2 ]
49
-
50
82
return True , prompt
51
83
52
84
return False , prompt
53
85
54
86
55
- def extract_style_from_prompts (style : PromptStyle , prompt , negative_prompt ):
87
+ def extract_original_prompts (style : PromptStyle , prompt , negative_prompt ):
88
+ """
89
+ Takes a style and compares it to the prompt and negative prompt. If the style
90
+ matches, returns True plus the prompt and negative prompt with the style text
91
+ removed. Otherwise, returns False with the original prompt and negative prompt.
92
+ """
56
93
if not style .prompt and not style .negative_prompt :
57
94
return False , prompt , negative_prompt
58
95
59
- match_positive , extracted_positive = extract_style_text_from_prompt (style .prompt , prompt )
96
+ match_positive , extracted_positive = unwrap_style_text_from_prompt (
97
+ style .prompt , prompt
98
+ )
60
99
if not match_positive :
61
100
return False , prompt , negative_prompt
62
101
63
- match_negative , extracted_negative = extract_style_text_from_prompt (style .negative_prompt , negative_prompt )
102
+ match_negative , extracted_negative = unwrap_style_text_from_prompt (
103
+ style .negative_prompt , negative_prompt
104
+ )
64
105
if not match_negative :
65
106
return False , prompt , negative_prompt
66
107
@@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
69
110
70
111
class StyleDatabase :
71
112
def __init__ (self , path : str ):
72
- self .no_style = PromptStyle ("None" , "" , "" )
113
+ self .no_style = PromptStyle ("None" , "" , "" , None )
73
114
self .styles = {}
74
115
self .path = path
75
116
117
+ folder , file = os .path .split (self .path )
118
+ self .default_file = file .split ("*" )[0 ] + ".csv"
119
+ if self .default_file == ".csv" :
120
+ self .default_file = "styles.csv"
121
+ self .default_path = os .path .join (folder , self .default_file )
122
+
123
+ self .prompt_fields = [field for field in PromptStyle ._fields if field != "path" ]
124
+
76
125
self .reload ()
77
126
78
127
def reload (self ):
128
+ """
129
+ Clears the style database and reloads the styles from the CSV file(s)
130
+ matching the path used to initialize the database.
131
+ """
79
132
self .styles .clear ()
80
133
81
- if not os .path .exists (self .path ):
134
+ path , filename = os .path .split (self .path )
135
+
136
+ if "*" in filename :
137
+ fileglob = filename .split ("*" )[0 ] + "*.csv"
138
+ filelist = []
139
+ for file in os .listdir (path ):
140
+ if fnmatch .fnmatch (file , fileglob ):
141
+ filelist .append (file )
142
+ # Add a visible divider to the style list
143
+ half_len = round (len (file ) / 2 )
144
+ divider = f"{ '-' * (20 - half_len )} { file .upper ()} "
145
+ divider = f"{ divider } { '-' * (40 - len (divider ))} "
146
+ self .styles [divider ] = PromptStyle (
147
+ f"{ divider } " , None , None , "do_not_save"
148
+ )
149
+ # Add styles from this CSV file
150
+ self .load_from_csv (os .path .join (path , file ))
151
+ if len (filelist ) == 0 :
152
+ print (f"No styles found in { path } matching { fileglob } " )
153
+ return
154
+ elif not os .path .exists (self .path ):
155
+ print (f"Style database not found: { self .path } " )
82
156
return
157
+ else :
158
+ self .load_from_csv (self .path )
83
159
84
- with open (self .path , "r" , encoding = "utf-8-sig" , newline = '' ) as file :
160
+ def load_from_csv (self , path : str ):
161
+ with open (path , "r" , encoding = "utf-8-sig" , newline = "" ) as file :
85
162
reader = csv .DictReader (file , skipinitialspace = True )
86
163
for row in reader :
164
+ # Ignore empty rows or rows starting with a comment
165
+ if not row or row ["name" ].startswith ("#" ):
166
+ continue
87
167
# Support loading old CSV format with "name, text"-columns
88
168
prompt = row ["prompt" ] if "prompt" in row else row ["text" ]
89
169
negative_prompt = row .get ("negative_prompt" , "" )
90
- self .styles [row ["name" ]] = PromptStyle (row ["name" ], prompt , negative_prompt )
170
+ # Add style to database
171
+ self .styles [row ["name" ]] = PromptStyle (
172
+ row ["name" ], prompt , negative_prompt , path
173
+ )
174
+
175
+ def get_style_paths (self ) -> list ():
176
+ """
177
+ Returns a list of all distinct paths, including the default path, of
178
+ files that styles are loaded from."""
179
+ # Update any styles without a path to the default path
180
+ for style in list (self .styles .values ()):
181
+ if not style .path :
182
+ self .styles [style .name ] = style ._replace (path = self .default_path )
183
+
184
+ # Create a list of all distinct paths, including the default path
185
+ style_paths = set ()
186
+ style_paths .add (self .default_path )
187
+ for _ , style in self .styles .items ():
188
+ if style .path :
189
+ style_paths .add (style .path )
190
+
191
+ # Remove any paths for styles that are just list dividers
192
+ style_paths .remove ("do_not_save" )
193
+
194
+ return list (style_paths )
91
195
92
196
def get_style_prompts (self , styles ):
93
197
return [self .styles .get (x , self .no_style ).prompt for x in styles ]
@@ -96,20 +200,53 @@ def get_negative_style_prompts(self, styles):
96
200
return [self .styles .get (x , self .no_style ).negative_prompt for x in styles ]
97
201
98
202
def apply_styles_to_prompt (self , prompt , styles ):
99
- return apply_styles_to_prompt (prompt , [self .styles .get (x , self .no_style ).prompt for x in styles ])
203
+ return apply_styles_to_prompt (
204
+ prompt , [self .styles .get (x , self .no_style ).prompt for x in styles ]
205
+ )
100
206
101
207
def apply_negative_styles_to_prompt (self , prompt , styles ):
102
- return apply_styles_to_prompt (prompt , [self .styles .get (x , self .no_style ).negative_prompt for x in styles ])
103
-
104
- def save_styles (self , path : str ) -> None :
105
- # Always keep a backup file around
106
- if os .path .exists (path ):
107
- shutil .copy (path , f"{ path } .bak" )
108
-
109
- with open (path , "w" , encoding = "utf-8-sig" , newline = '' ) as file :
110
- writer = csv .DictWriter (file , fieldnames = PromptStyle ._fields )
111
- writer .writeheader ()
112
- writer .writerows (style ._asdict () for k , style in self .styles .items ())
208
+ return apply_styles_to_prompt (
209
+ prompt , [self .styles .get (x , self .no_style ).negative_prompt for x in styles ]
210
+ )
211
+
212
+ def save_styles (self , path : str = None ) -> None :
213
+ # The path argument is deprecated, but kept for backwards compatibility
214
+ _ = path
215
+
216
+ # Update any styles without a path to the default path
217
+ for style in list (self .styles .values ()):
218
+ if not style .path :
219
+ self .styles [style .name ] = style ._replace (path = self .default_path )
220
+
221
+ # Create a list of all distinct paths, including the default path
222
+ style_paths = set ()
223
+ style_paths .add (self .default_path )
224
+ for _ , style in self .styles .items ():
225
+ if style .path :
226
+ style_paths .add (style .path )
227
+
228
+ # Remove any paths for styles that are just list dividers
229
+ style_paths .remove ("do_not_save" )
230
+
231
+ csv_names = [os .path .split (path )[1 ].lower () for path in style_paths ]
232
+
233
+ for style_path in style_paths :
234
+ # Always keep a backup file around
235
+ if os .path .exists (style_path ):
236
+ shutil .copy (style_path , f"{ style_path } .bak" )
237
+
238
+ # Write the styles to the CSV file
239
+ with open (style_path , "w" , encoding = "utf-8-sig" , newline = "" ) as file :
240
+ writer = csv .DictWriter (file , fieldnames = self .prompt_fields )
241
+ writer .writeheader ()
242
+ for style in (s for s in self .styles .values () if s .path == style_path ):
243
+ # Skip style list dividers, e.g. "STYLES.CSV"
244
+ if style .name .lower ().strip ("# " ) in csv_names :
245
+ continue
246
+ # Write style fields, ignoring the path field
247
+ writer .writerow (
248
+ {k : v for k , v in style ._asdict ().items () if k != "path" }
249
+ )
113
250
114
251
def extract_styles_from_prompt (self , prompt , negative_prompt ):
115
252
extracted = []
@@ -120,7 +257,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt):
120
257
found_style = None
121
258
122
259
for style in applicable_styles :
123
- is_match , new_prompt , new_neg_prompt = extract_style_from_prompts (style , prompt , negative_prompt )
260
+ is_match , new_prompt , new_neg_prompt = extract_original_prompts (
261
+ style , prompt , negative_prompt
262
+ )
124
263
if is_match :
125
264
found_style = style
126
265
prompt = new_prompt
0 commit comments