Skip to content

Commit 7631bda

Browse files
committed
make interrupt work
- use wrap_gradio_gpu_call to use webUI's queue system - remove async from fastAPI routes (turns out was what blocked requests) - remove waiting for no requests before getting backend state
1 parent 0ec2173 commit 7631bda

File tree

5 files changed

+40
-15
lines changed

5 files changed

+40
-15
lines changed

backend/app.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from fastapi import APIRouter, Request
99
from fastapi.responses import StreamingResponse
1010
from modules import shared
11+
from modules.call_queue import wrap_gradio_gpu_call
1112
from PIL import Image, ImageOps
1213
from starlette.concurrency import iterate_in_threadpool
1314

@@ -53,8 +54,9 @@
5354

5455
# TODO: Consider using pipeline directly instead of Gradio API for less surprises & better control
5556

57+
5658
@router.get("/config", response_model=ConfigResponse)
57-
async def get_state():
59+
def get_state():
5860
"""Get information about backend API.
5961
6062
Returns config from `krita_config.yaml`, other metadata,
@@ -83,7 +85,7 @@ async def get_state():
8385

8486

8587
@router.post("/txt2img", response_model=ImageResponse)
86-
async def f_txt2img(req: Txt2ImgRequest):
88+
def f_txt2img(req: Txt2ImgRequest):
8789
"""Post request for Txt2Img.
8890
8991
Args:
@@ -105,7 +107,7 @@ async def f_txt2img(req: Txt2ImgRequest):
105107
req.base_size, req.max_size, req.orig_width, req.orig_height
106108
)
107109

108-
images, info, html = modules.txt2img.txt2img(
110+
images, info, html = wrap_gradio_gpu_call(modules.txt2img.txt2img)(
109111
parse_prompt(req.prompt), # prompt
110112
parse_prompt(req.negative_prompt), # negative_prompt
111113
"None", # prompt_style: saved prompt styles (unsupported)
@@ -131,6 +133,10 @@ async def f_txt2img(req: Txt2ImgRequest):
131133
req.firstphase_height, # firstphase_height (yes its inconsistently width/height first)
132134
*args,
133135
)
136+
if len(images) < 1:
137+
log.warning("Interrupted!")
138+
return {"outputs": [], "info": info}
139+
134140
if shared.opts.return_grid:
135141
if not req.include_grid and len(images) > 1 and script_ind == 0:
136142
images = images[1:]
@@ -160,7 +166,7 @@ async def f_txt2img(req: Txt2ImgRequest):
160166

161167

162168
@router.post("/img2img", response_model=ImageResponse)
163-
async def f_img2img(req: Img2ImgRequest):
169+
def f_img2img(req: Img2ImgRequest):
164170
"""Post request for Img2Img.
165171
166172
Args:
@@ -203,7 +209,7 @@ async def f_img2img(req: Img2ImgRequest):
203209
# - new color sketch functionality in webUI is irrelevant so None is used for their options.
204210
# - the internal code for img2img is confusing and duplicative...
205211

206-
images, info, html = modules.img2img.img2img(
212+
images, info, html = wrap_gradio_gpu_call(modules.img2img.img2img)(
207213
req.mode, # mode
208214
parse_prompt(req.prompt), # prompt
209215
parse_prompt(req.negative_prompt), # negative_prompt
@@ -243,6 +249,10 @@ async def f_img2img(req: Img2ImgRequest):
243249
"", # img2img_batch_output_dir (unspported)
244250
*args,
245251
)
252+
if len(images) < 1:
253+
log.warning("Interrupted!")
254+
return {"outputs": [], "info": info}
255+
246256
if shared.opts.return_grid:
247257
if not req.include_grid and len(images) > 1 and script_ind == 0:
248258
images = images[1:]
@@ -288,7 +298,7 @@ def apply_mask(img):
288298

289299

290300
@router.post("/upscale", response_model=UpscaleResponse)
291-
async def f_upscale(req: UpscaleRequest):
301+
def f_upscale(req: UpscaleRequest):
292302
"""Post request for upscaling.
293303
294304
Args:

frontends/krita/krita_diff/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ERR_BAD_URL,
1313
ERR_NO_CONNECTION,
1414
GET_CONFIG_TIMEOUT,
15+
OFFICIAL_ROUTE_PREFIX,
1516
POST_TIMEOUT,
1617
ROUTE_PREFIX,
1718
STATE_READY,
@@ -24,11 +25,11 @@
2425
# except to prevent user from spamming themselves
2526

2627

27-
def get_url(cfg: Config, route: str = ...):
28+
def get_url(cfg: Config, route: str = ..., prefix: str = ROUTE_PREFIX):
2829
base = cfg("base_url", str)
2930
if not urlparse(base).scheme in {"http", "https"}:
3031
return None
31-
url = urljoin(base, ROUTE_PREFIX)
32+
url = urljoin(base, prefix)
3233
if route is not ...:
3334
url = urljoin(url, route)
3435
# print("url:", url)
@@ -137,7 +138,6 @@ def __init__(self, cfg: Config, ext_cfg: Config):
137138
self.ext_cfg = ext_cfg
138139
self.reqs = []
139140
# NOTE: this is a hacky workaround for detecting if backend is reachable
140-
# this is to prevent zombie post requests (since they have no timeout)
141141
self.is_connected = False
142142

143143
def handle_api_error(self, exc: Exception):
@@ -255,11 +255,6 @@ def cb(obj):
255255
self.is_connected = True
256256
self.status.emit(STATE_READY)
257257

258-
# only get config if there are no pending post requests jamming the backend
259-
# NOTE: this might prevent get_config() from ever working if zombie requests can happen
260-
if len(self.reqs) > 0:
261-
return
262-
263258
url = get_url(self.cfg, "config")
264259
if not url:
265260
self.status.emit(ERR_BAD_URL)
@@ -375,3 +370,8 @@ def post_upscale(self, cb, src_img):
375370
else {"src_img": img_to_b64(src_img)}
376371
)
377372
self.post("upscale", params, cb)
373+
374+
def post_interrupt(self, cb):
375+
# get official API url
376+
url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX)
377+
self.post("interrupt", {}, cb, base_url=url)

