Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rustls = ["reqwest/rustls-tls", "tokio-tungstenite/rustls-tls-webpki-roots"]
bytes = "1.10.1"
futures-util = "0.3.31"
log = { version = "0.4.26", features = ["kv"] }
pin-project-lite = "0.2.16"
reqwest = { version = "0.12.12", features = [
"json",
"multipart",
Expand Down
4 changes: 3 additions & 1 deletion examples/text_to_image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async fn main() {
.filter_level(log::LevelFilter::Debug)
.init();

let (client, mut stream) = ClientBuilder::new("http://localhost:8188")
let (client, mut stream, handle) = ClientBuilder::new("http://localhost:8188")
.build()
.await
.unwrap();
Expand Down Expand Up @@ -200,4 +200,6 @@ async fn main() {
}
}
}

handle.abort();
}
85 changes: 45 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ pub use crate::errors::{ClientError, ClientResult};
use crate::meta::{FileInfo, PromptInfo};
use bytes::Bytes;
use errors::{ApiBody, ApiError};
use futures_util::StreamExt;
use futures_util::stream::{Stream, StreamExt};
use log::trace;
use meta::{Event, History, OtherEvent, Prompt, PromptStatus};
use pin_project_lite::pin_project;
use reqwest::{
Body, IntoUrl, Response,
multipart::{self},
};
use serde_json::{Value, json};
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use tokio::{
sync::mpsc,
Expand Down Expand Up @@ -97,16 +99,24 @@ impl<U: IntoUrl> ClientBuilder<U> {
self
}

/// Builds the [`ComfyUIClient`] along with an associated [`EventStream`].
/// Builds the [`ComfyUIClient`] along with an associated [`EventStream`]
/// and a background task handle.
///
/// This method establishes a websocket connection and spawns an
/// asynchronous task to process incoming messages.
/// asynchronous task to process incoming messages. If reconnection is
/// enabled, the task will automatically attempt to reconnect when the
/// WebSocket connection drops unexpectedly.
///
/// # Returns
///
/// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success,
/// or an error.
pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> {
/// A tuple containing:
/// - The [`ComfyUIClient`] for HTTP API interactions
/// - An [`EventStream`] for receiving real-time events
/// - A [`JoinHandle`] for the background task that manages the WebSocket
/// connection
///
/// Returns an error if the initial connection cannot be established.
pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream, JoinHandle<()>)> {
let base_url = self.base_url.into_url()?;
let http_client = reqwest::Client::new();
let client_id = Uuid::new_v4().to_string();
Expand Down Expand Up @@ -223,12 +233,9 @@ impl<U: IntoUrl> ClientBuilder<U> {
client_id,
};

let stream = EventStream {
stream_handle,
rx_stream,
};
let stream = EventStream { rx_stream };

Ok((client, stream))
Ok((client, stream, stream_handle))
}

/// Builds a [`ComfyUIClient`] instance configured for HTTP-only
Expand Down Expand Up @@ -456,26 +463,35 @@ impl ComfyUIClient {
}
}

/// A structure representing the event stream received via a websocket
/// connection.
///
/// This stream continuously processes events from the ComfyUI service.
pub struct EventStream {
stream_handle: JoinHandle<()>,
rx_stream: ReceiverStream<ClientResult<Event>>,
pin_project! {
/// A structure representing the event stream received via a websocket connection.
///
/// This stream continuously processes events from the ComfyUI service.
/// It handles WebSocket connection management including automatic reconnection
/// when enabled through the [`ClientBuilder`].
///
/// The stream emits various events including execution status updates, errors,
/// and connection state changes. All WebSocket communication is managed by a
/// background task, allowing the stream to be consumed without worrying about
/// connection details.
pub struct EventStream {
#[pin]
rx_stream: ReceiverStream<ClientResult<Event>>,
}
}

impl EventStream {
/// Handles a single websocket message and attempts to parse it as an
/// [`Event`].
///
/// For text messages, it tries to deserialize the message into an
/// [`Event`]. If the deserialization fails, it wraps the message as
/// [`Event::Unknown`]. Other message types are ignored.
/// [`Event`]. If deserialization fails, it wraps the message as
/// [`Event::Unknown`]. Non-text message types are ignored and return
/// `None`.
///
/// # Parameters
///
/// - `msg`: A result containing a [`Message`] from the websocket.
/// - `msg`: A [`Message`] from the websocket.
///
/// # Returns
///
Expand All @@ -496,27 +512,16 @@ impl EventStream {
}
}

impl Drop for EventStream {
/// When the [`EventStream`] is dropped, abort the associated websocket
/// handling task.
fn drop(&mut self) {
self.stream_handle.abort();
}
}
impl Stream for EventStream {
type Item = ClientResult<Event>;

impl Deref for EventStream {
type Target = ReceiverStream<ClientResult<Event>>;

/// Allows access to the inner [`ReceiverStream`] containing the events.
fn deref(&self) -> &Self::Target {
&self.rx_stream
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
this.rx_stream.poll_next(cx)
}
}

impl DerefMut for EventStream {
/// Allows mutable access to the inner [`ReceiverStream`].
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rx_stream
fn size_hint(&self) -> (usize, Option<usize>) {
self.rx_stream.size_hint()
}
}

Expand Down
3 changes: 2 additions & 1 deletion tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use comfyui_client::{ClientBuilder, ComfyUIClient, EventStream};
use std::sync::Once;
use tokio::task::JoinHandle;

pub fn setup() {
static START: Once = Once::new();
Expand All @@ -10,7 +11,7 @@ pub fn setup() {
});
}

pub async fn build_client() -> (ComfyUIClient, EventStream) {
pub async fn build_client() -> (ComfyUIClient, EventStream, JoinHandle<()>) {
ClientBuilder::new("http://localhost:8188")
.build()
.await
Expand Down
7 changes: 5 additions & 2 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ use tokio_stream::StreamExt;
#[tokio::test]
async fn test_get_prompt() {
common::setup();
let (client, _) = common::build_client().await;
let (client, _, handle) = common::build_client().await;
client.get_prompt().await.unwrap();
handle.abort();
}

#[tokio::test]
async fn test_integration() {
common::setup();
let (client, mut stream) = common::build_client().await;
let (client, mut stream, handle) = common::build_client().await;

let file = File::open("./tests/data/cat.webp").await.unwrap();
let file_info = FileInfo {
Expand Down Expand Up @@ -64,4 +65,6 @@ async fn test_integration() {
let image2_buf = client.get_view(&image).await.unwrap();

assert_eq!(image_buf, image2_buf);

handle.abort();
}