diff --git a/Cargo.lock b/Cargo.lock index 759fa9d..8888f95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,6 +28,17 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -1898,6 +1909,7 @@ version = "0.1.0" dependencies = [ "aes", "anyhow", + "async-trait", "base64", "block-padding", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 0caf7a8..a553c4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ cipher = { version = "0.4", features = ["alloc"] } qrcode = "0.14" mime_guess = "2" tempfile = "3" +async-trait = "0.1" snafu = "0.9" rand = "0.9" hex = "0.4" diff --git a/examples/echo_bot.rs b/examples/echo_bot.rs index cef680b..8024c3b 100644 --- a/examples/echo_bot.rs +++ b/examples/echo_bot.rs @@ -1,19 +1,16 @@ -use std::{future::Future, pin::Pin, sync::Arc}; +use std::sync::Arc; +use async_trait::async_trait; use wechat_agent_rs::{Agent, ChatRequest, ChatResponse, LoginOptions, StartOptions, login, start}; struct EchoAgent; +#[async_trait] impl Agent for EchoAgent { - fn chat( - &self, - request: ChatRequest, - ) -> Pin> + Send + '_>> { - Box::pin(async move { - Ok(ChatResponse { - text: Some(format!("You said: {}", request.text)), - media: None, - }) + async fn chat(&self, request: ChatRequest) -> wechat_agent_rs::Result { + Ok(ChatResponse { + text: Some(format!("You said: {}", request.text)), + media: None, }) } } diff --git a/examples/openai_bot.rs b/examples/openai_bot.rs index 0fc63d9..849e6ee 100644 --- a/examples/openai_bot.rs +++ b/examples/openai_bot.rs @@ -3,6 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; +use async_trait::async_trait; use base64::Engine; use reqwest::Client; use serde_json::{Value, json}; @@ -37,88 +38,82 @@ impl OpenAIAgent { } } +#[async_trait] impl Agent for OpenAIAgent { - fn chat( - &self, - request: ChatRequest, - ) -> std::pin::Pin< - Box> + Send + '_>, - > { - Box::pin(async move { - let user_content = if let Some(ref media) = request.media { - match media.media_type { - wechat_agent_rs::MediaType::Image => { - let data = std::fs::read(&media.file_path).context(IoSnafu)?; - let b64 = base64::engine::general_purpose::STANDARD.encode(&data); - json!([ - {"type": "text", "text": request.text}, - {"type": "image_url", "image_url": {"url": format!("data:{};base64,{b64}", media.mime_type)}} - ]) - } - _ => { - json!(format!( - "{}\n[Attachment: {} ({})]", - request.text, - media.file_name.as_deref().unwrap_or("file"), - media.mime_type - )) - } + async fn chat(&self, request: ChatRequest) -> wechat_agent_rs::Result { + let user_content = if let Some(ref media) = request.media { + match media.media_type { + wechat_agent_rs::MediaType::Image => { + let data = std::fs::read(&media.file_path).context(IoSnafu)?; + let b64 = base64::engine::general_purpose::STANDARD.encode(&data); + json!([ + {"type": "text", "text": request.text}, + {"type": "image_url", "image_url": {"url": format!("data:{};base64,{b64}", media.mime_type)}} + ]) } - } else { - json!(request.text) - }; - - // Build messages while holding the lock, then drop it before await - let messages = { - let mut histories = self.histories.lock().unwrap(); - let history = histories - .entry(request.conversation_id.clone()) - .or_default(); - - history.push(json!({"role": "user", "content": user_content})); - - if history.len() > 50 { - history.drain(0..history.len() - 50); + _ => { + json!(format!( + "{}\n[Attachment: {} ({})]", + request.text, + media.file_name.as_deref().unwrap_or("file"), + media.mime_type + )) } - - let mut messages = vec![json!({"role": "system", "content": self.system_prompt})]; - messages.extend(history.iter().cloned()); - drop(histories); - messages - }; - - let resp = self - .client - .post(format!("{}/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&json!({ - "model": self.model, - "messages": messages, - })) - .send() - .await - .context(HttpSnafu)? - .json::() - .await - .context(HttpSnafu)?; - - let reply = resp["choices"][0]["message"]["content"] - .as_str() - .unwrap_or("(no response)") - .to_string(); - - // Re-acquire lock to store assistant reply - self.histories - .lock() - .unwrap() - .entry(request.conversation_id) - .or_default() - .push(json!({"role": "assistant", "content": &reply})); - - Ok(ChatResponse { - text: Some(reply), - media: None, - }) + } + } else { + json!(request.text) + }; + + // Build messages while holding the lock, then drop it before await + let messages = { + let mut histories = self.histories.lock().unwrap(); + let history = histories + .entry(request.conversation_id.clone()) + .or_default(); + + history.push(json!({"role": "user", "content": user_content})); + + if history.len() > 50 { + history.drain(0..history.len() - 50); + } + + let mut messages = vec![json!({"role": "system", "content": self.system_prompt})]; + messages.extend(history.iter().cloned()); + drop(histories); + messages + }; + + let resp = self + .client + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&json!({ + "model": self.model, + "messages": messages, + })) + .send() + .await + .context(HttpSnafu)? + .json::() + .await + .context(HttpSnafu)?; + + let reply = resp["choices"][0]["message"]["content"] + .as_str() + .unwrap_or("(no response)") + .to_string(); + + // Re-acquire lock to store assistant reply + self.histories + .lock() + .unwrap() + .entry(request.conversation_id) + .or_default() + .push(json!({"role": "assistant", "content": &reply})); + + Ok(ChatResponse { + text: Some(reply), + media: None, }) } } diff --git a/src/models.rs b/src/models.rs index 11765cb..09ef8ef 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,17 +1,14 @@ -use std::{future::Future, pin::Pin}; - +use async_trait::async_trait; use serde::{Deserialize, Serialize}; /// Trait that application code implements to handle incoming chat messages. /// /// The SDK calls [`Agent::chat`] for every incoming message and sends the /// returned [`ChatResponse`] back to the `WeChat` user. +#[async_trait] pub trait Agent: Send + Sync { /// Processes an incoming chat request and returns a response. - fn chat( - &self, - request: ChatRequest, - ) -> Pin> + Send + '_>>; + async fn chat(&self, request: ChatRequest) -> crate::Result; } /// An incoming chat message delivered to the agent.