From 27b35b54fcbf6cb9a624ca4328c5f672d64cbd93 Mon Sep 17 00:00:00 2001 From: Yuvinscria Werdxz Date: Fri, 18 Apr 2025 18:31:13 -0700 Subject: [PATCH] added context feature --- src/context.rs | 76 +++++++++++++++++++ src/events.rs | 17 +---- src/main.rs | 1 + src/{ => messages}/llm-prompt.txt | 4 - src/messages/message-context-format.txt | 3 + src/parser.rs | 98 ++++++++++++++++++------- 6 files changed, 156 insertions(+), 43 deletions(-) create mode 100644 src/context.rs rename src/{ => messages}/llm-prompt.txt (99%) create mode 100644 src/messages/message-context-format.txt diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..62f02b9 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,76 @@ +use std::fmt::Display; + +use chrono::{Local, NaiveDate}; +use serenity::all::{Context, Message, PartialGuild}; + +#[non_exhaustive] +pub struct MessageContext<'a> { + pub server_name: &'a str, + pub content: &'a str, + pub date: NaiveDate, +} + +impl<'a> MessageContext<'a> { + pub async fn from(guild: &'a PartialGuild, _ctx: &'a Context, msg: &'a Message) -> Self { + // The bot accepts two inputs + // 1. A message with information with mentions it with an @CalBot + // 2. Replying to a message with information and mentioning @CalBot in the reply + let (content, date) = match msg.referenced_message { + Some(ref message) => { + if let Some(edited) = message.edited_timestamp { + (&message.content, edited.date_naive()) + } else { + (&message.content, message.timestamp.date_naive()) + } + } + None => (&msg.content, msg.timestamp.date_naive()), + }; + + Self { + server_name: &guild.name, + content, + date, + } + } +} + +impl<'a> Display for MessageContext<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + include_str!("./messages/message-context-format.txt"), + self.server_name, self.content, + ) + } +} + +impl<'a> Default for MessageContext<'a> { + fn default() -> Self { + Self { + server_name: "Unknown Server", + content: "", + date: Local::now().date_naive(), + } + } +} + +#[cfg(test)] +mod test { + use super::MessageContext; + + #[test] + fn test_format() { + let ctx = MessageContext { + server_name: "Test Server", + content: "Some random message.", + date: chrono::NaiveDate::from_ymd_opt(2023, 10, 1).unwrap(), + }; + + let res = format!("{}", ctx); + + assert_eq!( + res, + "Context for the message:\n- Server Name: Test Server\n- Content: Some random message.\n" + ); + } +} diff --git a/src/events.rs b/src/events.rs index 114bd73..590cec9 100644 --- a/src/events.rs +++ b/src/events.rs @@ -6,6 +6,7 @@ use serenity::{ }; use crate::{ + context::MessageContext, parser::{parse_msg, Error}, utils::{calendar_message, upload_calendar}, }; @@ -51,19 +52,9 @@ impl EventHandler for Handler { return; } - // The bot accepts two inputs - // 1. A message with information with mentions it with an @CalBot - // 2. Replying to a message with information and mentioning @CalBot in the reply - let res = match msg.referenced_message { - Some(ref ref_msg) => { - if let Some(edited) = ref_msg.edited_timestamp { - parse_msg(&ref_msg.content, &edited.date_naive()).await - } else { - parse_msg(&ref_msg.content, &ref_msg.timestamp.date_naive()).await - } - } - None => parse_msg(&msg.content, &msg.timestamp.date_naive()).await, - }; + let context = MessageContext::from(&guild, &ctx, &msg).await; + + let res = parse_msg(&context).await; match res { Ok(calendar) => { diff --git a/src/main.rs b/src/main.rs index 60609e4..c872e2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod events; mod parser; mod utils; +mod context; use shuttle_runtime::SecretStore; use events::Handler; diff --git a/src/llm-prompt.txt b/src/messages/llm-prompt.txt similarity index 99% rename from src/llm-prompt.txt rename to src/messages/llm-prompt.txt index e7be284..1c0ec7a 100644 --- a/src/llm-prompt.txt +++ b/src/messages/llm-prompt.txt @@ -39,7 +39,3 @@ Try to keep the title brief, no more than 5 words. If the message does not seem to be parseable, return an empty string. If there are no times in the message, do not attempt to guess the time. If there are no dates in the message, do not attempt to guess the date. - -Message: - - diff --git a/src/messages/message-context-format.txt b/src/messages/message-context-format.txt new file mode 100644 index 0000000..17b7d14 --- /dev/null +++ b/src/messages/message-context-format.txt @@ -0,0 +1,3 @@ +Context for the message: +- Server Name: {} +- Content: {} diff --git a/src/parser.rs b/src/parser.rs index e74ddbf..e3c35a7 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3,8 +3,10 @@ use icalendar::{Calendar, Component, Event, EventLike}; use reqwest::header::{AUTHORIZATION, CONTENT_TYPE}; use serde::Deserialize; +use crate::context::MessageContext; + const GROQ_ENDPOINT: &str = "https://api.groq.com/openai/v1/chat/completions"; -const PROMPT_INSTRUCTIONS: &str = include_str!("llm-prompt.txt"); +const SYSTEM_PROMPT: &str = include_str!("./messages/llm-prompt.txt"); const MAX_COMPLETION_TOKEN: usize = 300; #[derive(Deserialize, Debug)] @@ -97,19 +99,23 @@ fn parse_date(date_str: &str, msg_date: &NaiveDate) -> Result } } -pub async fn parse_msg(msg: &str, message_date: &NaiveDate) -> Result { +pub async fn parse_msg<'a>(ctx: &MessageContext<'a>) -> Result { let groq_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY missing"); - let full_prompt = [PROMPT_INSTRUCTIONS, msg].join("\r\n"); + let prompt = format!("{}", ctx); let req_body = serde_json::json!({ "model": "llama-3.3-70b-versatile", // "model": "llama-3.2-90b-vision-preview", "max_completion_tokens": MAX_COMPLETION_TOKEN, "messages": [ + { + "role": "system", + "content": SYSTEM_PROMPT, + }, { "role": "user", - "content": full_prompt, + "content": prompt, } ]}); @@ -141,7 +147,7 @@ pub async fn parse_msg(msg: &str, message_date: &NaiveDate) -> Result