diff --git a/discord_rig_bot/Cargo.toml b/discord_rig_bot/Cargo.toml index 8dc0ba7..14db016 100644 --- a/discord_rig_bot/Cargo.toml +++ b/discord_rig_bot/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -rig-core = "0.2.1" +rig-core = "0.9" tokio = { version = "1.34.0", features = ["full"] } serenity = { version = "0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "cache", "model", "http"] } dotenv = "0.15.0" diff --git a/discord_rig_bot/src/main.rs b/discord_rig_bot/src/main.rs index 395faad..85db998 100644 --- a/discord_rig_bot/src/main.rs +++ b/discord_rig_bot/src/main.rs @@ -3,18 +3,18 @@ mod rig_agent; use anyhow::Result; +use dotenv::dotenv; +use rig_agent::RigAgent; use serenity::async_trait; use serenity::model::application::command::Command; +use serenity::model::application::command::CommandOptionType; use serenity::model::application::interaction::{Interaction, InteractionResponseType}; -use serenity::model::gateway::Ready; use serenity::model::channel::Message; +use serenity::model::gateway::Ready; use serenity::prelude::*; -use serenity::model::application::command::CommandOptionType; use std::env; use std::sync::Arc; -use tracing::{error, info, debug}; -use rig_agent::RigAgent; -use dotenv::dotenv; +use tracing::{debug, error, info}; // Define a key for storing the bot's user ID in the TypeMap struct BotUserId; @@ -30,12 +30,37 @@ struct Handler { #[async_trait] impl EventHandler for Handler { async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - debug!("Received an interaction"); + debug!("\n\n======> Received an interaction"); if let Interaction::ApplicationCommand(command) = interaction { - debug!("Received command: {}", command.data.name); - let content = match command.data.name.as_str() { - "hello" => "Hello! I'm your helpful Rust and Rig-powered assistant. How can I assist you today?".to_string(), + debug!("\n\n======> Received command: {}", command.data.name); + + match command.data.name.as_str() { + "hello" => { + let content = "Hello! I'm your helpful Rust and Rig-powered assistant. How can I assist you today?".to_string(); + + if let Err(why) = command + .create_interaction_response(&ctx.http, |response| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message| message.content(content)) + }) + .await + { + error!("Cannot respond to slash command: {}", why); + } + }, "ask" => { + // Step 1: Acknowledge quickly + if let Err(e) = command + .create_interaction_response(&ctx.http, |response| { + response.kind(InteractionResponseType::DeferredChannelMessageWithSource) + }) + .await + { + error!("Failed to create deferred response: {:?}", e); + return; + } + let query = command .data .options @@ -43,38 +68,55 @@ impl EventHandler for Handler { .and_then(|opt| opt.value.as_ref()) .and_then(|v| v.as_str()) .unwrap_or("What would you like to ask?"); - debug!("Query: {}", query); - match self.rig_agent.process_message(query).await { - Ok(response) => response, + + debug!("\n\n======> Query: {}", query); + + let response = match self.rig_agent.process_string(query).await { + Ok(response) => { + if response.len() > 1900 { + format!("Response truncated due to Discord limits:\n{}", &response[..1897]) + } else { + response + } + }, Err(e) => { error!("Error processing request: {:?}", e); format!("Error processing request: {:?}", e) } + }; + + // Step 3: Edit the original response + if let Err(e) = command + .edit_original_interaction_response(&ctx.http, |message| { + message.content(response) + }) + .await + { + error!("Failed to edit interaction response: {:?}", e); + } + }, + _ => { + if let Err(why) = command + .create_interaction_response(&ctx.http, |response| { + response + .kind(InteractionResponseType::ChannelMessageWithSource) + .interaction_response_data(|message| + message.content("Not implemented :(")) + }) + .await + { + error!("Cannot respond to slash command: {}", why); } } - _ => "Not implemented :(".to_string(), - }; - - debug!("Sending response: {}", content); - - if let Err(why) = command - .create_interaction_response(&ctx.http, |response| { - response - .kind(InteractionResponseType::ChannelMessageWithSource) - .interaction_response_data(|message| message.content(content)) - }) - .await - { - error!("Cannot respond to slash command: {}", why); - } else { - debug!("Response sent successfully"); } + + debug!("\n\n======> Response sent successfully"); } } - + async fn message(&self, ctx: Context, msg: Message) { if msg.mentions_me(&ctx.http).await.unwrap_or(false) { - debug!("Bot mentioned in message: {}", msg.content); + debug!("\n\n=====> Bot mentioned in message: {}", msg.content); let bot_id = { let data = ctx.data.read().await; @@ -85,25 +127,55 @@ impl EventHandler for Handler { let mention = format!("<@{}>", bot_id); let content = msg.content.replace(&mention, "").trim().to_string(); - debug!("Processed content after removing mention: {}", content); + debug!( + "\n\n=====> Processed content after removing mention: {}", + content + ); - match self.rig_agent.process_message(&content).await { + match self.rig_agent.process_message(&ctx, &msg).await { Ok(response) => { - if let Err(why) = msg.channel_id.say(&ctx.http, response).await { - error!("Error sending message: {:?}", why); + println!("Response sent successfully."); + println!("{}", response); + } + Err(e) => { + println!("Error processing request: {:?}", e); + if let Err(why) = msg.channel_id.say(&ctx.http, format!("Error processing request: {:?}", e)).await { + println!("Error sending error message: {:?}", why); } } + } + + match self.rig_agent.process_message(&ctx, &msg).await { + Ok(response) => { + println!("Response sent successfully."); + println!("{}", response); + } Err(e) => { - error!("Error processing message: {:?}", e); - if let Err(why) = msg - .channel_id - .say(&ctx.http, format!("Error processing message: {:?}", e)) - .await - { - error!("Error sending error message: {:?}", why); + println!("Error processing request: {:?}", e); + if let Err(why) = msg.channel_id.say(&ctx.http, format!("Error processing request: {:?}", e)).await { + println!("Error sending error message: {:?}", why); } } } + + + // match self.rig_agent.process_message(&content).await { + // Ok(response) => { + // if let Err(why) = msg.channel_id.say(&ctx.http, response).await { + // error!("Error sending message: {:?}", why); + // } + // } + // Err(e) => { + // error!("Error processing message: {:?}", e); + // if let Err(why) = msg + // .channel_id + // .say(&ctx.http, format!("Error processing message: {:?}", e)) + // .await + // { + // error!("Error sending error message: {:?}", why); + // } + // } + // } } else { error!("Bot user ID not found in TypeMap"); } @@ -121,9 +193,7 @@ impl EventHandler for Handler { let commands = Command::set_global_application_commands(&ctx.http, |commands| { commands .create_application_command(|command| { - command - .name("hello") - .description("Say hello to the bot") + command.name("hello").description("Say hello to the bot") }) .create_application_command(|command| { command @@ -172,4 +242,4 @@ async fn main() -> Result<()> { } Ok(()) -} \ No newline at end of file +} diff --git a/discord_rig_bot/src/rig_agent.rs b/discord_rig_bot/src/rig_agent.rs index d535e91..e6f90b1 100644 --- a/discord_rig_bot/src/rig_agent.rs +++ b/discord_rig_bot/src/rig_agent.rs @@ -1,16 +1,17 @@ // rig_agent.rs use anyhow::{Context, Result}; -use rig::providers::openai; -use rig::vector_store::in_memory_store::InMemoryVectorStore; -use rig::vector_store::VectorStore; -use rig::embeddings::EmbeddingsBuilder; -use rig::agent::Agent; -use rig::completion::Prompt; -use std::path::Path; +use rig::{ + agent::Agent, completion::Prompt, embeddings::EmbeddingsBuilder, providers::openai, + vector_store::in_memory_store::InMemoryVectorStore, +}; use std::fs; +use std::path::Path; use std::sync::Arc; +use serenity::client::Context as SerenityContext; +use serenity::model::channel::Message; + pub struct RigAgent { agent: Arc>, } @@ -37,36 +38,40 @@ impl RigAgent { let md2_content = Self::load_md_content(&md2_path)?; let md3_content = Self::load_md_content(&md3_path)?; - // Create embeddings and add to vector store + //Create embeddings add to vector store let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("Rig_guide", &md1_content) - .simple_document("Rig_faq", &md2_content) - .simple_document("Rig_examples", &md3_content) + .document(md1_content)? + .document(md2_content)? + .document(md3_content)? .build() .await?; - vector_store.add_documents(embeddings).await?; + vector_store.add_documents(embeddings); // Create index let index = vector_store.index(embedding_model); // Create Agent - let agent = Arc::new(openai_client.agent(openai::GPT_4O) - .preamble("You are an advanced AI assistant powered by Rig, a Rust library for building LLM applications. Your primary function is to provide accurate, helpful, and context-aware responses by leveraging both your general knowledge and specific information retrieved from a curated knowledge base. + let agent = Arc::new( + openai_client + .agent(openai::GPT_4O) + .preamble( + "You are an advanced AI assistant powered by Rig, a Rust library for building LLM applications. Your primary function is to provide accurate, helpful, and context-aware responses by leveraging both your general knowledge and specific information retrieved from a curated knowledge base. Key responsibilities and behaviors: 1. Information Retrieval: You have access to a vast knowledge base. When answering questions, always consider the context provided by the retrieved information. 2. Clarity and Conciseness: Provide clear and concise answers. Ensure responses are short and concise. Use bullet points or numbered lists for complex information when appropriate. 3. Technical Proficiency: You have deep knowledge about Rig and its capabilities. When discussing Rig or answering related questions, provide detailed and technically accurate information. - 4. Code Examples: When appropriate, provide Rust code examples to illustrate concepts, especially when discussing Rig's functionalities. Always format code examples for proper rendering in Discord by wrapping them in triple backticks and specifying the language as 'rust'. For example: + 5. Code Examples: When appropriate, provide Rust code examples to illustrate concepts, especially when discussing Rig's functionalities. Always format code examples for proper rendering in Discord by wrapping them in triple backticks and specifying the language as 'rust'. For example: ```rust let example_code = \"This is how you format Rust code for Discord\"; println!(\"{}\", example_code); ``` - 5. Keep your responses short and concise. If the user needs more information, they can ask follow-up questions. - ") - .dynamic_context(2, index) - .build()); + ", + ) + .dynamic_context(2, index) + .build(), + ); Ok(Self { agent }) } @@ -75,8 +80,43 @@ impl RigAgent { fs::read_to_string(file_path.as_ref()) .with_context(|| format!("Failed to read markdown file: {:?}", file_path.as_ref())) } - - pub async fn process_message(&self, message: &str) -> Result { - self.agent.prompt(message).await.map_err(anyhow::Error::from) + + // Add this function for messages that only need a string input/output + pub async fn process_string(&self, message: &str) -> Result { + self.agent + .prompt(message) + .await + .map_err(anyhow::Error::from) } -} \ No newline at end of file + + pub async fn process_message(&self, ctx: &SerenityContext, msg: &Message) -> Result { + // First, create a typing indicator + msg.channel_id.broadcast_typing(&ctx.http).await?; + + // Send deferred response to meet 3-second requirement + let mut deferred_msg = msg.channel_id.say(&ctx.http, "Thinking...").await?; + + // Use the string content directly, not a reference + let response = self.agent.prompt(msg.content.clone()).await.map_err(anyhow::Error::from)?; + + // Truncate if needed + let truncated_response = if response.len() > 1900 { + format!("Response truncated due to Discord limits:\n{}", &response[..1897]) + } else { + response + }; + + // Edit the deferred message + deferred_msg.edit(&ctx.http, |m| m.content(truncated_response.clone())).await?; + + Ok(truncated_response) + } + + // OLD process_message WITHOUT DEFERRAL AND TRUNCATION + // pub async fn process_message(&self, message: &str) -> Result { + // self.agent + // .prompt(message) + // .await + // .map_err(anyhow::Error::from) + // } +}