diff --git a/.gitignore b/.gitignore index add55ba8..1952ab03 100644 --- a/.gitignore +++ b/.gitignore @@ -196,4 +196,7 @@ dev-tools/mcp-mock-server/.certs/ requirements.*.backup # Local run files -local-run.yaml \ No newline at end of file +local-run.yaml + +datadir/ +INSTRUCTIONS.md \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 91d44662..d209acd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dependencies = [ "pyasn1>=0.6.3", # LCORE-1490 # Used for system prompt template variable rendering "jinja2>=3.1.0", + "anthropic>=0.86.0", ] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d9635e79..7773d4c3 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -65,6 +65,7 @@ ShieldModerationResult, TurnSummary, ) +from utils.user_memory import build_instructions_with_preferences, user_memory from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -183,6 +184,19 @@ async def query_endpoint_handler( inline_rag_context=inline_rag_context.context_text, ) + # Extract user preferences from conversation history + user_preferences = await user_memory( + user_id, + client, + responses_params.model, + _skip_userid_check, + new_conversation=not query_request.conversation_id, + ) + if user_preferences: + responses_params.instructions = build_instructions_with_preferences( + responses_params.instructions, user_preferences + ) + # Handle Azure token refresh if needed if ( responses_params.model.startswith("azure") @@ -293,6 +307,7 @@ async def retrieve_response( id=moderation_result.moderation_id, llm_response=moderation_result.message ) try: + logger.info(responses_params.model_dump(exclude_none=True)) response = await client.responses.create( **responses_params.model_dump(exclude_none=True) ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 13c2048f..953accbc 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,4 +1,4 @@ -"""Streaming query handler using Responses API.""" +"""Streaming query handler using Responses API.""" # pylint: disable=too-many-lines import asyncio import datetime @@ -95,6 +95,7 @@ from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.user_memory import build_instructions_with_preferences, user_memory from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -219,6 +220,19 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals inline_rag_context=inline_rag_context.context_text, ) + # Extract user preferences from conversation history + user_preferences = await user_memory( + user_id, + client, + responses_params.model, + _skip_userid_check, + new_conversation=not query_request.conversation_id, + ) + if user_preferences: + responses_params.instructions = build_instructions_with_preferences( + responses_params.instructions, user_preferences + ) + # Handle Azure token refresh if needed if ( responses_params.model.startswith("azure") diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py index baebf6aa..3887e353 100644 --- a/src/models/database/conversations.py +++ b/src/models/database/conversations.py @@ -39,6 +39,26 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods topic_summary: Mapped[str] = mapped_column(default="") +class UserMemory(Base): # pylint: disable=too-few-public-methods + """Model for storing cached user preference extractions.""" + + __tablename__ = "user_memory" + + # One row per user + user_id: Mapped[str] = mapped_column(primary_key=True) + + # The extracted preferences string (empty string if none found) + preferences: Mapped[str] = mapped_column(default="") + + # Conversation count at the time of extraction (used for cache invalidation) + conversation_count: Mapped[int] = mapped_column(default=0) + + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), # pylint: disable=not-callable + ) + + class UserTurn(Base): # pylint: disable=too-few-public-methods """Model for storing turn-level metadata.""" diff --git a/src/utils/user_memory.py b/src/utils/user_memory.py new file mode 100644 index 00000000..8d54c9e8 --- /dev/null +++ b/src/utils/user_memory.py @@ -0,0 +1,324 @@ +"""Utility for extracting user preferences from conversation history. + +Analyzes a user's past conversations to identify communication preferences +(response length, format, style) and returns a string that can be injected +into the system prompt so the LLM adapts to the user's preferences. + +Conversation history is read from the Llama Stack Conversations API (backed +by sql_store.db) via the SQLAlchemy UserConversation table for listing and +the Llama Stack client for fetching actual message content. Extracted +preferences are cached in the user_memory database table and invalidated +when the conversation count changes. +""" + +from typing import Optional, cast + +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from llama_stack_api.openai_responses import OpenAIResponseObject as ResponseObject +from sqlalchemy.exc import SQLAlchemyError + +from app.database import get_session +from log import get_logger +from models.database.conversations import UserConversation, UserMemory +from utils.conversations import get_all_conversation_items +from utils.responses import extract_text_from_response_items +from utils.suid import to_llama_stack_conversation_id + +logger = get_logger(__name__) + +# Limits for history collection +MAX_CONVERSATIONS = 5 +MAX_ITEMS_PER_CONVERSATION = 20 +MAX_TEXT_LENGTH = 200 + +# Marker returned by the LLM when no preferences are found +NO_PREFERENCES_MARKER = "NO_PREFERENCES" + +USER_MEMORY_SYSTEM_PROMPT = ( + "You are analyzing a user's conversation history to extract their " + "communication preferences and style requests.\n\n" + "Review the conversation history below and identify any explicit or implicit " + "preferences the user has expressed about how they want responses delivered. " + "Focus on:\n" + "- Preferred response length (brief, detailed, etc.)\n" + "- Preferred response format (bullet points, paragraphs, code examples, etc.)\n" + "- Any specific requests about communication style\n" + "- Language preferences\n" + "- Any other recurring requests or preferences\n\n" + "Return ONLY a concise summary of the user's preferences as instructions for " + "an AI assistant. If no clear preferences are found, return exactly: " + "NO_PREFERENCES" +) + + +async def user_memory( # pylint: disable=unused-argument + user_id: str, + client: AsyncLlamaStackClient, + model: str, + skip_user_id_check: bool = False, + new_conversation: bool = False, +) -> str: + """Extract user preferences from conversation history. + + Analyzes the user's past conversations to identify communication preferences + and style requests. Conversations are retrieved from the Llama Stack + Conversations API via the SQLAlchemy database. Results are cached in the + user_memory database table and invalidated when the number of conversations + changes. + + Parameters: + user_id: The authenticated user ID. + client: The AsyncLlamaStackClient for LLM and conversation API calls. + model: The model ID to use for preference extraction. + skip_user_id_check: Whether to skip user ID validation (unused, kept + for interface consistency with the auth tuple). + new_conversation: When True, forces cache invalidation because a new + conversation is being created but not yet stored in the database. + + Returns: + A string describing the user's preferences, or empty string if none found. + """ + conversation_ids = _list_user_conversation_ids(user_id) + current_count = len(conversation_ids) + if not conversation_ids: + logger.info("No conversation history for user %s", user_id) + return "" + + # Check database cache (skip when starting a new conversation since the + # new conversation row hasn't been stored yet and the count would match) + if not new_conversation: + cached = _get_cached_memory(user_id) + if cached is not None: + cached_preferences, cached_count = cached + if cached_count == current_count: + logger.info( + "Using cached user memory for user %s: %s", + user_id, + cached_preferences, + ) + return cached_preferences + + # Collect conversation history from Llama Stack + history_text = await _collect_conversation_history( + client, conversation_ids[:MAX_CONVERSATIONS] + ) + if not history_text: + return "" + + # Extract preferences using LLM + preferences = await _extract_preferences(client, model, history_text) + + # Cache the result in the database. When starting a new conversation, the + # conversation row hasn't been stored yet so current_count is off by one. + # Save count + 1 so the next request in the same conversation sees a match. + save_count = current_count + 1 if new_conversation else current_count + _save_cached_memory(user_id, preferences, save_count) + + if preferences: + logger.info("User memory for user %s: %s", user_id, preferences) + else: + logger.info("No user preferences identified for user %s", user_id) + + return preferences + + +def _list_user_conversation_ids(user_id: str) -> list[str]: + """List conversation IDs for a user from the database. + + Parameters: + user_id: The user ID. + + Returns: + List of conversation IDs ordered by most recent first. + """ + try: + with get_session() as session: + conversations = ( + session.query(UserConversation.id) + .filter_by(user_id=user_id) + .order_by(UserConversation.last_message_at.desc()) + .all() + ) + return [conv.id for conv in conversations] + except SQLAlchemyError: + logger.warning("Failed to list conversations for user memory", exc_info=True) + return [] + + +def _get_cached_memory(user_id: str) -> Optional[tuple[str, int]]: + """Get cached user memory from the database. + + Parameters: + user_id: The user ID. + + Returns: + Tuple of (preferences, conversation_count) or None if not cached. + """ + try: + with get_session() as session: + row = session.get(UserMemory, user_id) + if row is None: + return None + return (row.preferences, row.conversation_count) + except SQLAlchemyError: + logger.warning("Failed to read cached user memory", exc_info=True) + return None + + +def _save_cached_memory( + user_id: str, preferences: str, conversation_count: int +) -> None: + """Save user memory to the database. + + Parameters: + user_id: The user ID. + preferences: The extracted preferences string. + conversation_count: The conversation count at time of extraction. + """ + try: + with get_session() as session: + row = session.get(UserMemory, user_id) + if row is None: + row = UserMemory( + user_id=user_id, + preferences=preferences, + conversation_count=conversation_count, + ) + session.add(row) + else: + row.preferences = preferences + row.conversation_count = conversation_count + session.commit() + except SQLAlchemyError: + logger.warning("Failed to save user memory to database", exc_info=True) + + +async def _collect_conversation_history( + client: AsyncLlamaStackClient, + conversation_ids: list[str], +) -> str: + """Collect recent conversation history from Llama Stack, formatted for LLM analysis. + + Parameters: + client: The Llama Stack client. + conversation_ids: Conversation IDs to fetch (already limited and sorted). + + Returns: + Formatted conversation history string, or empty string if no history. + """ + history_parts: list[str] = [] + for conv_id in conversation_ids: + try: + llama_stack_id = to_llama_stack_conversation_id(conv_id) + items = await get_all_conversation_items(client, llama_stack_id) + except Exception: # pylint: disable=broad-except + logger.warning( + "Failed to get conversation %s for user memory", + conv_id, + exc_info=True, + ) + continue + + if not items: + continue + + conv_lines: list[str] = [] + for item in items[:MAX_ITEMS_PER_CONVERSATION]: + item_type = getattr(item, "type", None) + if item_type != "message": + continue + role = getattr(item, "role", None) + content = getattr(item, "content", "") + if isinstance(content, list): + text_parts = [] + for part in content: + text = getattr(part, "text", None) + if text: + text_parts.append(text) + content = " ".join(text_parts) + if not content: + continue + label = "User" if role == "user" else "Assistant" + conv_lines.append(f"{label}: {_truncate(str(content), MAX_TEXT_LENGTH)}") + + if conv_lines: + history_parts.append("\n".join(conv_lines)) + + if not history_parts: + return "" + + return "\n---\n".join(history_parts) + + +def _truncate(text: str, max_length: int) -> str: + """Truncate text to max_length, adding ellipsis if truncated. + + Parameters: + text: The text to truncate. + max_length: Maximum allowed length. + + Returns: + The truncated text. + """ + if len(text) <= max_length: + return text + return text[:max_length] + "..." + + +async def _extract_preferences( + client: AsyncLlamaStackClient, + model: str, + history_text: str, +) -> str: + """Use the LLM to extract user preferences from conversation history. + + Parameters: + client: The AsyncLlamaStackClient instance. + model: The model ID to use. + history_text: Formatted conversation history. + + Returns: + Extracted preferences string, or empty string if none found or on error. + """ + try: + response = cast( + ResponseObject, + await client.responses.create( + input=history_text, + model=model, + instructions=USER_MEMORY_SYSTEM_PROMPT, + stream=False, + store=False, + ), + ) + except (APIConnectionError, APIStatusError) as e: + logger.warning("Failed to extract user preferences via LLM: %s", e) + return "" + + result = extract_text_from_response_items(response.output) + if NO_PREFERENCES_MARKER in result: + return "" + + return result.strip() + + +def build_instructions_with_preferences( + instructions: Optional[str], preferences: str +) -> str: + """Append user preferences to the system prompt instructions. + + Parameters: + instructions: The original system prompt, or None. + preferences: The user preferences string. + + Returns: + The combined instructions with user preferences appended. + """ + base = instructions or "" + return ( + f"{base}\n\n" + "## User Preferences\n" + "The following preferences were identified from this user's " + "conversation history. Please adapt your responses accordingly:\n" + f"{preferences}" + ) diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index 40191821..f80423b8 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -140,6 +140,15 @@ def _build_base_mock_client(mocker: MockerFixture) -> Any: # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="mock_byok_client") def mock_byok_client_fixture( mocker: MockerFixture, diff --git a/tests/integration/endpoints/test_query_integration.py b/tests/integration/endpoints/test_query_integration.py index d96879c4..42f6fcbd 100644 --- a/tests/integration/endpoints/test_query_integration.py +++ b/tests/integration/endpoints/test_query_integration.py @@ -113,6 +113,15 @@ def mock_llama_stack_client_fixture( yield mock_client +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="patch_db_session", autouse=True) def patch_db_session_fixture( test_db_session: Session, diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index 5f58f603..354e96e0 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -98,6 +98,15 @@ async def _responses_create(**kwargs: Any) -> Any: # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="patch_db_session", autouse=True) def patch_db_session_fixture( test_db_session: Session, diff --git a/tests/integration/endpoints/test_streaming_query_integration.py b/tests/integration/endpoints/test_streaming_query_integration.py index 05fba0a5..28ed4d16 100644 --- a/tests/integration/endpoints/test_streaming_query_integration.py +++ b/tests/integration/endpoints/test_streaming_query_integration.py @@ -15,6 +15,15 @@ from models.requests import Attachment, QueryRequest +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="mock_streaming_llama_stack_client") def mock_llama_stack_streaming_fixture( mocker: MockerFixture, diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 06ee6992..c9f820ba 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -165,6 +165,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -248,6 +252,10 @@ async def test_query_merges_inline_and_tool_rag_chunks_and_documents( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -320,6 +328,10 @@ async def test_successful_query_with_conversation( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -403,6 +415,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, @@ -469,6 +485,10 @@ async def test_query_with_topic_summary( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, @@ -565,6 +585,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 99dee264..92b00414 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -396,6 +396,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await streaming_query_endpoint_handler( request=dummy_request, @@ -483,6 +487,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await streaming_query_endpoint_handler( request=dummy_request, @@ -581,6 +589,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, @@ -677,6 +689,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, @@ -777,6 +793,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, diff --git a/tests/unit/utils/test_user_memory.py b/tests/unit/utils/test_user_memory.py new file mode 100644 index 00000000..0f81c1d6 --- /dev/null +++ b/tests/unit/utils/test_user_memory.py @@ -0,0 +1,540 @@ +"""Unit tests for user_memory utility functions.""" + +# pylint: disable=protected-access + +import pytest +from llama_stack_client import APIConnectionError +from pytest_mock import MockerFixture, MockType +from sqlalchemy.exc import SQLAlchemyError + +from utils.user_memory import ( + NO_PREFERENCES_MARKER, + build_instructions_with_preferences, + user_memory, + _collect_conversation_history, + _extract_preferences, + _get_cached_memory, + _list_user_conversation_ids, + _save_cached_memory, + _truncate, +) + + +@pytest.fixture(name="mock_client") +def mock_client_fixture(mocker: MockerFixture) -> MockType: + """Create a mock AsyncLlamaStackClient.""" + return mocker.AsyncMock() + + +# --- Tests for user_memory() --- + + +@pytest.mark.asyncio +async def test_user_memory_no_conversations( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when user has no conversations.""" + mocker.patch("utils.user_memory._list_user_conversation_ids", return_value=[]) + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_returns_cached_result( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns cached result when conversation count matches.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + mocker.patch( + "utils.user_memory._get_cached_memory", + return_value=("User prefers detailed responses", 2), + ) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers detailed responses" + # LLM should not have been called + mock_client.responses.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_user_memory_skips_cache_on_new_conversation( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory skips cache when new_conversation=True.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + mock_get_cache = mocker.patch("utils.user_memory._get_cached_memory") + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: be brief\nAssistant: Ok", + ) + mock_save = mocker.patch("utils.user_memory._save_cached_memory") + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers brief responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory( + "user-1", mock_client, "provider/model", new_conversation=True + ) + assert result == "User prefers brief responses" + # Cache should NOT have been checked + mock_get_cache.assert_not_called() + # LLM should have been called + mock_client.responses.create.assert_called_once() + # Count saved should be current_count + 1 (2 existing + 1 new = 3) + mock_save.assert_called_once_with("user-1", "User prefers brief responses", 3) + + +@pytest.mark.asyncio +async def test_user_memory_invalidates_cache_on_count_change( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory re-extracts when conversation count changes.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: make responses longer\nAssistant: Ok", + ) + # Cached with count=1, but current count is 2 + mocker.patch( + "utils.user_memory._get_cached_memory", + return_value=("Old preferences", 1), + ) + mock_save = mocker.patch("utils.user_memory._save_cached_memory") + + # Mock LLM response + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers detailed responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers detailed responses" + mock_client.responses.create.assert_called_once() + mock_save.assert_called_once_with("user-1", "User prefers detailed responses", 2) + + +@pytest.mark.asyncio +async def test_user_memory_no_cache_entry( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory extracts preferences when no cache entry exists.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mock_save = mocker.patch("utils.user_memory._save_cached_memory") + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers brief responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers brief responses" + mock_save.assert_called_once_with("user-1", "User prefers brief responses", 1) + + +@pytest.mark.asyncio +async def test_user_memory_llm_returns_no_preferences( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when LLM finds no preferences.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mocker.patch("utils.user_memory._save_cached_memory") + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = NO_PREFERENCES_MARKER + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_llm_call_fails( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when LLM call fails.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mocker.patch("utils.user_memory._save_cached_memory") + mock_client.responses.create.side_effect = APIConnectionError( + request=mocker.MagicMock() + ) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_empty_history( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when history collection yields nothing.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="", + ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + mock_client.responses.create.assert_not_called() + + +# --- Tests for _get_cached_memory() --- + + +def test_get_cached_memory_hit(mocker: MockerFixture) -> None: + """Test reading cached memory from the database.""" + mock_session = mocker.MagicMock() + mock_row = mocker.MagicMock() + mock_row.preferences = "User prefers bullet points" + mock_row.conversation_count = 3 + mock_session.__enter__.return_value.get.return_value = mock_row + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result == ("User prefers bullet points", 3) + + +def test_get_cached_memory_miss(mocker: MockerFixture) -> None: + """Test _get_cached_memory returns None when no row exists.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = None + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result is None + + +def test_get_cached_memory_db_error(mocker: MockerFixture) -> None: + """Test _get_cached_memory returns None on DB error.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result is None + + +# --- Tests for _save_cached_memory() --- + + +def test_save_cached_memory_insert(mocker: MockerFixture) -> None: + """Test saving new user memory to the database.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = None + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + _save_cached_memory("user-1", "User prefers code examples", 2) + + session_ctx = mock_session.__enter__.return_value + session_ctx.add.assert_called_once() + session_ctx.commit.assert_called_once() + + +def test_save_cached_memory_update(mocker: MockerFixture) -> None: + """Test updating existing user memory in the database.""" + mock_session = mocker.MagicMock() + mock_row = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = mock_row + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + _save_cached_memory("user-1", "Updated preferences", 5) + + assert mock_row.preferences == "Updated preferences" + assert mock_row.conversation_count == 5 + mock_session.__enter__.return_value.commit.assert_called_once() + + +def test_save_cached_memory_db_error(mocker: MockerFixture) -> None: + """Test _save_cached_memory handles DB errors gracefully.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + # Should not raise + _save_cached_memory("user-1", "prefs", 1) + + +# --- Tests for _list_user_conversation_ids() --- + + +def test_list_user_conversation_ids(mocker: MockerFixture) -> None: + """Test listing conversation IDs from the database.""" + mock_session = mocker.MagicMock() + mock_query = mock_session.__enter__.return_value.query.return_value + mock_conv = mocker.MagicMock() + mock_conv.id = "conv-1" + mock_query.filter_by.return_value.order_by.return_value.all.return_value = [ + mock_conv + ] + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _list_user_conversation_ids("user-1") + assert result == ["conv-1"] + + +def test_list_user_conversation_ids_db_error(mocker: MockerFixture) -> None: + """Test that DB errors return empty list.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.query.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _list_user_conversation_ids("user-1") + assert result == [] + + +# --- Tests for _collect_conversation_history() --- + + +@pytest.mark.asyncio +async def test_collect_conversation_history(mocker: MockerFixture) -> None: + """Test _collect_conversation_history formats history from Llama Stack items.""" + mock_client = mocker.AsyncMock() + + # Create mock conversation items (user message + assistant message) + user_item = mocker.MagicMock() + user_item.type = "message" + user_item.role = "user" + user_item.content = "What is Python?" + + assistant_item = mocker.MagicMock() + assistant_item.type = "message" + assistant_item.role = "assistant" + assistant_item.content = "Python is a programming language." + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + return_value=[user_item, assistant_item], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + + assert "User: What is Python?" in result + assert "Assistant: Python is a programming language." in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_multiple_conversations( + mocker: MockerFixture, +) -> None: + """Test that multiple conversations are separated by ---.""" + mock_client = mocker.AsyncMock() + + item1 = mocker.MagicMock() + item1.type = "message" + item1.role = "user" + item1.content = "Hello" + + item2 = mocker.MagicMock() + item2.type = "message" + item2.role = "user" + item2.content = "World" + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + side_effect=[[item1], [item2]], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1", "conv-2"]) + assert "---" in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_skips_non_messages( + mocker: MockerFixture, +) -> None: + """Test that non-message items (tool calls etc.) are skipped.""" + mock_client = mocker.AsyncMock() + + tool_item = mocker.MagicMock() + tool_item.type = "file_search_call" + + user_item = mocker.MagicMock() + user_item.type = "message" + user_item.role = "user" + user_item.content = "Hello" + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + return_value=[tool_item, user_item], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + assert "Hello" in result + assert "file_search" not in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_api_failure( + mocker: MockerFixture, +) -> None: + """Test that API failures are handled gracefully.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "utils.user_memory.get_all_conversation_items", + side_effect=APIConnectionError(request=mocker.MagicMock()), + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + assert result == "" + + +# --- Tests for _extract_preferences() --- + + +@pytest.mark.asyncio +async def test_extract_preferences_success( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences returns extracted text.""" + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers concise bullet-point responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "User prefers concise bullet-point responses" + + +@pytest.mark.asyncio +async def test_extract_preferences_no_preferences( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences returns empty string when no preferences found.""" + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = NO_PREFERENCES_MARKER + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +@pytest.mark.asyncio +async def test_extract_preferences_empty_output( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences handles empty output.""" + mock_response = mocker.MagicMock() + mock_response.output = [] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +@pytest.mark.asyncio +async def test_extract_preferences_api_error( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences handles API errors gracefully.""" + mock_client.responses.create.side_effect = APIConnectionError( + request=mocker.MagicMock() + ) + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +# --- Tests for _truncate() --- + + +def test_truncate_short_text() -> None: + """Test that short text is not truncated.""" + assert _truncate("hello", 10) == "hello" + + +def test_truncate_exact_length() -> None: + """Test that text at exact limit is not truncated.""" + assert _truncate("hello", 5) == "hello" + + +def test_truncate_long_text() -> None: + """Test that long text is truncated with ellipsis.""" + assert _truncate("hello world", 5) == "hello..." + + +# --- Tests for build_instructions_with_preferences() --- + + +def test_build_instructions_with_preferences() -> None: + """Test building instructions with user preferences.""" + result = build_instructions_with_preferences( + "You are a helpful assistant", "User prefers detailed responses" + ) + assert "You are a helpful assistant" in result + assert "User Preferences" in result + assert "User prefers detailed responses" in result + + +def test_build_instructions_with_none_instructions() -> None: + """Test building instructions when original instructions are None.""" + result = build_instructions_with_preferences(None, "User prefers brief responses") + assert "User Preferences" in result + assert "User prefers brief responses" in result + + +def test_build_instructions_with_empty_instructions() -> None: + """Test building instructions when original instructions are empty.""" + result = build_instructions_with_preferences("", "User prefers code examples") + assert "User Preferences" in result + assert "User prefers code examples" in result diff --git a/uv.lock b/uv.lock index 98809a4a..b24ec46b 100644 --- a/uv.lock +++ b/uv.lock @@ -147,6 +147,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.86.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/7a/8b390dc47945d3169875d342847431e5f7d5fa716b2e37494d57cfc1db10/anthropic-0.86.0.tar.gz", hash = "sha256:60023a7e879aa4fbb1fed99d487fe407b2ebf6569603e5047cfe304cebdaa0e5", size = 583820, upload-time = "2026-03-18T18:43:08.017Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/5f/67db29c6e5d16c8c9c4652d3efb934d89cb750cad201539141781d8eae14/anthropic-0.86.0-py3-none-any.whl", hash = "sha256:9d2bbd339446acce98858c5627d33056efe01f70435b22b63546fe7edae0cd57", size = 469400, upload-time = "2026-03-18T18:43:06.526Z" }, +] + [[package]] name = "anyio" version = "4.12.1" @@ -1529,6 +1548,7 @@ dependencies = [ { name = "a2a-sdk" }, { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "asyncpg" }, { name = "authlib" }, { name = "azure-core" }, @@ -1624,6 +1644,7 @@ requires-dist = [ { name = "a2a-sdk", specifier = ">=0.3.4,<0.4.0" }, { name = "aiohttp", specifier = ">=3.12.14" }, { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "anthropic", specifier = ">=0.86.0" }, { name = "asyncpg", specifier = ">=0.31.0" }, { name = "authlib", specifier = ">=1.6.0" }, { name = "azure-core", specifier = ">=1.38.0" },