Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,63 @@ await comfyui.fetch(backend_filepath, on_load)

To test the library with a sample SDXL workflow, run the following after installing (replace the address with your ComfyUI endpoint). Make sure your ComfyUI has `sd_xl_base_1.0.safetensors` and `sd_xl_refiner_1.0.safetensors` installed (or replace the workflow).

Program options :

```usage: e2e.py [-h] --address ADDRESS --prompt PROMPT [--output_path OUTPUT_PATH] [--output_dir OUTPUT_DIR] [--sample_dir SAMPLE_DIR]

Run an SDXL or other workflow on a deployed ComfyUI server.

options:
-h, --help show this help message and exit
--address ADDRESS the ComfyUI endpoint
--prompt PROMPT the user prompt
--output_path OUTPUT_PATH
The filename to store the final image received from ComfyUI
--output_dir OUTPUT_DIR
A folder to store the final image received from ComfyUI, name will be prompt_RANDOM.png
--sample_dir SAMPLE_DIR
Folder to store sample images received from ComfyUI, images names will be prompt_sample_RANDOM_COUNTER.png
```
Example:
```python
comfy_ui_example_e2e\
--address='192.168.0.10:11010'\
--prompt='a smiling potato $base_steps=8$refiner_steps=3'\
--output='./potato.png'
--output_path='./potato.png'
```
The single quotes are important so your shell doesn't try to parse the `$`'s. Expected output:
```
Queuing workflow.
Queue position: #0
Base...
Base: 1/8
Base: 2/8
Base: 3/8
Base: 4/8
Base: 5/8
Base: 6/8
Base: 7/8
Base: 8/8
Refiner...
Refiner: 1/3
Refiner: 2/3
Refiner: 3/3
Decoding...
Saving image on backend...
Result (cached: no):
{'images': [{'filename': 'ComfyUI_00101_.png', 'subfolder': '', 'type': 'output'}]}
2023-12-18 22:23:37,441 [INFO ][__main__ ] Queuing workflow.
2023-12-18 22:23:37,533 [INFO ][root ] {'prompt_id': 'ae431d7f-e212-4b05-8bcd-ebc8d346554f', 'number': 0, 'node_errors': {}}
2023-12-18 22:23:37,609 [INFO ][__main__ ] Queue position: #0
2023-12-18 22:23:44,484 [INFO ][__main__ ] Base...
2023-12-18 22:23:46,335 [INFO ][__main__ ] Base: 1/8
2023-12-18 22:23:46,351 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:46,988 [INFO ][__main__ ] Base: 2/8
2023-12-18 22:23:47,002 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:47,663 [INFO ][__main__ ] Base: 3/8
2023-12-18 22:23:47,678 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:48,314 [INFO ][__main__ ] Base: 4/8
2023-12-18 22:23:48,330 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:48,964 [INFO ][__main__ ] Base: 5/8
2023-12-18 22:23:48,984 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:49,614 [INFO ][__main__ ] Base: 6/8
2023-12-18 22:23:49,712 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:50,271 [INFO ][__main__ ] Base: 7/8
2023-12-18 22:23:50,294 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:50,920 [INFO ][__main__ ] Base: 8/8
2023-12-18 22:23:50,947 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:50,948 [INFO ][__main__ ] Refiner...
2023-12-18 22:23:55,148 [INFO ][__main__ ] Refiner: 1/3
2023-12-18 22:23:55,291 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:55,826 [INFO ][__main__ ] Refiner: 2/3
2023-12-18 22:23:55,855 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:56,511 [INFO ][__main__ ] Refiner: 3/3
2023-12-18 22:23:56,541 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE)
2023-12-18 22:23:56,542 [INFO ][__main__ ] Decoding...
2023-12-18 22:23:57,466 [INFO ][__main__ ] Saving image on backend...
2023-12-18 22:23:57,901 [INFO ][__main__ ] Result (cached: no): {'images': [{'filename': 'ComfyUI_00022_.png', 'subfolder': '', 'type': 'output'}]}
2023-12-18 22:23:57,901 [INFO ][__main__ ] Saving backend image ComfyUI_00022_.png to potato.png
```
The file will be saved in the root directory.

