Skip to content

Commit 7a14c8a

Browse files
committed
add an option to enable sections from extras tab in txt2img/img2img
fix some style inconsistenices
1 parent 645f4e7 commit 7a14c8a

9 files changed

+133
-23
lines changed

modules/processing.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Any, Dict, List, Optional
1414

1515
import modules.sd_hijack
16-
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
16+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
1717
from modules.sd_hijack import model_hijack
1818
from modules.shared import opts, cmd_opts, state
1919
import modules.shared as shared
@@ -658,6 +658,11 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
658658

659659
image = Image.fromarray(x_sample)
660660

661+
if p.scripts is not None:
662+
pp = scripts.PostprocessImageArgs(image)
663+
p.scripts.postprocess_image(p, pp)
664+
image = pp.image
665+
661666
if p.color_corrections is not None and i < len(p.color_corrections):
662667
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
663668
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)

modules/scripts.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
import gradio as gr
88

9-
from modules.processing import StableDiffusionProcessing
109
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
1110

1211
AlwaysVisible = object()
1312

1413

14+
class PostprocessImageArgs:
15+
def __init__(self, image):
16+
self.image = image
17+
18+
1519
class Script:
1620
filename = None
1721
args_from = None
@@ -65,7 +69,7 @@ def run(self, p, *args):
6569
args contains all values returned by components from ui()
6670
"""
6771

68-
raise NotImplementedError()
72+
pass
6973

7074
def process(self, p, *args):
7175
"""
@@ -100,6 +104,13 @@ def postprocess_batch(self, p, *args, **kwargs):
100104

101105
pass
102106

