Skip to content

feat: enhance comfyui workflow #10085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import base64
import io
import json
import random
import uuid
Expand All @@ -8,7 +6,7 @@
from websocket import WebSocket
from yarl import URL

from core.file.file_manager import _get_encoded_string
from core.file.file_manager import download
from core.file.models import File


Expand All @@ -29,8 +27,7 @@ def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
return response.content

def upload_image(self, image_file: File) -> dict:
image_content = base64.b64decode(_get_encoded_string(image_file))
file = io.BytesIO(image_content)
file = download(image_file)
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
res = httpx.post(str(self.base_url / "upload/image"), files=files)
return res.json()
Expand All @@ -47,12 +44,7 @@ def open_websocket_connection(self) -> tuple[WebSocket, str]:
ws.connect(ws_address)
return ws, client_id

def set_prompt(
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
) -> dict:
"""
find the first KSampler, then can find the prompt node through it.
"""
def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
prompt = origin_prompt.copy()
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
Expand All @@ -64,9 +56,20 @@ def set_prompt(
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt

if image_name != "":
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
prompt.get(image_loader)["inputs"]["image"] = image_name
return prompt

def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
prompt = origin_prompt.copy()
for index, image_node_id in enumerate(image_ids):
prompt[image_node_id]["inputs"]["image"] = image_names[index]
return prompt

def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
prompt = origin_prompt.copy()
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
for load_image, image_name in zip(load_image_nodes, image_names):
prompt.get(load_image)["inputs"]["image"] = image_name
return prompt

def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
Expand Down
41 changes: 35 additions & 6 deletions api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from typing import Any

from core.file import FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolParameterValidationError
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
from core.tools.tool.builtin_tool import BuiltinTool

Expand All @@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])

positive_prompt = tool_parameters.get("positive_prompt")
negative_prompt = tool_parameters.get("negative_prompt")
positive_prompt = tool_parameters.get("positive_prompt", "")
negative_prompt = tool_parameters.get("negative_prompt", "")
images = tool_parameters.get("images") or []
workflow = tool_parameters.get("workflow_json")
image_name = ""
if image := tool_parameters.get("image"):
image_names = []
for image in images:
if image.type != FileType.IMAGE:
continue
image_name = comfyui.upload_image(image).get("name")
image_names.append(image_name)

set_prompt_with_ksampler = True
if "{{positive_prompt}}" in workflow:
set_prompt_with_ksampler = False
workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
workflow = workflow.replace("{{negative_prompt}}", negative_prompt)

try:
origin_prompt = json.loads(workflow)
prompt = json.loads(workflow)
except:
return self.create_text_message("the Workflow JSON is not correct")

prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
if set_prompt_with_ksampler:
try:
prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt)
except:
raise ToolParameterValidationError(
"Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json"
)

if image_names:
if image_ids := tool_parameters.get("image_ids"):
image_ids = image_ids.split(",")
try:
prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids)
except:
raise ToolParameterValidationError("the Image Node ID List not match your upload image files.")
else:
prompt = comfyui.set_prompt_images_by_default(prompt, image_names)

images = comfyui.generate_image_by_prompt(prompt)
result = []
for img in images:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ parameters:
zh_Hans: 负面提示词
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm
- name: image
type: file
- name: images
type: files
label:
en_US: Input Image
en_US: Input Images
zh_Hans: 输入的图片
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
llm_description: The input images, used to transfer to the comfyui workflow to generate another image.
form: llm
- name: workflow_json
type: string
Expand All @@ -40,3 +40,15 @@ parameters:
en_US: exported from ComfyUI workflow
zh_Hans: 从ComfyUI的工作流中导出
form: form
- name: image_ids
type: string
label:
en_US: Image Node ID List
zh_Hans: 图片节点ID列表
placeholder:
en_US: Use commas to separate multiple node ID
zh_Hans: 多个节点ID时使用半角逗号分隔
human_description:
en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list.
zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI
form: form