Expand Down
129 changes: 88 additions & 41 deletions comfyui_utils/comfy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import io
import json
import logging
from typing import Any, Callable, Union
from typing import Any, Callable, Union, Optional
import uuid

import struct
from PIL import Image
from io import BytesIO
import aiohttp

import aiohttp.client_exceptions

# Inherit this class to specify callbacks during prompt execution.
class Callbacks(abc.ABC):
Expand All @@ -27,6 +29,9 @@ async def in_progress(self, node_id: int, progress: int, total: int):
@abc.abstractmethod
async def completed(self, outputs: dict[str, Any], cached: bool):
"""Called when the prompt completes, with the final output."""
@abc.abstractmethod
async def image_received(self, image: Image.Image):
"""Called when the prompt's queue return a sampler image, a pillow image object is pass to this function"""
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "is pass" -> "is passed" + period at end just for consistency



StrDict = dict[str, Any] # parsed JSON of an API-formatted ComfyUI workflow.
Expand Down Expand Up @@ -88,33 +93,69 @@ async def _prompt_websocket(sess: PromptSession, callbacks: Callbacks) -> None:
# logging.warning(msg)
if msg.type == aiohttp.WSMsgType.ERROR:
raise BrokenPipeError(f"WebSocket error: {msg.data}")
assert msg.type == aiohttp.WSMsgType.TEXT
message = json.loads(msg.data)
# Handle prompt being started.
if message["type"] == "status":
queue_or_result = await _get_queue_position_or_cached_result(sess)
if isinstance(queue_or_result, int):
await callbacks.queue_position(queue_or_result)
else:
await callbacks.completed(queue_or_result, True)
if msg.type == aiohttp.WSMsgType.TEXT:
message = json.loads(msg.data)
# Handle prompt being started.
if message["type"] == "status":
queue_or_result = await _get_queue_position_or_cached_result(sess)
if isinstance(queue_or_result, int):
await callbacks.queue_position(queue_or_result)
else:
await callbacks.completed(queue_or_result, True)
break
# Handle a node being executed.
if message["type"] == "executing":
if message["data"]["node"] is not None:
node_id = int(message["data"]["node"])
current_node = node_id
await callbacks.in_progress(current_node, 0, 0)
# Handle completion of the request.
if message["type"] == "executed":
assert message["data"]["prompt_id"] == sess.prompt_id
await callbacks.completed(message["data"]["output"], False)
break
# Handle a node being executed.
if message["type"] == "executing":
if message["data"]["node"] is not None:
node_id = int(message["data"]["node"])
current_node = node_id
await callbacks.in_progress(current_node, 0, 0)
# Handle completion of the request.
if message["type"] == "executed":
assert message["data"]["prompt_id"] == sess.prompt_id
await callbacks.completed(message["data"]["output"], False)
break
# Handle progress on a node.
if message["type"] == "progress":
progress = int(message["data"]["value"])
total = int(message["data"]["max"])
await callbacks.in_progress(current_node, progress, total)

# Handle progress on a node.
if message["type"] == "progress":
progress = int(message["data"]["value"])
total = int(message["data"]["max"])
await callbacks.in_progress(current_node, progress, total)
elif msg.type == aiohttp.WSMsgType.BINARY:
image= await receive_image(msg.data)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space before "="

if image is not None:
await callbacks.image_received(image)
else:
logging.warning("Not text message, message type: %r", msg.type)


async def receive_image(image_data) -> Optional[Image.Image]:
'''
Reference : https://github.com/comfyanonymous/ComfyUI.git server.py function send_image
Rebuild an PIL Image from the data received on the websocket. Return None on any errors.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "a PIL image" + perhaps it's ok to skip the try and just let exception bubble up (ComfyAPI already throws exceptions)

'''
try:
type_num, = struct.unpack_from('>I', image_data, 0)
event_type_num, = struct.unpack_from('>I', image_data, 4)
if type_num == 1:
image_type= "JPEG"
elif type_num == 2:
image_type= "PNG"
else:
logging.error("Unsuported type received : %d", type_num)
return None