frontends/krita/krita_diff/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
STATE_IMG2IMG = "img2img done!"
1515
STATE_INPAINT = "inpaint done!"
1616
STATE_UPSCALE = "upscale done!"
17+
STATE_INTERRUPT = "Interrupted!"
1718

1819
# Other currently hardcoded stuff
1920
GET_CONFIG_TIMEOUT = 10 # there is prevention for get request accumulation
@@ -26,6 +27,7 @@
2627
ADD_MASK_TIMEOUT = 200
2728
THREADED = True
2829
ROUTE_PREFIX = "/sdapi/interpause/"
30+
OFFICIAL_ROUTE_PREFIX = "/sdapi/v1/"
2931

3032
# error messages
3133
ERR_MISSING_CONFIG = "Report this bug, developer missed out a config key somewhere."

frontends/krita/krita_diff/pages/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from krita import QHBoxLayout, QVBoxLayout, QWidget
1+
from krita import QHBoxLayout, QPushButton, QVBoxLayout, QWidget
22

33
from ..script import script
44
from ..widgets import QCheckBox, QComboBoxLayout, QLabel, QSpinBoxLayout
@@ -63,6 +63,9 @@ def __init__(self, *args, **kwargs):
6363
# Tiling mode
6464
self.tiling = QCheckBox(script.cfg, "sd_tiling", "Tiling mode")
6565

66+
# Interrupt button
67+
self.interrupt_btn = QPushButton("Interrupt")
68+
6669
layout = QVBoxLayout()
6770
layout.setContentsMargins(0, 0, 0, 0)
6871

@@ -74,6 +77,7 @@ def __init__(self, *args, **kwargs):
7477
layout.addLayout(self.sd_model_layout)
7578
layout.addLayout(batch_layout)
7679
layout.addLayout(size_layout)
80+
layout.addWidget(self.interrupt_btn)
7781

7882
self.setLayout(layout)
7983

@@ -112,3 +116,5 @@ def toggle_codeformer_weights(visible):
112116
toggle_codeformer_weights(
113117
self.face_restorer_layout.qcombo.currentText() == "CodeFormer"
114118
)
119+
120+
self.interrupt_btn.released.connect(script.action_interrupt)

frontends/krita/krita_diff/script.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
EXT_CFG_NAME,
2323
STATE_IMG2IMG,
2424
STATE_INPAINT,
25+
STATE_INTERRUPT,
2526
STATE_RESET_DEFAULT,
2627
STATE_TXT2IMG,
2728
STATE_UPSCALE,
@@ -363,5 +364,11 @@ def action_simple_upscale(self):
363364
return
364365
self.apply_simple_upscale()
365366

367+
def action_interrupt(self):
368+
def cb(response=None):
369+
self.status_changed.emit(STATE_INTERRUPT)
370+
371+
self.client.post_interrupt(cb)
372+
366373

367374
script = Script()

0 commit comments

Comments
 (0)