Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions ai/llmagent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import logging

import aiohttp
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.prebuilt import create_react_agent
from requests import RequestException

from ai.llmbot import LLMBot
from ai.tools import get_tools
Expand Down Expand Up @@ -37,14 +37,14 @@ def __init__(self, config: Config, system_instructions: list[BaseMessage]):
# logging.debug(f"Generated feedback message: {feedback_message}")
# return feedback_message

def answer_message(self, chat_id: int, message: str) -> BaseMessage:
async def answer_message(self, chat_id: int, message: str) -> BaseMessage:
self.chats[chat_id]["messages"].append(HumanMessage(content=message))
self.truncate_chat_context(chat_id)

ai_msg = self.agent.invoke({"messages": self.system_instructions + self.chats[chat_id]["messages"]})
ai_msg = await self.agent.ainvoke({"messages": self.system_instructions + self.chats[chat_id]["messages"]})
return ai_msg["messages"][-1]

def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessage:
async def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessage:
"""
Answer an image message.
:param chat_id: Chat ID
Expand All @@ -55,10 +55,12 @@ def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessa
logging.debug(f"Image message: {text}")

try:
# Use the shared session to download the image
response = self._session.get(image, timeout=self.config.web_content_request_timeout)
response.raise_for_status()
image_data = base64.b64encode(response.content).decode("utf-8")
# Use aiohttp to download the image
session = await self._get_session()
async with session.get(image) as response:
response.raise_for_status()
image_bytes = await response.read()
image_data = base64.b64encode(image_bytes).decode("utf-8")

llm_message = HumanMessage(
content=[
Expand All @@ -71,9 +73,9 @@ def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessa
)
self.chats[chat_id]["messages"].append(llm_message)
self.truncate_chat_context(chat_id)
response = self.agent.invoke({"messages": self.chats[chat_id]["messages"]})["messages"][-1]
except (RequestException, Exception) as e:
if isinstance(e, RequestException):
response = (await self.agent.ainvoke({"messages": self.chats[chat_id]["messages"]}))["messages"][-1]
except (aiohttp.ClientError, Exception) as e:
if isinstance(e, aiohttp.ClientError):
logging.error(f"Failed to get image: {image}")
logging.exception(e)
response = BaseMessage(content="NO_ANSWER", type="text")
Expand Down
133 changes: 72 additions & 61 deletions ai/llmbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from typing import Any
from urllib.parse import urljoin

import requests
import aiohttp
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.rate_limiters import InMemoryRateLimiter
from requests import ConnectTimeout, ReadTimeout, RequestException

from ai.tools import get_tool, get_tools
from config import Config
Expand All @@ -27,15 +25,30 @@ def __init__(self, config: Config, system_instructions: list[BaseMessage]):
self.llm = None
self.chats: dict = {} # {'chat_id': {"messages": []}}
self._load_llm()
self._session = requests.Session() # Single session for all HTTP requests
self._session = None # Will be initialized asynchronously

if self.config.use_tools:
self._load_tools()

def __del__(self):
# Clean up the session when the object is garbage collected
if hasattr(self, "_session"):
self._session.close()
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create the aiohttp session."""
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.config.web_content_request_timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session

async def close(self):
"""Close the aiohttp session."""
if self._session and not self._session.closed:
await self._session.close()

async def __aenter__(self):
"""Async context manager entry."""
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()