if event_type_num == 1:
event_type= "PREVIEW_IMAGE"
elif event_type_num == 2:
event_type= "UNENCODED_PREVIEW_IMAGE"
else:
event_type= f"UNKNOWN {event_type_num}"
logging.info(f"Received an {image_type} ({event_type})")
bytesIO = BytesIO(image_data[8:])
image= Image.open(bytesIO)
return image
except Exception as e:
logging.exception("Error on receiving image.")
return None

class ComfyAPI:
def __init__(self, address):
Expand All @@ -138,18 +179,24 @@ async def submit(self, prompt: StrDict, callbacks: Callbacks):
async with aiohttp.ClientSession() as session:
# Enqueue and get prompt ID.
async with session.post(f"http://{self.address}/prompt", data=init_data) as resp:
response_json = await resp.json()
logging.info(response_json)
if "error" in response_json:
if "node_errors" not in response_json:
raise ValueError(response_json["error"]["message"])
errors = []
for node_id, data in response_json["node_errors"].items():
for node_error in data["errors"]:
errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}")
raise ValueError("\n" + "\n".join(errors))

prompt_id = response_json['prompt_id']
try:
response_json = await resp.json()
logging.info(response_json)
if "error" in response_json:
if "node_errors" not in response_json:
raise ValueError(response_json["error"]["message"])
errors = []
for node_id, data in response_json["node_errors"].items():
for node_error in data["errors"]:
errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}")
raise ValueError("\n" + "\n".join(errors))

prompt_id = response_json['prompt_id']
except aiohttp.client_exceptions.ContentTypeError as e:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the except be moved up to only wrap the await resp.json() so any other errors bubble up normally?

text= await resp.text()
logging.error("Error, unexpected response, not a json response. %s \nReceived : \n%s", e, text)
raise ValueError("Not a JSON response : \n%s" % text)

# Listen on a websocket until the prompt completes and invoke callbacks.
await _prompt_websocket(PromptSession(
client_id=client_id,
Expand Down
78 changes: 64 additions & 14 deletions examples/e2e.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
"""End-to-end example for creating an image with SDXL base + refiner"""
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep the file comment?

import logging
import sys

formatter = logging.Formatter('%(asctime)s [%(levelname)-9s][%(name)-20s] %(message)s')
sh= logging.StreamHandler(stream=sys.stdout)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space before "="/"+=" throughout the file to follow standard style?

sh.setLevel(logging.DEBUG)
sh.setFormatter(formatter)
helper_logger= logging.root
helper_logger.addHandler(sh)
helper_logger.setLevel(logging.INFO)
logger= helper_logger.getChild(__name__)

import pathlib
import tempfile

import argparse
import io
import asyncio
import json
from comfyui_utils import comfy
from comfyui_utils import gen_prompts
from PIL.Image import Image

async def run_base_and_refiner(address: str, user_string: str, output_path=None):
async def run_base_and_refiner(address: str, user_string: str, output_path: pathlib.Path=None, sample_dir: pathlib.Path=None):
comfyui = comfy.ComfyAPI(address)

# Load the stored API format prompt.
Expand All @@ -25,7 +39,7 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None)
# Returns a list of warnings if any argument is invalid.
parsed = gen_prompts.parse_args(user_string, prompt_config)
for warning in parsed.warnings:
print(f"WARNING: {warning}")
logger.warning(warning)
# Adjust the prompt with user args.
prompt["6"]["inputs"]["text"] = parsed.cleaned
prompt["10"]["inputs"]["end_at_step"] = parsed.result.base_steps
Expand All @@ -41,44 +55,80 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None)
}
# Configure the callbacks which will write to it during execution while printing updates.
class Callbacks(comfy.Callbacks):
def __init__(self):
self.sample_count=0
self.sample_prefix= None
if sample_dir is not None:
sample_file_name= None
while sample_file_name is None or pathlib.Path(sample_file_name).exists():
sample_file_name= pathlib.Path(tempfile.mktemp('.png', 'prompt_sample_', sample_dir))
self.sample_prefix= sample_file_name.stem
async def queue_position(self, position):
print(f"Queue position: #{position}")
logger.info(f"Queue position: #{position}")
async def in_progress(self, node_id, progress, total):
progress = f"{progress}/{total}" if total else None
if node_id == 10:
print(f"Base: {progress}" if progress else "Base...")
logger.info(f"Base: {progress}" if progress else "Base...")
if node_id == 11:
print(f"Refiner: {progress}" if progress else "Refiner...")
logger.info(f"Refiner: {progress}" if progress else "Refiner...")
elif node_id == 17:
print("Decoding...")
logger.info("Decoding...")
elif node_id == 19:
print("Saving image on backend...")
logger.info("Saving image on backend...")
async def completed(self, outputs, cached):
result["output"] = outputs
result["cached"] = cached
async def image_received(self, image: Image):
if sample_dir is not None:
sample_file= None
while sample_file is None or sample_file.exists():
sample_file= sample_dir.joinpath(f"{self.sample_prefix}_{self.sample_count:04d}.png")
self.sample_count+= 1
logger.info("Save sample file : %s", sample_file.name)
image.save(sample_file)

