diff --git a/src/agent.py b/src/agent.py index 172bf23df..1ff9f390a 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,9 +11,9 @@ active_conversations = {} -def get_user_conversations(user_id: str): +async def get_user_conversations(user_id: str): """Get conversation metadata for a user from persistent storage""" - return conversation_persistence.get_user_conversations(user_id) + return await conversation_persistence.get_user_conversations(user_id) def get_conversation_thread(user_id: str, previous_response_id: str = None): @@ -82,7 +82,7 @@ async def store_conversation_thread(user_id: str, response_id: str, conversation # Legacy function for backward compatibility -def get_user_conversation(user_id: str): +async def get_user_conversation(user_id: str): """Get the most recent conversation for a user (for backward compatibility)""" # Check in-memory conversations first (with function calls) if user_id in active_conversations and active_conversations[user_id]: @@ -93,7 +93,7 @@ def get_user_conversation(user_id: str): return active_conversations[user_id][latest_response_id] # Fallback to metadata-only conversations - conversations = get_user_conversations(user_id) + conversations = await get_user_conversations(user_id) if not conversations: return get_conversation_thread(user_id) @@ -461,7 +461,7 @@ async def async_chat( ) # Debug: Check what's in user_conversations now - conversations = get_user_conversations(user_id) + conversations = await get_user_conversations(user_id) logger.debug( "User conversations updated", user_id=user_id, @@ -668,7 +668,7 @@ async def async_langflow_chat( ) # Debug: Check what's in user_conversations now - conversations = get_user_conversations(user_id) + conversations = await get_user_conversations(user_id) logger.debug( "User conversations updated", user_id=user_id, diff --git a/src/main.py b/src/main.py index 76fbf0cf8..a873eacac 100644 --- a/src/main.py +++ b/src/main.py @@ -59,6 +59,10 @@ from services.auth_service import AuthService from services.langflow_mcp_service import LangflowMCPService from services.chat_service import ChatService +from services.conversation_persistence_service import ( + CONVERSATION_METADATA_INDEX_BODY, + CONVERSATION_METADATA_INDEX_NAME, +) # Services from services.document_service import DocumentService @@ -230,6 +234,22 @@ async def init_index(): index_name=API_KEYS_INDEX_NAME, ) + # Create chat conversation metadata index for horizontally scaled backends + if not await clients.opensearch.indices.exists(index=CONVERSATION_METADATA_INDEX_NAME): + await clients.opensearch.indices.create( + index=CONVERSATION_METADATA_INDEX_NAME, + body=CONVERSATION_METADATA_INDEX_BODY, + ) + logger.info( + "Created chat conversation metadata index", + index_name=CONVERSATION_METADATA_INDEX_NAME, + ) + else: + logger.info( + "Chat conversation metadata index already exists, skipping creation", + index_name=CONVERSATION_METADATA_INDEX_NAME, + ) + # Configure alerting plugin security settings await configure_alerting_security() diff --git a/src/services/chat_service.py b/src/services/chat_service.py index d2bfa8b8f..9e4326e6b 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -368,7 +368,7 @@ async def get_chat_history(self, user_id: str): return {"error": "User ID is required", "conversations": []} # Get metadata from persistent storage - conversations_dict = get_user_conversations(user_id) + conversations_dict = await get_user_conversations(user_id) # Get in-memory conversations (with function calls) in_memory_conversations = active_conversations.get(user_id, {}) @@ -484,7 +484,7 @@ async def get_langflow_history(self, user_id: str): try: # 1. Get local conversation metadata (no actual messages stored here) - conversations_dict = get_user_conversations(user_id) + conversations_dict = await get_user_conversations(user_id) local_metadata = {} for response_id, conversation_metadata in conversations_dict.items(): diff --git a/src/services/conversation_persistence_service.py b/src/services/conversation_persistence_service.py index 8af36efff..e9ab6a68e 100644 --- a/src/services/conversation_persistence_service.py +++ b/src/services/conversation_persistence_service.py @@ -1,139 +1,223 @@ """ Conversation Persistence Service -Simple service to persist chat conversations to disk so they survive server restarts +Persists chat conversation metadata in OpenSearch so it can be shared across backend instances. """ -import json -import os -import asyncio -from typing import Dict, Any from datetime import datetime -import threading +from typing import Any, Dict + +from config.settings import clients from utils.logging_config import get_logger logger = get_logger(__name__) + +CONVERSATION_METADATA_INDEX_NAME = "chat_conversation_metadata" +CONVERSATION_METADATA_INDEX_BODY = { + "settings": { + "number_of_shards": 1, + "number_of_replicas": 0, + }, + "mappings": { + "properties": { + "user_id": {"type": "keyword"}, + "response_id": {"type": "keyword"}, + "title": {"type": "text", "fields": {"keyword": {"type": "keyword"}}}, + "endpoint": {"type": "keyword"}, + "created_at": {"type": "date"}, + "last_activity": {"type": "date"}, + "previous_response_id": {"type": "keyword"}, + "filter_id": {"type": "keyword"}, + "total_messages": {"type": "integer"}, + } + }, +} + + class ConversationPersistenceService: - """Simple service to persist conversations to disk""" - - def __init__(self, storage_file: str = "data/conversations.json"): - self.storage_file = storage_file - # Ensure data directory exists - os.makedirs(os.path.dirname(self.storage_file), exist_ok=True) - self.lock = threading.Lock() - self._conversations = self._load_conversations() - - def _load_conversations(self) -> Dict[str, Dict[str, Any]]: - """Load conversations from disk""" - if os.path.exists(self.storage_file): - try: - with open(self.storage_file, 'r', encoding='utf-8') as f: - data = json.load(f) - logger.debug(f"Loaded {self._count_total_conversations(data)} conversations from {self.storage_file}") - return data - except Exception as e: - logger.error(f"Error loading conversations from {self.storage_file}: {e}") - return {} - return {} - - def _save_conversations_sync(self): - """Synchronous save conversations to disk (runs in executor)""" - try: - with self.lock: - with open(self.storage_file, 'w', encoding='utf-8') as f: - json.dump(self._conversations, f, indent=2, ensure_ascii=False, default=str) - logger.debug(f"Saved {self._count_total_conversations(self._conversations)} conversations to {self.storage_file}") - except Exception as e: - logger.error(f"Error saving conversations to {self.storage_file}: {e}") - - async def _save_conversations(self): - """Async save conversations to disk (non-blocking)""" - # Run the synchronous file I/O in a thread pool to avoid blocking the event loop - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._save_conversations_sync) - - def _count_total_conversations(self, data: Dict[str, Any]) -> int: - """Count total conversations across all users""" - total = 0 - for user_conversations in data.values(): - if isinstance(user_conversations, dict): - total += len(user_conversations) - return total - - def get_user_conversations(self, user_id: str) -> Dict[str, Any]: - """Get all conversations for a user""" - if user_id not in self._conversations: - self._conversations[user_id] = {} - return self._conversations[user_id] - + """Persists conversation metadata in OpenSearch with in-memory fallback.""" + + def __init__(self): + self._fallback_conversations: Dict[str, Dict[str, Any]] = {} + + def _get_document_id(self, user_id: str, response_id: str) -> str: + return f"{user_id}:{response_id}" + def _serialize_datetime(self, obj: Any) -> Any: - """Recursively convert datetime objects to ISO strings for JSON serialization""" + """Recursively convert datetime objects to ISO strings for JSON serialization.""" if isinstance(obj, datetime): return obj.isoformat() - elif isinstance(obj, dict): + if isinstance(obj, dict): return {key: self._serialize_datetime(value) for key, value in obj.items()} - elif isinstance(obj, list): + if isinstance(obj, list): return [self._serialize_datetime(item) for item in obj] - else: - return obj - - async def store_conversation_thread(self, user_id: str, response_id: str, conversation_state: Dict[str, Any]): - """Store a conversation thread and persist to disk (async, non-blocking)""" - if user_id not in self._conversations: - self._conversations[user_id] = {} - - # Recursively convert datetime objects to strings for JSON serialization - serialized_conversation = self._serialize_datetime(conversation_state) - - self._conversations[user_id][response_id] = serialized_conversation - - # Save to disk asynchronously (non-blocking) - await self._save_conversations() - - def get_conversation_thread(self, user_id: str, response_id: str) -> Dict[str, Any]: - """Get a specific conversation thread""" - user_conversations = self.get_user_conversations(user_id) - return user_conversations.get(response_id, {}) - + return obj + + async def get_user_conversations(self, user_id: str) -> Dict[str, Any]: + """Get all persisted metadata for a user keyed by response_id.""" + if not user_id: + return {} + + try: + if not clients.opensearch: + return self._fallback_conversations.get(user_id, {}) + + result = await clients.opensearch.search( + index=CONVERSATION_METADATA_INDEX_NAME, + body={ + "query": {"term": {"user_id": user_id}}, + "size": 1000, + "sort": [{"last_activity": {"order": "desc", "unmapped_type": "date"}}], + }, + ) + + conversations: Dict[str, Any] = {} + for hit in result.get("hits", {}).get("hits", []): + source = hit.get("_source", {}) + response_id = source.get("response_id") + if response_id: + metadata = dict(source) + metadata.pop("user_id", None) + conversations[response_id] = metadata + + self._fallback_conversations[user_id] = conversations + return conversations + except Exception as e: + logger.warning( + "Failed to load conversation metadata from OpenSearch, using fallback", + user_id=user_id, + error=str(e), + ) + return self._fallback_conversations.get(user_id, {}) + + async def store_conversation_thread( + self, user_id: str, response_id: str, conversation_state: Dict[str, Any] + ): + """Store conversation metadata in OpenSearch.""" + if user_id not in self._fallback_conversations: + self._fallback_conversations[user_id] = {} + + serialized = self._serialize_datetime(conversation_state) + serialized["user_id"] = user_id + serialized["response_id"] = response_id + + self._fallback_conversations[user_id][response_id] = { + key: value for key, value in serialized.items() if key != "user_id" + } + + try: + if not clients.opensearch: + return + + await clients.opensearch.index( + index=CONVERSATION_METADATA_INDEX_NAME, + id=self._get_document_id(user_id, response_id), + body=serialized, + refresh=True, + ) + except Exception as e: + logger.warning( + "Failed to persist conversation metadata to OpenSearch", + user_id=user_id, + response_id=response_id, + error=str(e), + ) + + async def get_conversation_thread(self, user_id: str, response_id: str) -> Dict[str, Any]: + """Get a specific conversation metadata record.""" + try: + if not clients.opensearch: + return self._fallback_conversations.get(user_id, {}).get(response_id, {}) + + result = await clients.opensearch.get( + index=CONVERSATION_METADATA_INDEX_NAME, + id=self._get_document_id(user_id, response_id), + ) + source = result.get("_source", {}) + source.pop("user_id", None) + return source + except Exception: + return self._fallback_conversations.get(user_id, {}).get(response_id, {}) + async def delete_conversation_thread(self, user_id: str, response_id: str) -> bool: - """Delete a specific conversation thread (async, non-blocking)""" - if user_id in self._conversations and response_id in self._conversations[user_id]: - del self._conversations[user_id][response_id] - await self._save_conversations() - logger.debug(f"Deleted conversation {response_id} for user {user_id}") + """Delete a specific conversation metadata record.""" + deleted = False + + if user_id in self._fallback_conversations and response_id in self._fallback_conversations[user_id]: + del self._fallback_conversations[user_id][response_id] + deleted = True + + try: + if not clients.opensearch: + return deleted + + await clients.opensearch.delete( + index=CONVERSATION_METADATA_INDEX_NAME, + id=self._get_document_id(user_id, response_id), + refresh=True, + ) return True - return False - + except Exception as e: + logger.debug( + "Failed to delete conversation metadata from OpenSearch", + user_id=user_id, + response_id=response_id, + error=str(e), + ) + return deleted + async def clear_user_conversations(self, user_id: str): - """Clear all conversations for a user (async, non-blocking)""" - if user_id in self._conversations: - del self._conversations[user_id] - await self._save_conversations() - logger.debug(f"Cleared all conversations for user {user_id}") - - def get_storage_stats(self) -> Dict[str, Any]: - """Get statistics about stored conversations""" - total_users = len(self._conversations) - total_conversations = self._count_total_conversations(self._conversations) - - user_stats = {} - for user_id, conversations in self._conversations.items(): - user_stats[user_id] = { - 'conversation_count': len(conversations), - 'latest_activity': max( - (conv.get('last_activity', '') for conv in conversations.values()), - default='' - ) + """Clear all conversation metadata for a user.""" + self._fallback_conversations.pop(user_id, None) + + try: + if not clients.opensearch: + return + + await clients.opensearch.delete_by_query( + index=CONVERSATION_METADATA_INDEX_NAME, + body={"query": {"term": {"user_id": user_id}}}, + refresh=True, + ) + except Exception as e: + logger.warning( + "Failed to clear conversation metadata for user", + user_id=user_id, + error=str(e), + ) + + async def get_storage_stats(self) -> Dict[str, Any]: + """Get basic storage statistics for conversation metadata.""" + fallback_total = sum(len(v) for v in self._fallback_conversations.values()) + + try: + if not clients.opensearch: + return { + "total_users": len(self._fallback_conversations), + "total_conversations": fallback_total, + "index": CONVERSATION_METADATA_INDEX_NAME, + "opensearch_available": False, + } + + count_response = await clients.opensearch.count( + index=CONVERSATION_METADATA_INDEX_NAME, + body={"query": {"match_all": {}}}, + ) + return { + "total_users": len(self._fallback_conversations), + "total_conversations": count_response.get("count", 0), + "index": CONVERSATION_METADATA_INDEX_NAME, + "opensearch_available": True, + } + except Exception as e: + logger.warning("Failed to get OpenSearch storage stats", error=str(e)) + return { + "total_users": len(self._fallback_conversations), + "total_conversations": fallback_total, + "index": CONVERSATION_METADATA_INDEX_NAME, + "opensearch_available": False, } - - return { - 'total_users': total_users, - 'total_conversations': total_conversations, - 'storage_file': self.storage_file, - 'file_exists': os.path.exists(self.storage_file), - 'user_stats': user_stats - } # Global instance -conversation_persistence = ConversationPersistenceService() \ No newline at end of file +conversation_persistence = ConversationPersistenceService() diff --git a/tests/unit/test_conversation_persistence_service.py b/tests/unit/test_conversation_persistence_service.py new file mode 100644 index 000000000..ed5f205c9 --- /dev/null +++ b/tests/unit/test_conversation_persistence_service.py @@ -0,0 +1,76 @@ +import pytest + +from services.conversation_persistence_service import ( + ConversationPersistenceService, + CONVERSATION_METADATA_INDEX_NAME, +) +from config.settings import clients + + +class FakeOpenSearch: + def __init__(self): + self.docs = {} + + async def index(self, index, id, body, refresh=True): + self.docs[(index, id)] = body + + async def search(self, index, body): + user_id = body["query"]["term"]["user_id"] + hits = [] + for (idx, _doc_id), source in self.docs.items(): + if idx == index and source.get("user_id") == user_id: + hits.append({"_source": source}) + return {"hits": {"hits": hits}} + + async def get(self, index, id): + return {"_source": self.docs[(index, id)]} + + async def delete(self, index, id, refresh=True): + self.docs.pop((index, id), None) + + +@pytest.mark.asyncio +async def test_store_and_fetch_user_conversations_from_opensearch(): + service = ConversationPersistenceService() + original = clients.opensearch + clients.opensearch = FakeOpenSearch() + try: + await service.store_conversation_thread( + "user-1", + "resp-1", + { + "title": "Hello", + "endpoint": "chat", + "created_at": "2026-01-01T00:00:00", + "last_activity": "2026-01-01T00:00:00", + "total_messages": 2, + }, + ) + + conversations = await service.get_user_conversations("user-1") + assert "resp-1" in conversations + assert conversations["resp-1"]["title"] == "Hello" + assert (CONVERSATION_METADATA_INDEX_NAME, "user-1:resp-1") in clients.opensearch.docs + finally: + clients.opensearch = original + + +@pytest.mark.asyncio +async def test_fallback_data_used_when_opensearch_unavailable(): + service = ConversationPersistenceService() + original = clients.opensearch + clients.opensearch = None + try: + await service.store_conversation_thread( + "user-2", + "resp-2", + { + "title": "Fallback", + "endpoint": "chat", + "total_messages": 1, + }, + ) + conversations = await service.get_user_conversations("user-2") + assert conversations["resp-2"]["title"] == "Fallback" + finally: + clients.opensearch = original