def _get_rate_limiter(self):
return InMemoryRateLimiter(
Expand Down Expand Up @@ -124,7 +137,7 @@ def truncate_chat_context(self, chat_id: int) -> None:
self.chats[chat_id]["messages"] = self.chats[chat_id]["messages"][1:]
logging.debug(f"Chat context truncated for chat {chat_id}")

def call_sdapi(self, prompt: str) -> dict[str, Any] | None:
async def call_sdapi(self, prompt: str) -> dict[str, Any] | None:
"""
Call the StableDiffusion API.
:param prompt: The prompt to send to the StableDiffusion API.
Expand All @@ -137,14 +150,14 @@ def call_sdapi(self, prompt: str) -> dict[str, Any] | None:
if self.config.sdapi_negative_prompt:
params["negative_prompt"] = self.config.sdapi_negative_prompt

# Use the shared session instead of requests.post
response = self._session.post(
# Use aiohttp for async HTTP requests
session = await self._get_session()
async with session.post(
urljoin(self.config.sdapi_url, "/sdapi/v1/txt2img"),
json=params,
timeout=self.config.web_content_request_timeout,
)
if response.status_code == 200:
return response.json()
) as response:
if response.status == 200:
return await response.json()
except Exception as e:
logging.error("Failed to call SDAPI")
logging.exception(e)
Expand All @@ -158,20 +171,20 @@ def clean_context(self, chat_id: int) -> None:
self.chats[chat_id]["messages"] = []
logging.debug(f"Chat context cleaned for chat {chat_id}")

def answer_message(self, chat_id: int, message: str) -> BaseMessage:
async def answer_message(self, chat_id: int, message: str) -> BaseMessage:
self.chats[chat_id]["messages"].append(HumanMessage(content=message))
self.truncate_chat_context(chat_id)
ai_msg = self.llm.invoke(self.system_instructions + self.chats[chat_id]["messages"])
ai_msg = await self.llm.ainvoke(self.system_instructions + self.chats[chat_id]["messages"])
if ai_msg.tool_calls:
self.chats[chat_id]["messages"].append(ai_msg)
for tool_call in ai_msg.tool_calls:
selected_tool = get_tool(tool_call["name"])
tool_msg = selected_tool.invoke(tool_call)
tool_msg = await selected_tool.ainvoke(tool_call)
self.chats[chat_id]["messages"].append(tool_msg)
ai_msg = self.llm.invoke(self.system_instructions + self.chats[chat_id]["messages"])
ai_msg = await self.llm.ainvoke(self.system_instructions + self.chats[chat_id]["messages"])
return ai_msg

def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessage:
async def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessage:
"""
Answer an image message.
:param chat_id: Chat ID
Expand All @@ -182,10 +195,12 @@ def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessa
logging.debug(f"Image message: {text}")

try:
# Use the shared session to download the image
response = self._session.get(image, timeout=self.config.web_content_request_timeout)
response.raise_for_status()
image_data = base64.b64encode(response.content).decode("utf-8")
# Use aiohttp to download the image
session = await self._get_session()
async with session.get(image) as response:
response.raise_for_status()
image_bytes = await response.read()
image_data = base64.b64encode(image_bytes).decode("utf-8")

llm_message = HumanMessage(
content=[
Expand All @@ -198,17 +213,17 @@ def answer_image_message(self, chat_id: int, text: str, image: str) -> BaseMessa
)
self.chats[chat_id]["messages"].append(llm_message)
self.truncate_chat_context(chat_id)
response = self.llm.invoke(self.chats[chat_id]["messages"])
except (RequestException, Exception) as e:
if isinstance(e, RequestException):
response = await self.llm.ainvoke(self.chats[chat_id]["messages"])
except (aiohttp.ClientError, Exception) as e:
if isinstance(e, aiohttp.ClientError):
logging.error(f"Failed to get image: {image}")
logging.exception(e)
response = BaseMessage(content="NO_ANSWER", type="text")

logging.debug(f"Image message response: {response}")
return response

def postprocess_response(self, response: BaseMessage, message_text: str, chat_id: int) -> dict | None:
async def postprocess_response(self, response: BaseMessage, message_text: str, chat_id: int) -> dict | None:
"""
Postprocess the response from the LLM.
:param response: Response from the LLM
Expand All @@ -228,21 +243,21 @@ def postprocess_response(self, response: BaseMessage, message_text: str, chat_id
final_response = None
if response_content.startswith("GENERATE_IMAGE"):
logging.debug(f"GENERATE_IMAGE response, generating image for chat {chat_id}")
image = self.generate_image(response_content[len("GENERATE_IMAGE ") :])
image = await self.generate_image(response_content[len("GENERATE_IMAGE ") :])
if image:
final_response = {
"type": "image",
"content": image,
}
elif "WEBCONTENT_RESUME" in response_content:
logging.debug(f"WEBCONTENT_RESUME response, generating web content abstract for chat {chat_id}")
response_content = self.answer_webcontent(message_text, response_content, chat_id)
response_content = await self.answer_webcontent(message_text, response_content, chat_id)
# TODO: find a way to graciously handle failed web content requests
response_content = response_content if response_content else "😐"
final_response = {"type": "text", "data": response_content}
elif "WEBCONTENT_OPINION" in response_content:
logging.debug(f"WEBCONTENT_OPINION response, generating web content opinion for chat {chat_id}")
response_content = self.answer_webcontent(message_text, response_content, chat_id)
response_content = await self.answer_webcontent(message_text, response_content, chat_id)
# TODO: find a way to graciously handle failed web content requests
response_content = response_content if response_content else "😐"
final_response = {"type": "text", "data": response_content}
Expand All @@ -257,14 +272,14 @@ def postprocess_response(self, response: BaseMessage, message_text: str, chat_id

return final_response

def generate_image(self, prompt: str) -> str | None:
async def generate_image(self, prompt: str) -> str | None:
"""
Generate an image.
:param prompt: Prompt to generate the image
:return: Image representation in base64 format if the call was successful, None otherwise
"""
logging.debug(f"Generate image: {prompt}")
response = self.call_sdapi(prompt)
response = await self.call_sdapi(prompt)
if response and "images" in response:
return response["images"][0]
return None
Expand All @@ -290,7 +305,7 @@ def count_tokens(self, messages: list[BaseMessage]) -> int:

return self.llm.get_num_tokens(context_text) + extra_tokens

def answer_webcontent(self, message_text: str, response_content: str, chat_id: int) -> str | None:
async def answer_webcontent(self, message_text: str, response_content: str, chat_id: int) -> str | None:
"""
Answer a web content message.
:param message_text: Text to answer
Expand All @@ -302,15 +317,21 @@ def answer_webcontent(self, message_text: str, response_content: str, chat_id: i
if url:
logging.debug(f"Obtaining web content for {url}")

# Configure WebBaseLoader with the shared session
loader = WebBaseLoader(
url,
requests_kwargs={
"timeout": self.config.web_content_request_timeout,
"session": self._session, # Reuse the session
},
)
docs = loader.load()
# Use aiohttp to fetch web content
session = await self._get_session()
async with session.get(url) as response:
response.raise_for_status()
html_content = await response.text()

from bs4 import BeautifulSoup

soup = BeautifulSoup(html_content, "html.parser")
page_content = soup.get_text(separator=" ", strip=True)

# Create a simple document for the chain
from langchain_core.documents import Document

docs = [Document(page_content=page_content)]

template = self._remove_urls(message_text) + "\n" + '"{text}"'
prompt = PromptTemplate.from_template(template)
Expand All @@ -324,12 +345,12 @@ def answer_webcontent(self, message_text: str, response_content: str, chat_id: i
)

# The key should match the document_variable_name parameter
response = stuff_chain.invoke({"text": docs})
response = await stuff_chain.ainvoke({"text": docs})
logging.debug(f"Web content response: {response}")
return response
else:
logging.debug(f"No URL found for web content: {message_text}")
except ConnectionError as e:
except aiohttp.ClientError as e:
logging.error("Connection error connecting to web content")
logging.exception(e)
error_prompt = (
Expand All @@ -338,18 +359,8 @@ def answer_webcontent(self, message_text: str, response_content: str, chat_id: i
f"Suggest checking the URL or trying again later. "
f"Keep your response under 150 characters and maintain your character's style."
)
return self.generate_feedback_message(error_prompt)
except ReadTimeout as e:
logging.error("Read timeout error connecting to web content")
logging.exception(e)
error_prompt = (
f"Generate a brief response in {self.config.preferred_language} "
f"explaining that the webpage {url} took too long to send data. "
f"Suggest it might be unavailable or too large. "
f"Keep your response under 150 characters and maintain your character's style."
)
return self.generate_feedback_message(error_prompt)
except ConnectTimeout as e:
return await self.generate_feedback_message(error_prompt)
except TimeoutError as e:
logging.error("Timeout error connecting to web content")
logging.exception(e)
error_prompt = (
Expand All @@ -358,7 +369,7 @@ def answer_webcontent(self, message_text: str, response_content: str, chat_id: i
f"Suggest it might be unavailable or too large. "
f"Keep your response under 150 characters and maintain your character's style."
)
return self.generate_feedback_message(error_prompt)
return await self.generate_feedback_message(error_prompt)
except Exception as e:
logging.error("Error connecting to web content")
logging.exception(e)
Expand All @@ -368,10 +379,10 @@ def answer_webcontent(self, message_text: str, response_content: str, chat_id: i
f"Suggest trying again later or trying a different URL. "
f"Keep your response under 150 characters and maintain your character's style."
)
return self.generate_feedback_message(error_prompt)
return await self.generate_feedback_message(error_prompt)
return None

def generate_feedback_message(self, prompt: str, max_length: int = 200) -> str:
async def generate_feedback_message(self, prompt: str, max_length: int = 200) -> str:
"""
Generate a feedback message using the LLM.

Expand All @@ -383,7 +394,7 @@ def generate_feedback_message(self, prompt: str, max_length: int = 200) -> str:

# Create a simple message list with just the prompt
messages = [HumanMessage(content=prompt)]
response = self.llm.invoke(messages)
response = await self.llm.ainvoke(messages)

# Clean up the response if needed
feedback_message = response.content.strip()
Expand Down
Loading