Skip to content

Commit bd61751

Browse files
authored
feat: enhance comfyui workflow (#10085)
1 parent 6692e8c commit bd61751

File tree

3 files changed

+68
-24
lines changed

3 files changed

+68
-24
lines changed

api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import base64
2-
import io
31
import json
42
import random
53
import uuid
@@ -8,7 +6,7 @@
86
from websocket import WebSocket
97
from yarl import URL
108

11-
from core.file.file_manager import _get_encoded_string
9+
from core.file.file_manager import download
1210
from core.file.models import File
1311

1412

@@ -29,8 +27,7 @@ def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
2927
return response.content
3028

3129
def upload_image(self, image_file: File) -> dict:
32-
image_content = base64.b64decode(_get_encoded_string(image_file))
33-
file = io.BytesIO(image_content)
30+
file = download(image_file)
3431
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
3532
res = httpx.post(str(self.base_url / "upload/image"), files=files)
3633
return res.json()
@@ -47,12 +44,7 @@ def open_websocket_connection(self) -> tuple[WebSocket, str]:
4744
ws.connect(ws_address)
4845
return ws, client_id
4946

50-
def set_prompt(
51-
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
52-
) -> dict:
53-
"""
54-
find the first KSampler, then can find the prompt node through it.
55-
"""
47+
def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
5648
prompt = origin_prompt.copy()
5749
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
5850
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
@@ -64,9 +56,20 @@ def set_prompt(
6456
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
6557
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
6658

67-
if image_name != "":
68-
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
69-
prompt.get(image_loader)["inputs"]["image"] = image_name
59+
return prompt
60+
61+
def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
62+
prompt = origin_prompt.copy()
63+
for index, image_node_id in enumerate(image_ids):
64+
prompt[image_node_id]["inputs"]["image"] = image_names[index]
65+
return prompt
66+
67+
def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
68+
prompt = origin_prompt.copy()
69+
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
70+
load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
71+
for load_image, image_name in zip(load_image_nodes, image_names):
72+
prompt.get(load_image)["inputs"]["image"] = image_name
7073
return prompt
7174

7275
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):

api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
22
from typing import Any
33

4+
from core.file import FileType
45
from core.tools.entities.tool_entities import ToolInvokeMessage
6+
from core.tools.errors import ToolParameterValidationError
57
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
68
from core.tools.tool.builtin_tool import BuiltinTool
79

@@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool):
1012
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
1113
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
1214

13-
positive_prompt = tool_parameters.get("positive_prompt")
14-
negative_prompt = tool_parameters.get("negative_prompt")
15+
positive_prompt = tool_parameters.get("positive_prompt", "")
16+
negative_prompt = tool_parameters.get("negative_prompt", "")
17+
images = tool_parameters.get("images") or []
1518
workflow = tool_parameters.get("workflow_json")
16-
image_name = ""
17-
if image := tool_parameters.get("image"):
19+
image_names = []
20+
for image in images:
21+
if image.type != FileType.IMAGE:
22+
continue
1823
image_name = comfyui.upload_image(image).get("name")
24+
image_names.append(image_name)
25+
26+
set_prompt_with_ksampler = True
27+
if "{{positive_prompt}}" in workflow:
28+
set_prompt_with_ksampler = False
29+
workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
30+
workflow = workflow.replace("{{negative_prompt}}", negative_prompt)
1931

2032
try:
21-
origin_prompt = json.loads(workflow)
33+
prompt = json.loads(workflow)
2234
except:
2335
return self.create_text_message("the Workflow JSON is not correct")
2436

25-
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
37+
if set_prompt_with_ksampler:
38+
try:
39+
prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt)
40+
except:
41+
raise ToolParameterValidationError(
42+
"Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json"
43+
)
44+
45+
if image_names:
46+
if image_ids := tool_parameters.get("image_ids"):
47+
image_ids = image_ids.split(",")
48+
try:
49+
prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids)
50+
except:
51+
raise ToolParameterValidationError("the Image Node ID List not match your upload image files.")
52+
else:
53+
prompt = comfyui.set_prompt_images_by_default(prompt, image_names)
54+
2655
images = comfyui.generate_image_by_prompt(prompt)
2756
result = []
2857
for img in images:

api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml

+16-4
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ parameters:
2424
zh_Hans: 负面提示词
2525
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
2626
form: llm
27-
- name: image
28-
type: file
27+
- name: images
28+
type: files
2929
label:
30-
en_US: Input Image
30+
en_US: Input Images
3131
zh_Hans: 输入的图片
32-
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
32+
llm_description: The input images, used to transfer to the comfyui workflow to generate another image.
3333
form: llm
3434
- name: workflow_json
3535
type: string
@@ -40,3 +40,15 @@ parameters:
4040
en_US: exported from ComfyUI workflow
4141
zh_Hans: 从ComfyUI的工作流中导出
4242
form: form
43+
- name: image_ids
44+
type: string
45+
label:
46+
en_US: Image Node ID List
47+
zh_Hans: 图片节点ID列表
48+
placeholder:
49+
en_US: Use commas to separate multiple node ID
50+
zh_Hans: 多个节点ID时使用半角逗号分隔
51+
human_description:
52+
en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list.
53+
zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI
54+
form: form

0 commit comments

Comments
 (0)