107+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
108+
"""
109+
Called for every image after it has been generated.
110+
"""
111+
112+
pass
113+
103114
def postprocess(self, p, processed, *args):
104115
"""
105116
This function is called after processing ends for AlwaysVisible scripts.
@@ -247,11 +258,15 @@ def __init__(self):
247258
self.infotext_fields = []
248259

249260
def initialize_scripts(self, is_img2img):
261+
from modules import scripts_auto_postprocessing
262+
250263
self.scripts.clear()
251264
self.alwayson_scripts.clear()
252265
self.selectable_scripts.clear()
253266

254-
for script_class, path, basedir, script_module in scripts_data:
267+
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
268+
269+
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
255270
script = script_class()
256271
script.filename = path
257272
script.is_txt2img = not is_img2img
@@ -332,7 +347,7 @@ def init_field(title):
332347

333348
return inputs
334349

335-
def run(self, p: StableDiffusionProcessing, *args):
350+
def run(self, p, *args):
336351
script_index = args[0]
337352

338353
if script_index == 0:
@@ -386,6 +401,15 @@ def postprocess_batch(self, p, images, **kwargs):
386401
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
387402
print(traceback.format_exc(), file=sys.stderr)
388403

404+
def postprocess_image(self, p, pp: PostprocessImageArgs):
405+
for script in self.alwayson_scripts:
406+
try:
407+
script_args = p.script_args[script.args_from:script.args_to]
408+
script.postprocess_image(p, pp, *script_args)
409+
except Exception:
410+
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
411+
print(traceback.format_exc(), file=sys.stderr)
412+
389413
def before_component(self, component, **kwargs):
390414
for script in self.scripts:
391415
try:
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from modules import scripts, scripts_postprocessing, shared
2+
3+
4+
class ScriptPostprocessingForMainUI(scripts.Script):
5+
def __init__(self, script_postproc):
6+
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
7+
self.postprocessing_controls = None
8+
9+
def title(self):
10+
return self.script.name
11+
12+
def show(self, is_img2img):
13+
return scripts.AlwaysVisible
14+
15+
def ui(self, is_img2img):
16+
self.postprocessing_controls = self.script.ui()
17+
return self.postprocessing_controls.values()
18+
19+
def postprocess_image(self, p, script_pp, *args):
20+
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
21+
22+
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
23+
pp.info = {}
24+
self.script.process(pp, **args_dict)
25+
p.extra_generation_params.update(pp.info)
26+
script_pp.image = pp.image
27+
28+
29+
def create_auto_preprocessing_script_data():
30+
from modules import scripts
31+
32+
res = []
33+
34+
for name in shared.opts.postprocessing_enable_in_main_ui:
35+
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
36+
if script is None:
37+
continue
38+
39+
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
40+
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
41+
42+
return res

modules/scripts_postprocessing.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def image_changed(self):
4646
pass
4747

4848

49+
50+
4951
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
5052
try:
5153
res = func(*args, **kwargs)
@@ -68,6 +70,9 @@ def initialize_scripts(self, scripts_data):
6870
script: ScriptPostprocessing = script_class()
6971
script.filename = path
7072

73+
if script.name == "Simple Upscale":
74+
continue
75+
7176
self.scripts.append(script)
7277

7378
def create_script_ui(self, script, inputs):
@@ -87,12 +92,11 @@ def scripts_in_preferred_order(self):
8792
import modules.scripts
8893
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
8994

90-
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
95+
scripts_order = shared.opts.postprocessing_operation_order
9196

9297
def script_score(name):
93-
name = name.lower()
9498
for i, possible_match in enumerate(scripts_order):
95-
if possible_match in name:
99+
if possible_match == name:
96100
return i
97101

98102
return len(self.scripts)
@@ -145,3 +149,4 @@ def create_args_for_run(self, scripts_args):
145149
def image_changed(self):
146150
for script in self.scripts_in_preferred_order():
147151
script.image_changed()
152+

modules/shared.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import modules.memmon
1414
import modules.styles
1515
import modules.devices as devices
16-
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
17-
from modules.paths import models_path, script_path, sd_path
16+
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
17+
from modules.paths import models_path, script_path
1818

1919

2020
demo = None
@@ -264,12 +264,6 @@ def assign_current_image(self, image):
264264

265265
face_restorers = []
266266

267-
268-
def realesrgan_models_names():
269-
import modules.realesrgan_model
270-
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
271-
272-
273267
class OptionInfo:
274268
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
275269
self.default = default
@@ -360,7 +354,7 @@ def list_samplers():
360354
options_templates.update(options_section(('upscaling', "Upscaling"), {
361355
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
362356
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
363-
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
357+
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
364358
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
365359
}))
366360

@@ -483,7 +477,8 @@ def list_samplers():
483477
}))
484478

485479
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
486-
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
480+
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
481+
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
487482
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
488483
}))
489484

modules/shared_items.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
3+
def realesrgan_models_names():
4+
import modules.realesrgan_model
5+
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
6+
7+
def postprocessing_scripts():
8+
import modules.scripts
9+
10+
return modules.scripts.scripts_postproc.scripts

modules/ui_components.py

+8
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
4848
def get_block_name(self):
4949
return "colorpicker"
5050

51+
52+
class DropdownMulti(gr.Dropdown):
53+
"""Same as gr.Dropdown but always multiselect"""
54+
def __init__(self, **kwargs):
55+
super().__init__(multiselect=True, **kwargs)
56+
57+
def get_block_name(self):
58+
return "dropdown"

scripts/postprocessing_upscale.py

+25
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,28 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1,
104104

105105
def image_changed(self):
106106
upscale_cache.clear()
107+
108+
109+
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
110+
name = "Simple Upscale"
111+
order = 900
112+
113+
def ui(self):
114+
with FormRow():
115+
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
116+
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
117+
118+
return {
119+
"upscale_by": upscale_by,
120+
"upscaler_name": upscaler_name,
121+
}
122+
123+
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
124+
if upscaler_name is None or upscaler_name == "None":
125+
return
126+
127+
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
128+
assert upscaler1, f'could not find upscaler named {upscaler_name}'
129+
130+
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
131+
pp.info[f"Postprocess upscaler"] = upscaler1.name

style.css

+1-5
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
min-height: 3.2em;
165165
}
166166

167-
#txt2img_styles ul, #img2img_styles ul{
167+
ul.list-none{
168168
max-height: 35em;
169169
z-index: 2000;
170170
}
@@ -714,9 +714,6 @@ footer {
714714
white-space: nowrap;
715715
min-width: auto;
716716
}
717-
#txt2img_hires_fix{
718-
margin-left: -0.8em;
719-
}
720717

721718
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
722719
margin-left: 0em;
@@ -744,7 +741,6 @@ footer {
744741

745742
.dark .gr-compact{
746743
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
747-
margin-left: 0.8em;
748744
}
749745

750746
.gr-compact{

0 commit comments

Comments
 (0)