-
Notifications
You must be signed in to change notification settings - Fork 5
Handle websocket binary response from ComfyUI server #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,11 +10,13 @@ | |
| import io | ||
| import json | ||
| import logging | ||
| from typing import Any, Callable, Union | ||
| from typing import Any, Callable, Union, Optional | ||
| import uuid | ||
|
|
||
| import struct | ||
| from PIL import Image | ||
| from io import BytesIO | ||
| import aiohttp | ||
|
|
||
| import aiohttp.client_exceptions | ||
|
|
||
| # Inherit this class to specify callbacks during prompt execution. | ||
| class Callbacks(abc.ABC): | ||
|
|
@@ -27,6 +29,9 @@ async def in_progress(self, node_id: int, progress: int, total: int): | |
| @abc.abstractmethod | ||
| async def completed(self, outputs: dict[str, Any], cached: bool): | ||
| """Called when the prompt completes, with the final output.""" | ||
| @abc.abstractmethod | ||
| async def image_received(self, image: Image.Image): | ||
| """Called when the prompt's queue return a sampler image, a pillow image object is pass to this function""" | ||
|
|
||
|
|
||
| StrDict = dict[str, Any] # parsed JSON of an API-formatted ComfyUI workflow. | ||
|
|
@@ -88,33 +93,69 @@ async def _prompt_websocket(sess: PromptSession, callbacks: Callbacks) -> None: | |
| # logging.warning(msg) | ||
| if msg.type == aiohttp.WSMsgType.ERROR: | ||
| raise BrokenPipeError(f"WebSocket error: {msg.data}") | ||
| assert msg.type == aiohttp.WSMsgType.TEXT | ||
| message = json.loads(msg.data) | ||
| # Handle prompt being started. | ||
| if message["type"] == "status": | ||
| queue_or_result = await _get_queue_position_or_cached_result(sess) | ||
| if isinstance(queue_or_result, int): | ||
| await callbacks.queue_position(queue_or_result) | ||
| else: | ||
| await callbacks.completed(queue_or_result, True) | ||
| if msg.type == aiohttp.WSMsgType.TEXT: | ||
| message = json.loads(msg.data) | ||
| # Handle prompt being started. | ||
| if message["type"] == "status": | ||
| queue_or_result = await _get_queue_position_or_cached_result(sess) | ||
| if isinstance(queue_or_result, int): | ||
| await callbacks.queue_position(queue_or_result) | ||
| else: | ||
| await callbacks.completed(queue_or_result, True) | ||
| break | ||
| # Handle a node being executed. | ||
| if message["type"] == "executing": | ||
| if message["data"]["node"] is not None: | ||
| node_id = int(message["data"]["node"]) | ||
| current_node = node_id | ||
| await callbacks.in_progress(current_node, 0, 0) | ||
| # Handle completion of the request. | ||
| if message["type"] == "executed": | ||
| assert message["data"]["prompt_id"] == sess.prompt_id | ||
| await callbacks.completed(message["data"]["output"], False) | ||
| break | ||
| # Handle a node being executed. | ||
| if message["type"] == "executing": | ||
| if message["data"]["node"] is not None: | ||
| node_id = int(message["data"]["node"]) | ||
| current_node = node_id | ||
| await callbacks.in_progress(current_node, 0, 0) | ||
| # Handle completion of the request. | ||
| if message["type"] == "executed": | ||
| assert message["data"]["prompt_id"] == sess.prompt_id | ||
| await callbacks.completed(message["data"]["output"], False) | ||
| break | ||
| # Handle progress on a node. | ||
| if message["type"] == "progress": | ||
| progress = int(message["data"]["value"]) | ||
| total = int(message["data"]["max"]) | ||
| await callbacks.in_progress(current_node, progress, total) | ||
|
|
||
| # Handle progress on a node. | ||
| if message["type"] == "progress": | ||
| progress = int(message["data"]["value"]) | ||
| total = int(message["data"]["max"]) | ||
| await callbacks.in_progress(current_node, progress, total) | ||
| elif msg.type == aiohttp.WSMsgType.BINARY: | ||
| image= await receive_image(msg.data) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: space before "=" |
||
| if image is not None: | ||
| await callbacks.image_received(image) | ||
| else: | ||
| logging.warning("Not text message, message type: %r", msg.type) | ||
|
|
||
|
|
||
| async def receive_image(image_data) -> Optional[Image.Image]: | ||
| ''' | ||
| Reference : https://github.com/comfyanonymous/ComfyUI.git server.py function send_image | ||
| Rebuild an PIL Image from the data received on the websocket. Return None on any errors. | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "a PIL image" + perhaps it's ok to skip the try and just let exception bubble up (ComfyAPI already throws exceptions) |
||
| ''' | ||
| try: | ||
| type_num, = struct.unpack_from('>I', image_data, 0) | ||
| event_type_num, = struct.unpack_from('>I', image_data, 4) | ||
| if type_num == 1: | ||
| image_type= "JPEG" | ||
| elif type_num == 2: | ||
| image_type= "PNG" | ||
| else: | ||
| logging.error("Unsuported type received : %d", type_num) | ||
| return None | ||
|
|
||
| if event_type_num == 1: | ||
| event_type= "PREVIEW_IMAGE" | ||
| elif event_type_num == 2: | ||
| event_type= "UNENCODED_PREVIEW_IMAGE" | ||
| else: | ||
| event_type= f"UNKNOWN {event_type_num}" | ||
| logging.info(f"Received an {image_type} ({event_type})") | ||
| bytesIO = BytesIO(image_data[8:]) | ||
| image= Image.open(bytesIO) | ||
| return image | ||
| except Exception as e: | ||
| logging.exception("Error on receiving image.") | ||
| return None | ||
|
|
||
| class ComfyAPI: | ||
| def __init__(self, address): | ||
|
|
@@ -138,18 +179,24 @@ async def submit(self, prompt: StrDict, callbacks: Callbacks): | |
| async with aiohttp.ClientSession() as session: | ||
| # Enqueue and get prompt ID. | ||
| async with session.post(f"http://{self.address}/prompt", data=init_data) as resp: | ||
| response_json = await resp.json() | ||
| logging.info(response_json) | ||
| if "error" in response_json: | ||
| if "node_errors" not in response_json: | ||
| raise ValueError(response_json["error"]["message"]) | ||
| errors = [] | ||
| for node_id, data in response_json["node_errors"].items(): | ||
| for node_error in data["errors"]: | ||
| errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}") | ||
| raise ValueError("\n" + "\n".join(errors)) | ||
|
|
||
| prompt_id = response_json['prompt_id'] | ||
| try: | ||
| response_json = await resp.json() | ||
| logging.info(response_json) | ||
| if "error" in response_json: | ||
| if "node_errors" not in response_json: | ||
| raise ValueError(response_json["error"]["message"]) | ||
| errors = [] | ||
| for node_id, data in response_json["node_errors"].items(): | ||
| for node_error in data["errors"]: | ||
| errors.append(f"Node {node_id}, {node_error['details']}: {node_error['message']}") | ||
| raise ValueError("\n" + "\n".join(errors)) | ||
|
|
||
| prompt_id = response_json['prompt_id'] | ||
| except aiohttp.client_exceptions.ContentTypeError as e: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can the except be moved up to only wrap the |
||
| text= await resp.text() | ||
| logging.error("Error, unexpected response, not a json response. %s \nReceived : \n%s", e, text) | ||
| raise ValueError("Not a JSON response : \n%s" % text) | ||
|
|
||
| # Listen on a websocket until the prompt completes and invoke callbacks. | ||
| await _prompt_websocket(PromptSession( | ||
| client_id=client_id, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,27 @@ | ||
| """End-to-end example for creating an image with SDXL base + refiner""" | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we keep the file comment? |
||
| import logging | ||
| import sys | ||
|
|
||
| formatter = logging.Formatter('%(asctime)s [%(levelname)-9s][%(name)-20s] %(message)s') | ||
| sh= logging.StreamHandler(stream=sys.stdout) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: space before "="/"+=" throughout the file to follow standard style? |
||
| sh.setLevel(logging.DEBUG) | ||
| sh.setFormatter(formatter) | ||
| helper_logger= logging.root | ||
| helper_logger.addHandler(sh) | ||
| helper_logger.setLevel(logging.INFO) | ||
| logger= helper_logger.getChild(__name__) | ||
|
|
||
| import pathlib | ||
| import tempfile | ||
|
|
||
| import argparse | ||
| import io | ||
| import asyncio | ||
| import json | ||
| from comfyui_utils import comfy | ||
| from comfyui_utils import gen_prompts | ||
| from PIL.Image import Image | ||
|
|
||
| async def run_base_and_refiner(address: str, user_string: str, output_path=None): | ||
| async def run_base_and_refiner(address: str, user_string: str, output_path: pathlib.Path=None, sample_dir: pathlib.Path=None): | ||
| comfyui = comfy.ComfyAPI(address) | ||
|
|
||
| # Load the stored API format prompt. | ||
|
|
@@ -25,7 +39,7 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None) | |
| # Returns a list of warnings if any argument is invalid. | ||
| parsed = gen_prompts.parse_args(user_string, prompt_config) | ||
| for warning in parsed.warnings: | ||
| print(f"WARNING: {warning}") | ||
| logger.warning(warning) | ||
| # Adjust the prompt with user args. | ||
| prompt["6"]["inputs"]["text"] = parsed.cleaned | ||
| prompt["10"]["inputs"]["end_at_step"] = parsed.result.base_steps | ||
|
|
@@ -41,44 +55,80 @@ async def run_base_and_refiner(address: str, user_string: str, output_path=None) | |
| } | ||
| # Configure the callbacks which will write to it during execution while printing updates. | ||
| class Callbacks(comfy.Callbacks): | ||
| def __init__(self): | ||
| self.sample_count=0 | ||
| self.sample_prefix= None | ||
| if sample_dir is not None: | ||
| sample_file_name= None | ||
| while sample_file_name is None or pathlib.Path(sample_file_name).exists(): | ||
| sample_file_name= pathlib.Path(tempfile.mktemp('.png', 'prompt_sample_', sample_dir)) | ||
| self.sample_prefix= sample_file_name.stem | ||
| async def queue_position(self, position): | ||
| print(f"Queue position: #{position}") | ||
| logger.info(f"Queue position: #{position}") | ||
| async def in_progress(self, node_id, progress, total): | ||
| progress = f"{progress}/{total}" if total else None | ||
| if node_id == 10: | ||
| print(f"Base: {progress}" if progress else "Base...") | ||
| logger.info(f"Base: {progress}" if progress else "Base...") | ||
| if node_id == 11: | ||
| print(f"Refiner: {progress}" if progress else "Refiner...") | ||
| logger.info(f"Refiner: {progress}" if progress else "Refiner...") | ||
| elif node_id == 17: | ||
| print("Decoding...") | ||
| logger.info("Decoding...") | ||
| elif node_id == 19: | ||
| print("Saving image on backend...") | ||
| logger.info("Saving image on backend...") | ||
| async def completed(self, outputs, cached): | ||
| result["output"] = outputs | ||
| result["cached"] = cached | ||
| async def image_received(self, image: Image): | ||
| if sample_dir is not None: | ||
| sample_file= None | ||
| while sample_file is None or sample_file.exists(): | ||
| sample_file= sample_dir.joinpath(f"{self.sample_prefix}_{self.sample_count:04d}.png") | ||
| self.sample_count+= 1 | ||
| logger.info("Save sample file : %s", sample_file.name) | ||
| image.save(sample_file) | ||
|
|
||
| # Run the prompt and print the result. | ||
| print("Queuing workflow.") | ||
| await comfyui.submit(prompt, Callbacks()) | ||
| print(f"Result (cached: {'yes' if result['cached'] else 'no'}):\n{result['output']}") | ||
| logger.info("Queuing workflow.") | ||
| try: | ||
| await comfyui.submit(prompt, Callbacks()) | ||
| except ValueError as e: | ||
| logger.error("Error processing template: %s", e) | ||
| return | ||
| logger.info(f"Result (cached: {'yes' if result['cached'] else 'no'}): {result['output']}") | ||
|
|
||
| # Write the result to a local file. | ||
| if output_path is not None: | ||
| backend_filepath = result["output"]["images"][0] | ||
| async def on_load(data_file : io.BytesIO): | ||
| with open(output_path, "wb") as f: | ||
| f.write(data_file.getbuffer()) | ||
| logger.info("Saving backend image %s to %s", backend_filepath.get("filename", None), output_path) | ||
| await comfyui.fetch(backend_filepath, on_load) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description='Run an SDXL or other workflow on a deployed ComfyUI server.') | ||
| parser.add_argument("--address", type=str, help="the ComfyUI endpoint", required=True) | ||
| parser.add_argument("--prompt", type=str, help="the user prompt", required=True) | ||
| parser.add_argument("--output_path", type=str, help="the output path", default=None) | ||
| parser.add_argument("--output_path", type=pathlib.Path, help="The filename to store the final image received from ComfyUI", default=None) | ||
| parser.add_argument("--output_dir", type=pathlib.Path, help="A folder to store the final image received from ComfyUI, name will be prompt_RANDOM.png", default=None) | ||
| parser.add_argument("--sample_dir", type=pathlib.Path, help="Folder to store sample images received from ComfyUI, images names will be prompt_sample_RANDOM_COUNTER.png", default=None) | ||
| args = parser.parse_args() | ||
|
|
||
| asyncio.run(run_base_and_refiner(args.address, args.prompt, args.output_path)) | ||
| outfile=args.output_path | ||
| if outfile is None and args.output_dir is not None: | ||
| if not args.output_dir.is_dir(): | ||
| logger.error("Argument --output_dir is %s but it's not a directory (full path resolved: %s)", args.output_dir, args.output_dir.resolve()) | ||
| exit(1) | ||
| #Find a temporary name for the samples | ||
| while outfile is None or pathlib.Path(outfile).exists(): | ||
| outfile= pathlib.Path(tempfile.mktemp('.png', 'prompt_', args.output_dir)) | ||
| logger.debug("(%s) %r", type(outfile), outfile) | ||
| if args.sample_dir is not None and not args.sample_dir.is_dir(): | ||
| logger.error("Argument --sample_dir is %s but it's not a directory (full path resolved: %s)", args.sample_dir, args.sample_dir.resolve()) | ||
| exit(1) | ||
|
|
||
| asyncio.run(run_base_and_refiner(args.address, args.prompt, outfile, args.sample_dir)) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| aiohttp | ||
| pillow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "is pass" -> "is passed" + period at end just for consistency