diff --git a/README.md b/README.md index 82c3270..ab513b6 100644 --- a/README.md +++ b/README.md @@ -48,33 +48,63 @@ await comfyui.fetch(backend_filepath, on_load) To test the library with a sample SDXL workflow, run the following after installing (replace the address with your ComfyUI endpoint). Make sure your ComfyUI has `sd_xl_base_1.0.safetensors` and `sd_xl_refiner_1.0.safetensors` installed (or replace the workflow). +Program options : + +```usage: e2e.py [-h] --address ADDRESS --prompt PROMPT [--output_path OUTPUT_PATH] [--output_dir OUTPUT_DIR] [--sample_dir SAMPLE_DIR] + +Run an SDXL or other workflow on a deployed ComfyUI server. + +options: + -h, --help show this help message and exit + --address ADDRESS the ComfyUI endpoint + --prompt PROMPT the user prompt + --output_path OUTPUT_PATH + The filename to store the final image received from ComfyUI + --output_dir OUTPUT_DIR + A folder to store the final image received from ComfyUI, name will be prompt_RANDOM.png + --sample_dir SAMPLE_DIR + Folder to store sample images received from ComfyUI, images names will be prompt_sample_RANDOM_COUNTER.png +``` +Example: ```python comfy_ui_example_e2e\ --address='192.168.0.10:11010'\ --prompt='a smiling potato $base_steps=8$refiner_steps=3'\ - --output='./potato.png' + --output_path='./potato.png' ``` The single quotes are important so your shell doesn't try to parse the `$`'s. Expected output: ``` -Queuing workflow. -Queue position: #0 -Base... -Base: 1/8 -Base: 2/8 -Base: 3/8 -Base: 4/8 -Base: 5/8 -Base: 6/8 -Base: 7/8 -Base: 8/8 -Refiner... -Refiner: 1/3 -Refiner: 2/3 -Refiner: 3/3 -Decoding... -Saving image on backend... -Result (cached: no): -{'images': [{'filename': 'ComfyUI_00101_.png', 'subfolder': '', 'type': 'output'}]} +2023-12-18 22:23:37,441 [INFO ][__main__ ] Queuing workflow. +2023-12-18 22:23:37,533 [INFO ][root ] {'prompt_id': 'ae431d7f-e212-4b05-8bcd-ebc8d346554f', 'number': 0, 'node_errors': {}} +2023-12-18 22:23:37,609 [INFO ][__main__ ] Queue position: #0 +2023-12-18 22:23:44,484 [INFO ][__main__ ] Base... +2023-12-18 22:23:46,335 [INFO ][__main__ ] Base: 1/8 +2023-12-18 22:23:46,351 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:46,988 [INFO ][__main__ ] Base: 2/8 +2023-12-18 22:23:47,002 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:47,663 [INFO ][__main__ ] Base: 3/8 +2023-12-18 22:23:47,678 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:48,314 [INFO ][__main__ ] Base: 4/8 +2023-12-18 22:23:48,330 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:48,964 [INFO ][__main__ ] Base: 5/8 +2023-12-18 22:23:48,984 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:49,614 [INFO ][__main__ ] Base: 6/8 +2023-12-18 22:23:49,712 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:50,271 [INFO ][__main__ ] Base: 7/8 +2023-12-18 22:23:50,294 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:50,920 [INFO ][__main__ ] Base: 8/8 +2023-12-18 22:23:50,947 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:50,948 [INFO ][__main__ ] Refiner... +2023-12-18 22:23:55,148 [INFO ][__main__ ] Refiner: 1/3 +2023-12-18 22:23:55,291 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:55,826 [INFO ][__main__ ] Refiner: 2/3 +2023-12-18 22:23:55,855 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:56,511 [INFO ][__main__ ] Refiner: 3/3 +2023-12-18 22:23:56,541 [INFO ][root ] Received an JPEG (PREVIEW_IMAGE) +2023-12-18 22:23:56,542 [INFO ][__main__ ] Decoding... +2023-12-18 22:23:57,466 [INFO ][__main__ ] Saving image on backend... +2023-12-18 22:23:57,901 [INFO ][__main__ ] Result (cached: no): {'images': [{'filename': 'ComfyUI_00022_.png', 'subfolder': '', 'type': 'output'}]} +2023-12-18 22:23:57,901 [INFO ][__main__ ] Saving backend image ComfyUI_00022_.png to potato.png ``` The file will be saved in the root directory. diff --git a/comfyui_utils/comfy.py b/comfyui_utils/comfy.py index a12ea58..7a3eac6 100644 --- a/comfyui_utils/comfy.py +++ b/comfyui_utils/comfy.py @@ -10,11 +10,13 @@ import io import json import logging -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import uuid - +import struct +from PIL import Image +from io import BytesIO import aiohttp - +import aiohttp.client_exceptions # Inherit this class to specify callbacks during prompt execution. class Callbacks(abc.ABC): @@ -27,6 +29,9 @@ async def in_progress(self, node_id: int, progress: int, total: int): @abc.abstractmethod async def completed(self, outputs: dict[str, Any], cached: bool): """Called when the prompt completes, with the final output.""" + @abc.abstractmethod + async def image_received(self, image: Image.Image): + """Called when the prompt's queue return a sampler image, a pillow image object is pass to this function""" StrDict = dict[str, Any] # parsed JSON of an API-formatted ComfyUI workflow. @@ -88,33 +93,69 @@ async def _prompt_websocket(sess: PromptSession, callbacks: Callbacks) -> None: # logging.warning(msg) if msg.type == aiohttp.WSMsgType.ERROR: raise BrokenPipeError(f"WebSocket error: {msg.data}") - assert msg.type == aiohttp.WSMsgType.TEXT - message = json.loads(msg.data) - # Handle prompt being started. - if message["type"] == "status": - queue_or_result = await _get_queue_position_or_cached_result(sess) - if isinstance(queue_or_result, int): - await callbacks.queue_position(queue_or_result) - else: - await callbacks.completed(queue_or_result, True) + if msg.type == aiohttp.WSMsgType.TEXT: + message = json.loads(msg.data) + # Handle prompt being started. + if message["type"] == "status": + queue_or_result = await _get_queue_position_or_cached_result(sess) + if isinstance(queue_or_result, int): + await callbacks.queue_position(queue_or_result) + else: + await callbacks.completed(queue_or_result, True) + break + # Handle a node being executed. + if message["type"] == "executing": + if message["data"]["node"] is not None: + node_id = int(message["data"]["node"]) + current_node = node_id + await callbacks.in_progress(current_node, 0, 0) + # Handle completion of the request. + if message["type"] == "executed": + assert message["data"]["prompt_id"] == sess.prompt_id + await callbacks.completed(message["data"]["output"], False) break - # Handle a node being executed. - if message["type"] == "executing": - if message["data"]["node"] is not None: - node_id = int(message["data"]["node"]) - current_node = node_id - await callbacks.in_progress(current_node, 0, 0) - # Handle completion of the request. - if message["type"] == "executed": - assert message["data"]["prompt_id"] == sess.prompt_id - await callbacks.completed(message["data"]["output"], False) - break - # Handle progress on a node. - if message["type"] == "progress": - progress = int(message["data"]["value"]) - total = int(message["data"]["max"]) - await callbacks.in_progress(current_node, progress, total) - + # Handle progress on a node. + if message["type"] == "progress": + progress = int(message["data"]["value"]) + total = int(message["data"]["max"]) + await callbacks.in_progress(current_node, progress, total) + elif msg.type == aiohttp.WSMsgType.BINARY: + image= await receive_image(msg.data) + if image is not None: + await callbacks.image_received(image) + else: + logging.warning("Not text message, message type: %r", msg.type) + + +async def receive_image(image_data) -> Optional[Image.Image]: + ''' + Reference : https://github.com/comfyanonymous/ComfyUI.git server.py function send_image + Rebuild an PIL Image from the data received on the websocket. Return None on any errors. + ''' + try: + type_num, = struct.unpack_from('>I', image_data, 0) + event_type_num, = struct.unpack_from('>I', image_data, 4) + if type_num == 1: + image_type= "JPEG" + elif type_num == 2: + image_type= "PNG" + else: + logging.error("Unsuported type received : %d", type_num) + return None + + if event_type_num == 1: + event_type= "PREVIEW_IMAGE" + elif event_type_num == 2: + event_type= "UNENCODED_PREVIEW_IMAGE" + else: + event_type= f"UNKNOWN {event_type_num}" + logging.info(f"Received an {image_type} ({event_type})") + bytesIO = BytesIO(image_data[8:]) + image= Image.open(bytesIO) + return image + except Exception as e: + logging.exception("Error on receiving image.") + return None class ComfyAPI: def __init__(self, address): @@ -138,18 +179,24 @@ async def submit(self, prompt: StrDict, callbacks: Callbacks): async with aiohttp.ClientSession() as session: # Enqueue and get prompt ID. async with session.post(f"http://{self.address}/prompt", data=init_data) as resp: - response_json = await resp.json() - logging.info(response_json) - if "error" in response_json: - if "node_errors" not in response_json: - raise ValueError(response_json["error"]["message"]) - errors = [] - for node_id, data in response_json["node_errors"].items(): - for node_error in data["errors"]: - errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}") - raise ValueError("\n" + "\n".join(errors)) - - prompt_id = response_json['prompt_id'] + try: + response_json = await resp.json() + logging.info(response_json) + if "error" in response_json: + if "node_errors" not in response_json: + raise ValueError(response_json["error"]["message"]) + errors = [] + for node_id, data in response_json["node_errors"].items(): + for node_error in data["errors"]: + errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}") + raise ValueError("\n" + "\n".join(errors)) + + prompt_id = response_json['prompt_id'] + except aiohttp.client_exceptions.ContentTypeError as e: + text= await resp.text() + logging.error("Error, unexpected response, not a json response. %s \nReceived : \n%s", e, text) + raise ValueError("Not a JSON response : \n%s" % text) + # Listen on a websocket until the prompt completes and invoke callbacks. await _prompt_websocket(PromptSession( client_id=client_id, diff --git a/examples/e2e.py b/examples/e2e.py index 051ac69..ec9939a 100644 --- a/examples/e2e.py +++ b/examples/e2e.py @@ -1,4 +1,17 @@ -"""End-to-end example for creating an image with SDXL base + refiner""" +import logging +import sys + +formatter = logging.Formatter('%(asctime)s [%(levelname)-9s][%(name)-20s] %(message)s') +sh= logging.StreamHandler(stream=sys.stdout) +sh.setLevel(logging.DEBUG) +sh.setFormatter(formatter) +helper_logger= logging.root +helper_logger.addHandler(sh) +helper_logger.setLevel(logging.INFO) +logger= helper_logger.getChild(__name__) + +import pathlib +import tempfile import argparse import io @@ -6,8 +19,9 @@ import json from comfyui_utils import comfy from comfyui_utils import gen_prompts +from PIL.Image import Image -async def run_base_and_refiner(address: str, user_string: str, output_path=None): +async def run_base_and_refiner(address: str, user_string: str, output_path: pathlib.Path=None, sample_dir: pathlib.Path=None): comfyui = comfy.ComfyAPI(address) # Load the stored API format prompt. @@ -25,7 +39,7 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None) # Returns a list of warnings if any argument is invalid. parsed = gen_prompts.parse_args(user_string, prompt_config) for warning in parsed.warnings: - print(f"WARNING: {warning}") + logger.warning(warning) # Adjust the prompt with user args. prompt["6"]["inputs"]["text"] = parsed.cleaned prompt["10"]["inputs"]["end_at_step"] = parsed.result.base_steps @@ -41,26 +55,46 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None) } # Configure the callbacks which will write to it during execution while printing updates. class Callbacks(comfy.Callbacks): + def __init__(self): + self.sample_count=0 + self.sample_prefix= None + if sample_dir is not None: + sample_file_name= None + while sample_file_name is None or pathlib.Path(sample_file_name).exists(): + sample_file_name= pathlib.Path(tempfile.mktemp('.png', 'prompt_sample_', sample_dir)) + self.sample_prefix= sample_file_name.stem async def queue_position(self, position): - print(f"Queue position: #{position}") + logger.info(f"Queue position: #{position}") async def in_progress(self, node_id, progress, total): progress = f"{progress}/{total}" if total else None if node_id == 10: - print(f"Base: {progress}" if progress else "Base...") + logger.info(f"Base: {progress}" if progress else "Base...") if node_id == 11: - print(f"Refiner: {progress}" if progress else "Refiner...") + logger.info(f"Refiner: {progress}" if progress else "Refiner...") elif node_id == 17: - print("Decoding...") + logger.info("Decoding...") elif node_id == 19: - print("Saving image on backend...") + logger.info("Saving image on backend...") async def completed(self, outputs, cached): result["output"] = outputs result["cached"] = cached + async def image_received(self, image: Image): + if sample_dir is not None: + sample_file= None + while sample_file is None or sample_file.exists(): + sample_file= sample_dir.joinpath(f"{self.sample_prefix}_{self.sample_count:04d}.png") + self.sample_count+= 1 + logger.info("Save sample file : %s", sample_file.name) + image.save(sample_file) # Run the prompt and print the result. - print("Queuing workflow.") - await comfyui.submit(prompt, Callbacks()) - print(f"Result (cached: {'yes' if result['cached'] else 'no'}):\n{result['output']}") + logger.info("Queuing workflow.") + try: + await comfyui.submit(prompt, Callbacks()) + except ValueError as e: + logger.error("Error processing template: %s", e) + return + logger.info(f"Result (cached: {'yes' if result['cached'] else 'no'}): {result['output']}") # Write the result to a local file. if output_path is not None: @@ -68,6 +102,7 @@ async def completed(self, outputs, cached): async def on_load(data_file : io.BytesIO): with open(output_path, "wb") as f: f.write(data_file.getbuffer()) + logger.info("Saving backend image %s to %s", backend_filepath.get("filename", None), output_path) await comfyui.fetch(backend_filepath, on_load) @@ -75,10 +110,25 @@ def main(): parser = argparse.ArgumentParser(description='Run an SDXL or other workflow on a deployed ComfyUI server.') parser.add_argument("--address", type=str, help="the ComfyUI endpoint", required=True) parser.add_argument("--prompt", type=str, help="the user prompt", required=True) - parser.add_argument("--output_path", type=str, help="the output path", default=None) + parser.add_argument("--output_path", type=pathlib.Path, help="The filename to store the final image received from ComfyUI", default=None) + parser.add_argument("--output_dir", type=pathlib.Path, help="A folder to store the final image received from ComfyUI, name will be prompt_RANDOM.png", default=None) + parser.add_argument("--sample_dir", type=pathlib.Path, help="Folder to store sample images received from ComfyUI, images names will be prompt_sample_RANDOM_COUNTER.png", default=None) args = parser.parse_args() - asyncio.run(run_base_and_refiner(args.address, args.prompt, args.output_path)) + outfile=args.output_path + if outfile is None and args.output_dir is not None: + if not args.output_dir.is_dir(): + logger.error("Argument --output_dir is %s but it's not a directory (full path resolved: %s)", args.output_dir, args.output_dir.resolve()) + exit(1) + #Find a temporary name for the samples + while outfile is None or pathlib.Path(outfile).exists(): + outfile= pathlib.Path(tempfile.mktemp('.png', 'prompt_', args.output_dir)) + logger.debug("(%s) %r", type(outfile), outfile) + if args.sample_dir is not None and not args.sample_dir.is_dir(): + logger.error("Argument --sample_dir is %s but it's not a directory (full path resolved: %s)", args.sample_dir, args.sample_dir.resolve()) + exit(1) + + asyncio.run(run_base_and_refiner(args.address, args.prompt, outfile, args.sample_dir)) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ee4ba4f..6d69677 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ aiohttp +pillow \ No newline at end of file