# Run the prompt and print the result.
print("Queuing workflow.")
await comfyui.submit(prompt, Callbacks())
print(f"Result (cached: {'yes' if result['cached'] else 'no'}):\n{result['output']}")
logger.info("Queuing workflow.")
try:
await comfyui.submit(prompt, Callbacks())
except ValueError as e:
logger.error("Error processing template: %s", e)
return
logger.info(f"Result (cached: {'yes' if result['cached'] else 'no'}): {result['output']}")

# Write the result to a local file.
if output_path is not None:
backend_filepath = result["output"]["images"][0]
async def on_load(data_file : io.BytesIO):
with open(output_path, "wb") as f:
f.write(data_file.getbuffer())
logger.info("Saving backend image %s to %s", backend_filepath.get("filename", None), output_path)
await comfyui.fetch(backend_filepath, on_load)


def main():
parser = argparse.ArgumentParser(description='Run an SDXL or other workflow on a deployed ComfyUI server.')
parser.add_argument("--address", type=str, help="the ComfyUI endpoint", required=True)
parser.add_argument("--prompt", type=str, help="the user prompt", required=True)
parser.add_argument("--output_path", type=str, help="the output path", default=None)
parser.add_argument("--output_path", type=pathlib.Path, help="The filename to store the final image received from ComfyUI", default=None)
parser.add_argument("--output_dir", type=pathlib.Path, help="A folder to store the final image received from ComfyUI, name will be prompt_RANDOM.png", default=None)
parser.add_argument("--sample_dir", type=pathlib.Path, help="Folder to store sample images received from ComfyUI, images names will be prompt_sample_RANDOM_COUNTER.png", default=None)
args = parser.parse_args()

asyncio.run(run_base_and_refiner(args.address, args.prompt, args.output_path))
outfile=args.output_path
if outfile is None and args.output_dir is not None:
if not args.output_dir.is_dir():
logger.error("Argument --output_dir is %s but it's not a directory (full path resolved: %s)", args.output_dir, args.output_dir.resolve())
exit(1)
#Find a temporary name for the samples
while outfile is None or pathlib.Path(outfile).exists():
outfile= pathlib.Path(tempfile.mktemp('.png', 'prompt_', args.output_dir))
logger.debug("(%s) %r", type(outfile), outfile)
if args.sample_dir is not None and not args.sample_dir.is_dir():
logger.error("Argument --sample_dir is %s but it's not a directory (full path resolved: %s)", args.sample_dir, args.sample_dir.resolve())
exit(1)

asyncio.run(run_base_and_refiner(args.address, args.prompt, outfile, args.sample_dir))

if __name__ == "__main__":
main()
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
aiohttp
pillow