Skip to content

Commit e32c9dc

Browse files
committed
put batch in group with prompt info
- resolves #67 and #69
1 parent 7092411 commit e32c9dc

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

frontends/krita/krita_diff/script.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import os
23
import time
34
from typing import Union
@@ -27,8 +28,8 @@
2728
)
2829
from .utils import (
2930
b64_to_img,
30-
create_layer,
3131
find_optimal_selection_region,
32+
get_desc_from_resp,
3233
img_to_ba,
3334
save_img,
3435
)
@@ -160,10 +161,22 @@ def get_mask_image(self) -> Union[QImage, None]:
160161
QImage.Format_RGBA8888,
161162
).rgbSwapped()
162163

163-
def img_inserter(self, x, y, width, height):
164+
def img_inserter(self, x, y, width, height, group: str = None):
164165
"""Return frozen image inserter to insert images as new layer."""
165166
# Selection may change before callback, so freeze selection region
166167
has_selection = self.selection is not None
168+
glayer = self.doc.createGroupLayer(group) if group else None
169+
170+
def create_layer(name: str):
171+
"""Create new layer in document or group"""
172+
layer = self.doc.createNode(name, "paintLayer")
173+
parent = self.doc.rootNode()
174+
if glayer:
175+
glayer.addChildNode(layer, None)
176+
parent.addChildNode(glayer, None)
177+
else:
178+
parent.addChildNode(layer, None)
179+
return layer
167180

168181
# TODO: Insert images inside a group layer for better organization
169182
# Group layer name can contain model name, prompt, etc
@@ -206,7 +219,7 @@ def insert(layer_name, enc):
206219
self.doc.resizeImage(0, 0, new_width, new_height)
207220

208221
ba = img_to_ba(image)
209-
layer = create_layer(self.doc, layer_name)
222+
layer = create_layer(layer_name)
210223
# layer.setColorSpace() doesn't pernamently convert layer depth etc...
211224

212225
# Don't fail silently for setPixelData(); fails if bit depth or number of channels mismatch
@@ -218,23 +231,30 @@ def insert(layer_name, enc):
218231
layer.setPixelData(ba, x, y, width, height)
219232
return layer
220233

234+
if glayer:
235+
return insert, glayer
221236
return insert
222237

223238
def apply_txt2img(self):
224239
# freeze selection region
225-
insert = self.img_inserter(self.x, self.y, self.width, self.height)
240+
insert, glayer = self.img_inserter(
241+
self.x, self.y, self.width, self.height, group="a"
242+
)
226243
mask_trigger = self.transparency_mask_inserter()
227244

228245
def cb(response):
229246
if len(self.client.long_reqs) == 1: # last request
230247
self.eta_timer.stop()
231248
assert response is not None, "Backend Error, check terminal"
232249
outputs = response["outputs"]
250+
glayer_name, layer_names = get_desc_from_resp(response, "txt2img")
233251
layers = [
234-
insert(f"txt2img {i + 1}", output) for i, output in enumerate(outputs)
252+
insert(name if name else f"txt2img {i + 1}", output)
253+
for output, name, i in zip(outputs, layer_names, itertools.count())
235254
]
236255
for layer in layers[:-1]:
237256
layer.setVisible(False)
257+
glayer.setName(glayer_name)
238258
self.doc.refreshProjection()
239259
mask_trigger(layers)
240260

@@ -244,7 +264,9 @@ def cb(response):
244264
)
245265

246266
def apply_img2img(self, mode):
247-
insert = self.img_inserter(self.x, self.y, self.width, self.height)
267+
insert, glayer = self.img_inserter(
268+
self.x, self.y, self.width, self.height, group="a"
269+
)
248270
mask_trigger = self.transparency_mask_inserter()
249271
mask_image = self.get_mask_image()
250272

@@ -272,12 +294,14 @@ def cb(response):
272294
layer_name_prefix = (
273295
"inpaint" if mode == 1 else "sd upscale" if mode == 2 else "img2img"
274296
)
297+
glayer_name, layer_names = get_desc_from_resp(response, layer_name_prefix)
275298
layers = [
276-
insert(f"{layer_name_prefix} {i + 1}", output)
277-
for i, output in enumerate(outputs)
299+
insert(name if name else f"{layer_name_prefix} {i + 1}", output)
300+
for output, name, i in zip(outputs, layer_names, itertools.count())
278301
]
279302
for layer in layers[:-1]:
280303
layer.setVisible(False)
304+
glayer.setName(glayer_name)
281305
self.doc.refreshProjection()
282306
# dont need transparency mask for inpaint mode
283307
if mode == 0:

frontends/krita/krita_diff/utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import cycle
44
from math import ceil
55

6-
from krita import Document, Krita, QBuffer, QByteArray, QImage, QIODevice, Qt
6+
from krita import Krita, QBuffer, QByteArray, QImage, QIODevice, Qt
77

88
from .config import Config
99
from .defaults import (
@@ -149,14 +149,6 @@ def find_optimal_selection_region(
149149
return best_x, best_y, best_width, best_height
150150

151151

152-
def create_layer(doc: Document, name: str):
153-
"""Create new layer in document"""
154-
root = doc.rootNode()
155-
layer = doc.createNode(name, "paintLayer")
156-
root.addChildNode(layer, None)
157-
return layer
158-
159-
160152
def save_img(img: QImage, path: str):
161153
"""Expects QImage"""
162154
# png is lossless; setting compression to max (0) won't affect quality
@@ -194,6 +186,26 @@ def bytewise_xor(msg: bytes, key: bytes):
194186
return bytes(v ^ k for v, k in zip(msg, cycle(key)))
195187

196188

189+
def get_desc_from_resp(resp: dict, type: str = ""):
190+
"""Get description of image generation from backend response."""
191+
try:
192+
info = json.loads(resp["info"])
193+
seeds = info["all_seeds"]
194+
glayer_desc = f"""[{type}]
195+
Prompt: {info['prompt']},
196+
Negative Prompt: {info['negative_prompt']},
197+
Model: {info['sd_model_hash']},
198+
Sampler: {info['sampler_name']},
199+
Scale: {info['cfg_scale']},
200+
Steps: {info['steps']}"""
201+
layers_desc = []
202+
for (seed,) in zip(seeds):
203+
layers_desc.append(f"Seed: {seed}")
204+
return glayer_desc, layers_desc
205+
except:
206+
return f"[{type}]", cycle([None])
207+
208+
197209
def reset_docker_layout():
198210
"""NOTE: Default stacking of dockers hardcoded here."""
199211
docker_ids = {

0 commit comments

Comments
 (0)