From 0097813758d1f680d29085ec36f52fe99f197d14 Mon Sep 17 00:00:00 2001 From: jmjoy Date: Sat, 19 Apr 2025 01:44:25 +0800 Subject: [PATCH] refactor: update ClientBuilder to return a JoinHandle for background task management --- Cargo.lock | 1 + Cargo.toml | 1 + examples/text_to_image.rs | 4 +- src/lib.rs | 85 +++++++++++++++++++++------------------ tests/common/mod.rs | 3 +- tests/integration.rs | 7 +++- 6 files changed, 57 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0eed3f5..32748d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -177,6 +177,7 @@ dependencies = [ "env_logger", "futures-util", "log", + "pin-project-lite", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 96c66f1..a550548 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ rustls = ["reqwest/rustls-tls", "tokio-tungstenite/rustls-tls-webpki-roots"] bytes = "1.10.1" futures-util = "0.3.31" log = { version = "0.4.26", features = ["kv"] } +pin-project-lite = "0.2.16" reqwest = { version = "0.12.12", features = [ "json", "multipart", diff --git a/examples/text_to_image.rs b/examples/text_to_image.rs index e494461..f68689b 100644 --- a/examples/text_to_image.rs +++ b/examples/text_to_image.rs @@ -136,7 +136,7 @@ async fn main() { .filter_level(log::LevelFilter::Debug) .init(); - let (client, mut stream) = ClientBuilder::new("http://localhost:8188") + let (client, mut stream, handle) = ClientBuilder::new("http://localhost:8188") .build() .await .unwrap(); @@ -200,4 +200,6 @@ async fn main() { } } } + + handle.abort(); } diff --git a/src/lib.rs b/src/lib.rs index f2d6b4d..ba09fdf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,10 @@ pub use crate::errors::{ClientError, ClientResult}; use crate::meta::{FileInfo, PromptInfo}; use bytes::Bytes; use errors::{ApiBody, ApiError}; -use futures_util::StreamExt; +use futures_util::stream::{Stream, StreamExt}; use log::trace; use meta::{Event, History, OtherEvent, Prompt, PromptStatus}; +use pin_project_lite::pin_project; use reqwest::{ Body, IntoUrl, Response, multipart::{self}, @@ -21,7 +22,8 @@ use reqwest::{ use serde_json::{Value, json}; use std::{ collections::HashMap, - ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, }; use tokio::{ sync::mpsc, @@ -97,16 +99,24 @@ impl ClientBuilder { self } - /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`]. + /// Builds the [`ComfyUIClient`] along with an associated [`EventStream`] + /// and a background task handle. /// /// This method establishes a websocket connection and spawns an - /// asynchronous task to process incoming messages. + /// asynchronous task to process incoming messages. If reconnection is + /// enabled, the task will automatically attempt to reconnect when the + /// WebSocket connection drops unexpectedly. /// /// # Returns /// - /// A tuple containing the [`ComfyUIClient`] and [`EventStream`] on success, - /// or an error. - pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream)> { + /// A tuple containing: + /// - The [`ComfyUIClient`] for HTTP API interactions + /// - An [`EventStream`] for receiving real-time events + /// - A [`JoinHandle`] for the background task that manages the WebSocket + /// connection + /// + /// Returns an error if the initial connection cannot be established. + pub async fn build(self) -> ClientResult<(ComfyUIClient, EventStream, JoinHandle<()>)> { let base_url = self.base_url.into_url()?; let http_client = reqwest::Client::new(); let client_id = Uuid::new_v4().to_string(); @@ -223,12 +233,9 @@ impl ClientBuilder { client_id, }; - let stream = EventStream { - stream_handle, - rx_stream, - }; + let stream = EventStream { rx_stream }; - Ok((client, stream)) + Ok((client, stream, stream_handle)) } /// Builds a [`ComfyUIClient`] instance configured for HTTP-only @@ -456,13 +463,21 @@ impl ComfyUIClient { } } -/// A structure representing the event stream received via a websocket -/// connection. -/// -/// This stream continuously processes events from the ComfyUI service. -pub struct EventStream { - stream_handle: JoinHandle<()>, - rx_stream: ReceiverStream>, +pin_project! { + /// A structure representing the event stream received via a websocket connection. + /// + /// This stream continuously processes events from the ComfyUI service. + /// It handles WebSocket connection management including automatic reconnection + /// when enabled through the [`ClientBuilder`]. + /// + /// The stream emits various events including execution status updates, errors, + /// and connection state changes. All WebSocket communication is managed by a + /// background task, allowing the stream to be consumed without worrying about + /// connection details. + pub struct EventStream { + #[pin] + rx_stream: ReceiverStream>, + } } impl EventStream { @@ -470,12 +485,13 @@ impl EventStream { /// [`Event`]. /// /// For text messages, it tries to deserialize the message into an - /// [`Event`]. If the deserialization fails, it wraps the message as - /// [`Event::Unknown`]. Other message types are ignored. + /// [`Event`]. If deserialization fails, it wraps the message as + /// [`Event::Unknown`]. Non-text message types are ignored and return + /// `None`. /// /// # Parameters /// - /// - `msg`: A result containing a [`Message`] from the websocket. + /// - `msg`: A [`Message`] from the websocket. /// /// # Returns /// @@ -496,27 +512,16 @@ impl EventStream { } } -impl Drop for EventStream { - /// When the [`EventStream`] is dropped, abort the associated websocket - /// handling task. - fn drop(&mut self) { - self.stream_handle.abort(); - } -} +impl Stream for EventStream { + type Item = ClientResult; -impl Deref for EventStream { - type Target = ReceiverStream>; - - /// Allows access to the inner [`ReceiverStream`] containing the events. - fn deref(&self) -> &Self::Target { - &self.rx_stream + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + this.rx_stream.poll_next(cx) } -} -impl DerefMut for EventStream { - /// Allows mutable access to the inner [`ReceiverStream`]. - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.rx_stream + fn size_hint(&self) -> (usize, Option) { + self.rx_stream.size_hint() } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e20f3b6..d5f2261 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,5 +1,6 @@ use comfyui_client::{ClientBuilder, ComfyUIClient, EventStream}; use std::sync::Once; +use tokio::task::JoinHandle; pub fn setup() { static START: Once = Once::new(); @@ -10,7 +11,7 @@ pub fn setup() { }); } -pub async fn build_client() -> (ComfyUIClient, EventStream) { +pub async fn build_client() -> (ComfyUIClient, EventStream, JoinHandle<()>) { ClientBuilder::new("http://localhost:8188") .build() .await diff --git a/tests/integration.rs b/tests/integration.rs index d7c5252..b624656 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -8,14 +8,15 @@ use tokio_stream::StreamExt; #[tokio::test] async fn test_get_prompt() { common::setup(); - let (client, _) = common::build_client().await; + let (client, _, handle) = common::build_client().await; client.get_prompt().await.unwrap(); + handle.abort(); } #[tokio::test] async fn test_integration() { common::setup(); - let (client, mut stream) = common::build_client().await; + let (client, mut stream, handle) = common::build_client().await; let file = File::open("./tests/data/cat.webp").await.unwrap(); let file_info = FileInfo { @@ -64,4 +65,6 @@ async fn test_integration() { let image2_buf = client.get_view(&image).await.unwrap(); assert_eq!(image_buf, image2_buf); + + handle.abort(); }