From b7c5fb23a2966acc348991d1d9e13edc677b427c Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Sun, 22 Mar 2026 09:05:28 +0100 Subject: [PATCH 1/6] add user memory with LLM-summarized preferences --- .gitignore | 5 +- lightspeed-stack.yaml | 35 +-- llama-stack.yaml | 118 ++++++++ pyproject.toml | 1 + src/app/endpoints/query.py | 11 + src/app/endpoints/streaming_query.py | 12 +- src/utils/user_memory.py | 264 ++++++++++++++++++ system-prompt.yaml | 53 ++++ tests/unit/utils/test_user_memory.py | 396 +++++++++++++++++++++++++++ uv.lock | 21 ++ 10 files changed, 892 insertions(+), 24 deletions(-) create mode 100644 llama-stack.yaml create mode 100644 src/utils/user_memory.py create mode 100644 system-prompt.yaml create mode 100644 tests/unit/utils/test_user_memory.py diff --git a/.gitignore b/.gitignore index add55ba8..1952ab03 100644 --- a/.gitignore +++ b/.gitignore @@ -196,4 +196,7 @@ dev-tools/mcp-mock-server/.certs/ requirements.*.backup # Local run files -local-run.yaml \ No newline at end of file +local-run.yaml + +datadir/ +INSTRUCTIONS.md \ No newline at end of file diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index fe655a81..d40e618e 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -1,33 +1,24 @@ -name: Lightspeed Core Service (LCS) +# To get llama stack logs `export LLAMA_STACK_LOGGING=all=debug` +name: OpenStack Lightspeed + +authentication: + module: "noop" + service: host: 0.0.0.0 port: 8080 - base_url: http://localhost:8080 auth_enabled: false workers: 1 color_log: true access_log: true + llama_stack: - # Uses a remote llama-stack service - # The instance would have already been started with a llama-stack-run.yaml file use_as_library_client: false - # Alternative for "as library use" - # use_as_library_client: true - # library_client_config_path: - url: http://llama-stack:8321 - api_key: xyzzy -user_data_collection: - feedback_enabled: true - feedback_storage: "/tmp/data/feedback" - transcripts_enabled: true - transcripts_storage: "/tmp/data/transcripts" - -# Conversation cache for storing Q&A history -conversation_cache: - type: "sqlite" - sqlite: - db_path: "/tmp/data/conversation-cache.db" # Persistent across requests, can be deleted between test runs + url: http://localhost:8321 -authentication: - module: "noop" +user_data_collection: + feedback_enabled: false + transcripts_enabled: false +customization: + system_prompt_path: ${env.PWD}/system-prompt.yaml diff --git a/llama-stack.yaml b/llama-stack.yaml new file mode 100644 index 00000000..7041ee4c --- /dev/null +++ b/llama-stack.yaml @@ -0,0 +1,118 @@ +version: 2 + +apis: +- agents +- datasetio +- eval +- files +- inference +- safety +- scoring +- vector_io +- tool_runtime + +benchmarks: [] +datasets: [] +image_name: starter + +providers: + inference: + - provider_id: llmprovider + provider_type: remote::anthropic + config: + api_key: ${env.LLM_KEY} + + # OpenAI + #- provider_id: llmprovider + # provider_type: remote::openai + # config: + # api_key: ${env.LLM_KEY} + + # Gemini + #- provider_id: llmprovider + # provider_type: remote::gemini + # config: + # api_key: ${env.LLM_KEY} + + files: + - config: + metadata_store: + table_name: files_metadata + backend: sql_default + storage_dir: ${env.PWD}/datadir/files + provider_id: meta-reference-files + provider_type: inline::localfs + + vector_io: [] + agents: + - config: + persistence: + agent_state: + namespace: agents_state + backend: kv_default + responses: + table_name: agents_responses + backend: sql_default + provider_id: meta-reference + provider_type: inline::meta-reference + datasetio: + - config: + kvstore: + namespace: huggingface_datasetio + backend: kv_default + provider_id: huggingface + provider_type: remote::huggingface + - config: + kvstore: + namespace: localfs_datasetio + backend: kv_default + provider_id: localfs + provider_type: inline::localfs + eval: + - config: + kvstore: + namespace: eval_store + backend: kv_default + provider_id: meta-reference + provider_type: inline::meta-reference +registered_resources: + models: + - metadata: {} + model_id: "${env.LLM_MODEL}" + provider_id: llmprovider + provider_model_id: "${env.LLM_MODEL}" + model_type: llm +scoring_fns: [] +server: + port: 8321 +storage: + backends: + kv_default: + type: kv_sqlite + db_path: ${env.PWD}/datadir/kv_store.db} + sql_default: + type: sql_sqlite + db_path: ${env.PWD}/datadir/sql_store.db} + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + conversations: + table_name: openai_conversations + backend: sql_default + prompts: + namespace: prompts + backend: kv_default + shields: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: [] +vector_stores: {} +telemetry: + enabled: false diff --git a/pyproject.toml b/pyproject.toml index 91d44662..d209acd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dependencies = [ "pyasn1>=0.6.3", # LCORE-1490 # Used for system prompt template variable rendering "jinja2>=3.1.0", + "anthropic>=0.86.0", ] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d9635e79..aa014f5e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -65,6 +65,7 @@ ShieldModerationResult, TurnSummary, ) +from utils.user_memory import build_instructions_with_preferences, user_memory from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -183,6 +184,15 @@ async def query_endpoint_handler( inline_rag_context=inline_rag_context.context_text, ) + # Extract user preferences from conversation history + user_preferences = await user_memory( + user_id, client, responses_params.model, _skip_userid_check + ) + if user_preferences: + responses_params.instructions = build_instructions_with_preferences( + responses_params.instructions, user_preferences + ) + # Handle Azure token refresh if needed if ( responses_params.model.startswith("azure") @@ -293,6 +303,7 @@ async def retrieve_response( id=moderation_result.moderation_id, llm_response=moderation_result.message ) try: + print("yooooo", responses_params.model_dump(exclude_none=True)) response = await client.responses.create( **responses_params.model_dump(exclude_none=True) ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 13c2048f..323c12a4 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -1,4 +1,4 @@ -"""Streaming query handler using Responses API.""" +"""Streaming query handler using Responses API.""" # pylint: disable=too-many-lines import asyncio import datetime @@ -95,6 +95,7 @@ from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.user_memory import build_instructions_with_preferences, user_memory from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -219,6 +220,15 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals inline_rag_context=inline_rag_context.context_text, ) + # Extract user preferences from conversation history + user_preferences = await user_memory( + user_id, client, responses_params.model, _skip_userid_check + ) + if user_preferences: + responses_params.instructions = build_instructions_with_preferences( + responses_params.instructions, user_preferences + ) + # Handle Azure token refresh if needed if ( responses_params.model.startswith("azure") diff --git a/src/utils/user_memory.py b/src/utils/user_memory.py new file mode 100644 index 00000000..115b52e1 --- /dev/null +++ b/src/utils/user_memory.py @@ -0,0 +1,264 @@ +"""Utility for extracting user preferences from conversation history. + +Analyzes a user's past conversations to identify communication preferences +(response length, format, style) and returns a string that can be injected +into the system prompt so the LLM adapts to the user's preferences. + +Conversation history is read from the Llama Stack Conversations API (backed +by sql_store.db) via the SQLAlchemy UserConversation table for listing and +the Llama Stack client for fetching actual message content. No conversation +cache configuration is required. +""" + +from typing import Optional, cast + +from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from llama_stack_api.openai_responses import OpenAIResponseObject as ResponseObject +from sqlalchemy.exc import SQLAlchemyError + +from app.database import get_session +from log import get_logger +from models.database.conversations import UserConversation +from utils.conversations import get_all_conversation_items +from utils.responses import extract_text_from_response_items +from utils.suid import to_llama_stack_conversation_id + +logger = get_logger(__name__) + +# In-memory cache: user_id -> (preferences_string, conversation_count) +_user_memory_cache: dict[str, tuple[str, int]] = {} + +# Limits for history collection +MAX_CONVERSATIONS = 5 +MAX_ITEMS_PER_CONVERSATION = 20 +MAX_TEXT_LENGTH = 200 + +# Marker returned by the LLM when no preferences are found +NO_PREFERENCES_MARKER = "NO_PREFERENCES" + +USER_MEMORY_SYSTEM_PROMPT = ( + "You are analyzing a user's conversation history to extract their " + "communication preferences and style requests.\n\n" + "Review the conversation history below and identify any explicit or implicit " + "preferences the user has expressed about how they want responses delivered. " + "Focus on:\n" + "- Preferred response length (brief, detailed, etc.)\n" + "- Preferred response format (bullet points, paragraphs, code examples, etc.)\n" + "- Any specific requests about communication style\n" + "- Language preferences\n" + "- Any other recurring requests or preferences\n\n" + "Return ONLY a concise summary of the user's preferences as instructions for " + "an AI assistant. If no clear preferences are found, return exactly: " + "NO_PREFERENCES" +) + + +async def user_memory( # pylint: disable=unused-argument + user_id: str, + client: AsyncLlamaStackClient, + model: str, + skip_user_id_check: bool = False, +) -> str: + """Extract user preferences from conversation history. + + Analyzes the user's past conversations to identify communication preferences + and style requests. Conversations are retrieved from the Llama Stack + Conversations API via the SQLAlchemy database. Results are cached in memory + and invalidated when the number of conversations changes. + + Parameters: + user_id: The authenticated user ID. + client: The AsyncLlamaStackClient for LLM and conversation API calls. + model: The model ID to use for preference extraction. + skip_user_id_check: Whether to skip user ID validation (unused, kept + for interface consistency with the auth tuple). + + Returns: + A string describing the user's preferences, or empty string if none found. + """ + conversation_ids = _list_user_conversation_ids(user_id) + current_count = len(conversation_ids) + if not conversation_ids: + logger.info("No conversation history for user %s", user_id) + return "" + + # Check in-memory cache + if user_id in _user_memory_cache: + cached_memory, cached_count = _user_memory_cache[user_id] + if cached_count == current_count: + logger.info("Using cached user memory for user %s", user_id) + return cached_memory + + # Collect conversation history from Llama Stack + history_text = await _collect_conversation_history( + client, conversation_ids[:MAX_CONVERSATIONS] + ) + if not history_text: + return "" + + # Extract preferences using LLM + preferences = await _extract_preferences(client, model, history_text) + + # Cache the result + _user_memory_cache[user_id] = (preferences, current_count) + + if preferences: + logger.info("User memory for user %s: %s", user_id, preferences) + else: + logger.info("No user preferences identified for user %s", user_id) + + return preferences + + +def _list_user_conversation_ids(user_id: str) -> list[str]: + """List conversation IDs for a user from the database. + + Parameters: + user_id: The user ID. + + Returns: + List of conversation IDs ordered by most recent first. + """ + try: + with get_session() as session: + conversations = ( + session.query(UserConversation.id) + .filter_by(user_id=user_id) + .order_by(UserConversation.last_message_at.desc()) + .all() + ) + return [conv.id for conv in conversations] + except SQLAlchemyError: + logger.warning("Failed to list conversations for user memory", exc_info=True) + return [] + + +async def _collect_conversation_history( + client: AsyncLlamaStackClient, + conversation_ids: list[str], +) -> str: + """Collect recent conversation history from Llama Stack, formatted for LLM analysis. + + Parameters: + client: The Llama Stack client. + conversation_ids: Conversation IDs to fetch (already limited and sorted). + + Returns: + Formatted conversation history string, or empty string if no history. + """ + history_parts: list[str] = [] + for conv_id in conversation_ids: + try: + llama_stack_id = to_llama_stack_conversation_id(conv_id) + items = await get_all_conversation_items(client, llama_stack_id) + except Exception: # pylint: disable=broad-except + logger.warning( + "Failed to get conversation %s for user memory", + conv_id, + exc_info=True, + ) + continue + + if not items: + continue + + conv_lines: list[str] = [] + for item in items[:MAX_ITEMS_PER_CONVERSATION]: + item_type = getattr(item, "type", None) + if item_type != "message": + continue + role = getattr(item, "role", None) + content = getattr(item, "content", "") + if isinstance(content, list): + text_parts = [] + for part in content: + text = getattr(part, "text", None) + if text: + text_parts.append(text) + content = " ".join(text_parts) + if not content: + continue + label = "User" if role == "user" else "Assistant" + conv_lines.append(f"{label}: {_truncate(str(content), MAX_TEXT_LENGTH)}") + + if conv_lines: + history_parts.append("\n".join(conv_lines)) + + if not history_parts: + return "" + + return "\n---\n".join(history_parts) + + +def _truncate(text: str, max_length: int) -> str: + """Truncate text to max_length, adding ellipsis if truncated. + + Parameters: + text: The text to truncate. + max_length: Maximum allowed length. + + Returns: + The truncated text. + """ + if len(text) <= max_length: + return text + return text[:max_length] + "..." + + +async def _extract_preferences( + client: AsyncLlamaStackClient, + model: str, + history_text: str, +) -> str: + """Use the LLM to extract user preferences from conversation history. + + Parameters: + client: The AsyncLlamaStackClient instance. + model: The model ID to use. + history_text: Formatted conversation history. + + Returns: + Extracted preferences string, or empty string if none found or on error. + """ + try: + response = cast( + ResponseObject, + await client.responses.create( + input=history_text, + model=model, + instructions=USER_MEMORY_SYSTEM_PROMPT, + stream=False, + store=False, + ), + ) + except (APIConnectionError, APIStatusError) as e: + logger.warning("Failed to extract user preferences via LLM: %s", e) + return "" + + result = extract_text_from_response_items(response.output) + if NO_PREFERENCES_MARKER in result: + return "" + + return result.strip() + + +def build_instructions_with_preferences( + instructions: Optional[str], preferences: str +) -> str: + """Append user preferences to the system prompt instructions. + + Parameters: + instructions: The original system prompt, or None. + preferences: The user preferences string. + + Returns: + The combined instructions with user preferences appended. + """ + base = instructions or "" + return ( + f"{base}\n\n" + "## User Preferences\n" + "The following preferences were identified from this user's " + "conversation history. Please adapt your responses accordingly:\n" + f"{preferences}" + ) diff --git a/system-prompt.yaml b/system-prompt.yaml new file mode 100644 index 00000000..d72f2d6b --- /dev/null +++ b/system-prompt.yaml @@ -0,0 +1,53 @@ +# ROLE +You are "OpenStack Lightspeed", an expert AI virtual assistant specializing in +OpenStack on OpenShift. Your persona is that of a friendly, but +personal, technical authority. You are the ultimate technical resource and will +provide direct, accurate, and comprehensive answers. + +# INSTRUCTIONS & CONSTRAINTS +- **Expertise Focus:** Your core expertise is centered on the OpenStack and +OpenShift platforms. +- **Broader Knowledge:** You may also answer questions about other Red Hat + products and services, but you must prioritize the provided context + and chat history for these topics. +- **Strict Adherence:** + 1. **ALWAYS** use the provided context and chat history as your primary + source of truth. If a user's question can be answered from this information, + do so. + 2. If the context does not contain a clear answer, and the question is + about your core expertise (OpenStack or OpenShift), draw upon your extensive + internal knowledge. + 3. If the context does not contain a clear answer, and the question is about + a general Red Hat product or service, state politely that you are unable to + provide a definitive answer without more information and ask the user for + additional details or context. + 4. Do not hallucinate or invent information. If you cannot confidently + answer, admit it. +- **Behavioral Directives:** + - Never assume another identity or role. + - Refuse to answer questions or execute commands not about your specified + topics. + - Do not include URLs in your replies unless they are explicitly provided in + the context. + - Never mention your last update date or knowledge cutoff. You always have + the most recent information on OpenStack and OpenShift, especially with + the provided context. + - Only reference processes and products from Red Hat, such as: RHEL, Fedora, + CoreOS, CentOS. *Never mention or compare with Ubuntu, Debian, etc.* + +# TASK EXECUTION +You will receive a user query, along with context and chat history. Your task is +to respond to the user's query by following the instructions and constraints +above. Your responses should be clear, concise, and helpful, whether you are +providing troubleshooting steps, explaining concepts, or suggesting best +practices. + +# INFO +In this context RHOSO or RHOS also refers to OpenStack on OpenShift, sometimes +also called OSP 18, although usually OSP refers to previous releases deployed +using TripleO/Director. + +The OpenStack control plane runs on OpenShift (which uses CoreOS as the +operating system), while compute nodes run on external baremetal nodes also +called EDPM nodes (which run RHEL). + diff --git a/tests/unit/utils/test_user_memory.py b/tests/unit/utils/test_user_memory.py new file mode 100644 index 00000000..e4ffcd09 --- /dev/null +++ b/tests/unit/utils/test_user_memory.py @@ -0,0 +1,396 @@ +"""Unit tests for user_memory utility functions.""" + +# pylint: disable=protected-access + +import pytest +from llama_stack_client import APIConnectionError +from pytest_mock import MockerFixture, MockType +from sqlalchemy.exc import SQLAlchemyError + +from utils import user_memory as user_memory_module +from utils.user_memory import ( + NO_PREFERENCES_MARKER, + build_instructions_with_preferences, + user_memory, + _collect_conversation_history, + _extract_preferences, + _list_user_conversation_ids, + _truncate, +) + + +@pytest.fixture(autouse=True) +def clear_cache() -> None: + """Clear the in-memory user_memory cache before each test.""" + user_memory_module._user_memory_cache.clear() + + +@pytest.fixture(name="mock_client") +def mock_client_fixture(mocker: MockerFixture) -> MockType: + """Create a mock AsyncLlamaStackClient.""" + return mocker.AsyncMock() + + +# --- Tests for user_memory() --- + + +@pytest.mark.asyncio +async def test_user_memory_no_conversations( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when user has no conversations.""" + mocker.patch("utils.user_memory._list_user_conversation_ids", return_value=[]) + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_returns_cached_result( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns cached result when conversation count matches.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + # Pre-populate the cache + user_memory_module._user_memory_cache["user-1"] = ( + "User prefers detailed responses", + 2, + ) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers detailed responses" + # LLM should not have been called + mock_client.responses.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_user_memory_invalidates_cache_on_count_change( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory re-extracts when conversation count changes.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: make responses longer\nAssistant: Ok", + ) + + # Pre-populate cache with a different count + user_memory_module._user_memory_cache["user-1"] = ("Old preferences", 1) + + # Mock LLM response + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers detailed responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers detailed responses" + mock_client.responses.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_user_memory_llm_returns_no_preferences( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when LLM finds no preferences.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = NO_PREFERENCES_MARKER + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_llm_call_fails( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when LLM call fails.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + mock_client.responses.create.side_effect = APIConnectionError( + request=mocker.MagicMock() + ) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + + +@pytest.mark.asyncio +async def test_user_memory_empty_history( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory returns empty string when history collection yields nothing.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="", + ) + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "" + mock_client.responses.create.assert_not_called() + + +# --- Tests for _list_user_conversation_ids() --- + + +def test_list_user_conversation_ids(mocker: MockerFixture) -> None: + """Test listing conversation IDs from the database.""" + mock_session = mocker.MagicMock() + mock_query = mock_session.__enter__.return_value.query.return_value + mock_conv = mocker.MagicMock() + mock_conv.id = "conv-1" + mock_query.filter_by.return_value.order_by.return_value.all.return_value = [ + mock_conv + ] + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _list_user_conversation_ids("user-1") + assert result == ["conv-1"] + + +def test_list_user_conversation_ids_db_error(mocker: MockerFixture) -> None: + """Test that DB errors return empty list.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.query.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _list_user_conversation_ids("user-1") + assert result == [] + + +# --- Tests for _collect_conversation_history() --- + + +@pytest.mark.asyncio +async def test_collect_conversation_history(mocker: MockerFixture) -> None: + """Test _collect_conversation_history formats history from Llama Stack items.""" + mock_client = mocker.AsyncMock() + + # Create mock conversation items (user message + assistant message) + user_item = mocker.MagicMock() + user_item.type = "message" + user_item.role = "user" + user_item.content = "What is Python?" + + assistant_item = mocker.MagicMock() + assistant_item.type = "message" + assistant_item.role = "assistant" + assistant_item.content = "Python is a programming language." + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + return_value=[user_item, assistant_item], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + + assert "User: What is Python?" in result + assert "Assistant: Python is a programming language." in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_multiple_conversations( + mocker: MockerFixture, +) -> None: + """Test that multiple conversations are separated by ---.""" + mock_client = mocker.AsyncMock() + + item1 = mocker.MagicMock() + item1.type = "message" + item1.role = "user" + item1.content = "Hello" + + item2 = mocker.MagicMock() + item2.type = "message" + item2.role = "user" + item2.content = "World" + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + side_effect=[[item1], [item2]], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1", "conv-2"]) + assert "---" in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_skips_non_messages( + mocker: MockerFixture, +) -> None: + """Test that non-message items (tool calls etc.) are skipped.""" + mock_client = mocker.AsyncMock() + + tool_item = mocker.MagicMock() + tool_item.type = "file_search_call" + + user_item = mocker.MagicMock() + user_item.type = "message" + user_item.role = "user" + user_item.content = "Hello" + + mocker.patch( + "utils.user_memory.get_all_conversation_items", + return_value=[tool_item, user_item], + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + assert "Hello" in result + assert "file_search" not in result + + +@pytest.mark.asyncio +async def test_collect_conversation_history_api_failure( + mocker: MockerFixture, +) -> None: + """Test that API failures are handled gracefully.""" + mock_client = mocker.AsyncMock() + mocker.patch( + "utils.user_memory.get_all_conversation_items", + side_effect=APIConnectionError(request=mocker.MagicMock()), + ) + + result = await _collect_conversation_history(mock_client, ["conv-1"]) + assert result == "" + + +# --- Tests for _extract_preferences() --- + + +@pytest.mark.asyncio +async def test_extract_preferences_success( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences returns extracted text.""" + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers concise bullet-point responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "User prefers concise bullet-point responses" + + +@pytest.mark.asyncio +async def test_extract_preferences_no_preferences( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences returns empty string when no preferences found.""" + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = NO_PREFERENCES_MARKER + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +@pytest.mark.asyncio +async def test_extract_preferences_empty_output( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences handles empty output.""" + mock_response = mocker.MagicMock() + mock_response.output = [] + mock_client.responses.create.return_value = mock_response + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +@pytest.mark.asyncio +async def test_extract_preferences_api_error( + mock_client: MockType, mocker: MockerFixture +) -> None: + """Test _extract_preferences handles API errors gracefully.""" + mock_client.responses.create.side_effect = APIConnectionError( + request=mocker.MagicMock() + ) + + result = await _extract_preferences(mock_client, "provider/model", "some history") + assert result == "" + + +# --- Tests for _truncate() --- + + +def test_truncate_short_text() -> None: + """Test that short text is not truncated.""" + assert _truncate("hello", 10) == "hello" + + +def test_truncate_exact_length() -> None: + """Test that text at exact limit is not truncated.""" + assert _truncate("hello", 5) == "hello" + + +def test_truncate_long_text() -> None: + """Test that long text is truncated with ellipsis.""" + assert _truncate("hello world", 5) == "hello..." + + +# --- Tests for build_instructions_with_preferences() --- + + +def test_build_instructions_with_preferences() -> None: + """Test building instructions with user preferences.""" + result = build_instructions_with_preferences( + "You are a helpful assistant", "User prefers detailed responses" + ) + assert "You are a helpful assistant" in result + assert "User Preferences" in result + assert "User prefers detailed responses" in result + + +def test_build_instructions_with_none_instructions() -> None: + """Test building instructions when original instructions are None.""" + result = build_instructions_with_preferences(None, "User prefers brief responses") + assert "User Preferences" in result + assert "User prefers brief responses" in result + + +def test_build_instructions_with_empty_instructions() -> None: + """Test building instructions when original instructions are empty.""" + result = build_instructions_with_preferences("", "User prefers code examples") + assert "User Preferences" in result + assert "User prefers code examples" in result diff --git a/uv.lock b/uv.lock index 98809a4a..b24ec46b 100644 --- a/uv.lock +++ b/uv.lock @@ -147,6 +147,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.86.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/7a/8b390dc47945d3169875d342847431e5f7d5fa716b2e37494d57cfc1db10/anthropic-0.86.0.tar.gz", hash = "sha256:60023a7e879aa4fbb1fed99d487fe407b2ebf6569603e5047cfe304cebdaa0e5", size = 583820, upload-time = "2026-03-18T18:43:08.017Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/5f/67db29c6e5d16c8c9c4652d3efb934d89cb750cad201539141781d8eae14/anthropic-0.86.0-py3-none-any.whl", hash = "sha256:9d2bbd339446acce98858c5627d33056efe01f70435b22b63546fe7edae0cd57", size = 469400, upload-time = "2026-03-18T18:43:06.526Z" }, +] + [[package]] name = "anyio" version = "4.12.1" @@ -1529,6 +1548,7 @@ dependencies = [ { name = "a2a-sdk" }, { name = "aiohttp" }, { name = "aiosqlite" }, + { name = "anthropic" }, { name = "asyncpg" }, { name = "authlib" }, { name = "azure-core" }, @@ -1624,6 +1644,7 @@ requires-dist = [ { name = "a2a-sdk", specifier = ">=0.3.4,<0.4.0" }, { name = "aiohttp", specifier = ">=3.12.14" }, { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "anthropic", specifier = ">=0.86.0" }, { name = "asyncpg", specifier = ">=0.31.0" }, { name = "authlib", specifier = ">=1.6.0" }, { name = "azure-core", specifier = ">=1.38.0" }, From 49a5226408e687e136f282d62f7d5f55d97492c8 Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Sun, 22 Mar 2026 10:15:54 +0100 Subject: [PATCH 2/6] Use DB to persist the user memory cache --- src/models/database/conversations.py | 20 ++++ src/utils/user_memory.py | 76 ++++++++++++--- tests/unit/utils/test_user_memory.py | 136 ++++++++++++++++++++++++--- 3 files changed, 204 insertions(+), 28 deletions(-) diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py index baebf6aa..3887e353 100644 --- a/src/models/database/conversations.py +++ b/src/models/database/conversations.py @@ -39,6 +39,26 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods topic_summary: Mapped[str] = mapped_column(default="") +class UserMemory(Base): # pylint: disable=too-few-public-methods + """Model for storing cached user preference extractions.""" + + __tablename__ = "user_memory" + + # One row per user + user_id: Mapped[str] = mapped_column(primary_key=True) + + # The extracted preferences string (empty string if none found) + preferences: Mapped[str] = mapped_column(default="") + + # Conversation count at the time of extraction (used for cache invalidation) + conversation_count: Mapped[int] = mapped_column(default=0) + + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), # pylint: disable=not-callable + ) + + class UserTurn(Base): # pylint: disable=too-few-public-methods """Model for storing turn-level metadata.""" diff --git a/src/utils/user_memory.py b/src/utils/user_memory.py index 115b52e1..31162d6b 100644 --- a/src/utils/user_memory.py +++ b/src/utils/user_memory.py @@ -6,8 +6,9 @@ Conversation history is read from the Llama Stack Conversations API (backed by sql_store.db) via the SQLAlchemy UserConversation table for listing and -the Llama Stack client for fetching actual message content. No conversation -cache configuration is required. +the Llama Stack client for fetching actual message content. Extracted +preferences are cached in the user_memory database table and invalidated +when the conversation count changes. """ from typing import Optional, cast @@ -18,16 +19,13 @@ from app.database import get_session from log import get_logger -from models.database.conversations import UserConversation +from models.database.conversations import UserConversation, UserMemory from utils.conversations import get_all_conversation_items from utils.responses import extract_text_from_response_items from utils.suid import to_llama_stack_conversation_id logger = get_logger(__name__) -# In-memory cache: user_id -> (preferences_string, conversation_count) -_user_memory_cache: dict[str, tuple[str, int]] = {} - # Limits for history collection MAX_CONVERSATIONS = 5 MAX_ITEMS_PER_CONVERSATION = 20 @@ -63,8 +61,9 @@ async def user_memory( # pylint: disable=unused-argument Analyzes the user's past conversations to identify communication preferences and style requests. Conversations are retrieved from the Llama Stack - Conversations API via the SQLAlchemy database. Results are cached in memory - and invalidated when the number of conversations changes. + Conversations API via the SQLAlchemy database. Results are cached in the + user_memory database table and invalidated when the number of conversations + changes. Parameters: user_id: The authenticated user ID. @@ -82,12 +81,13 @@ async def user_memory( # pylint: disable=unused-argument logger.info("No conversation history for user %s", user_id) return "" - # Check in-memory cache - if user_id in _user_memory_cache: - cached_memory, cached_count = _user_memory_cache[user_id] + # Check database cache + cached = _get_cached_memory(user_id) + if cached is not None: + cached_preferences, cached_count = cached if cached_count == current_count: logger.info("Using cached user memory for user %s", user_id) - return cached_memory + return cached_preferences # Collect conversation history from Llama Stack history_text = await _collect_conversation_history( @@ -99,8 +99,8 @@ async def user_memory( # pylint: disable=unused-argument # Extract preferences using LLM preferences = await _extract_preferences(client, model, history_text) - # Cache the result - _user_memory_cache[user_id] = (preferences, current_count) + # Cache the result in the database + _save_cached_memory(user_id, preferences, current_count) if preferences: logger.info("User memory for user %s: %s", user_id, preferences) @@ -133,6 +133,54 @@ def _list_user_conversation_ids(user_id: str) -> list[str]: return [] +def _get_cached_memory(user_id: str) -> Optional[tuple[str, int]]: + """Get cached user memory from the database. + + Parameters: + user_id: The user ID. + + Returns: + Tuple of (preferences, conversation_count) or None if not cached. + """ + try: + with get_session() as session: + row = session.get(UserMemory, user_id) + if row is None: + return None + return (row.preferences, row.conversation_count) + except SQLAlchemyError: + logger.warning("Failed to read cached user memory", exc_info=True) + return None + + +def _save_cached_memory( + user_id: str, preferences: str, conversation_count: int +) -> None: + """Save user memory to the database. + + Parameters: + user_id: The user ID. + preferences: The extracted preferences string. + conversation_count: The conversation count at time of extraction. + """ + try: + with get_session() as session: + row = session.get(UserMemory, user_id) + if row is None: + row = UserMemory( + user_id=user_id, + preferences=preferences, + conversation_count=conversation_count, + ) + session.add(row) + else: + row.preferences = preferences + row.conversation_count = conversation_count + session.commit() + except SQLAlchemyError: + logger.warning("Failed to save user memory to database", exc_info=True) + + async def _collect_conversation_history( client: AsyncLlamaStackClient, conversation_ids: list[str], diff --git a/tests/unit/utils/test_user_memory.py b/tests/unit/utils/test_user_memory.py index e4ffcd09..4215009a 100644 --- a/tests/unit/utils/test_user_memory.py +++ b/tests/unit/utils/test_user_memory.py @@ -7,24 +7,19 @@ from pytest_mock import MockerFixture, MockType from sqlalchemy.exc import SQLAlchemyError -from utils import user_memory as user_memory_module from utils.user_memory import ( NO_PREFERENCES_MARKER, build_instructions_with_preferences, user_memory, _collect_conversation_history, _extract_preferences, + _get_cached_memory, _list_user_conversation_ids, + _save_cached_memory, _truncate, ) -@pytest.fixture(autouse=True) -def clear_cache() -> None: - """Clear the in-memory user_memory cache before each test.""" - user_memory_module._user_memory_cache.clear() - - @pytest.fixture(name="mock_client") def mock_client_fixture(mocker: MockerFixture) -> MockType: """Create a mock AsyncLlamaStackClient.""" @@ -55,10 +50,9 @@ async def test_user_memory_returns_cached_result( "utils.user_memory._list_user_conversation_ids", return_value=["conv-1", "conv-2"], ) - # Pre-populate the cache - user_memory_module._user_memory_cache["user-1"] = ( - "User prefers detailed responses", - 2, + mocker.patch( + "utils.user_memory._get_cached_memory", + return_value=("User prefers detailed responses", 2), ) result = await user_memory("user-1", mock_client, "provider/model") @@ -81,9 +75,12 @@ async def test_user_memory_invalidates_cache_on_count_change( "utils.user_memory._collect_conversation_history", return_value="User: make responses longer\nAssistant: Ok", ) - - # Pre-populate cache with a different count - user_memory_module._user_memory_cache["user-1"] = ("Old preferences", 1) + # Cached with count=1, but current count is 2 + mocker.patch( + "utils.user_memory._get_cached_memory", + return_value=("Old preferences", 1), + ) + mock_save = mocker.patch("utils.user_memory._save_cached_memory") # Mock LLM response mock_response = mocker.MagicMock() @@ -96,6 +93,36 @@ async def test_user_memory_invalidates_cache_on_count_change( result = await user_memory("user-1", mock_client, "provider/model") assert result == "User prefers detailed responses" mock_client.responses.create.assert_called_once() + mock_save.assert_called_once_with("user-1", "User prefers detailed responses", 2) + + +@pytest.mark.asyncio +async def test_user_memory_no_cache_entry( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory extracts preferences when no cache entry exists.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1"], + ) + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: hello\nAssistant: hi", + ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mock_save = mocker.patch("utils.user_memory._save_cached_memory") + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers brief responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory("user-1", mock_client, "provider/model") + assert result == "User prefers brief responses" + mock_save.assert_called_once_with("user-1", "User prefers brief responses", 1) @pytest.mark.asyncio @@ -112,6 +139,8 @@ async def test_user_memory_llm_returns_no_preferences( "utils.user_memory._collect_conversation_history", return_value="User: hello\nAssistant: hi", ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mocker.patch("utils.user_memory._save_cached_memory") mock_response = mocker.MagicMock() mock_output_item = mocker.MagicMock() @@ -138,6 +167,8 @@ async def test_user_memory_llm_call_fails( "utils.user_memory._collect_conversation_history", return_value="User: hello\nAssistant: hi", ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) + mocker.patch("utils.user_memory._save_cached_memory") mock_client.responses.create.side_effect = APIConnectionError( request=mocker.MagicMock() ) @@ -160,12 +191,89 @@ async def test_user_memory_empty_history( "utils.user_memory._collect_conversation_history", return_value="", ) + mocker.patch("utils.user_memory._get_cached_memory", return_value=None) result = await user_memory("user-1", mock_client, "provider/model") assert result == "" mock_client.responses.create.assert_not_called() +# --- Tests for _get_cached_memory() --- + + +def test_get_cached_memory_hit(mocker: MockerFixture) -> None: + """Test reading cached memory from the database.""" + mock_session = mocker.MagicMock() + mock_row = mocker.MagicMock() + mock_row.preferences = "User prefers bullet points" + mock_row.conversation_count = 3 + mock_session.__enter__.return_value.get.return_value = mock_row + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result == ("User prefers bullet points", 3) + + +def test_get_cached_memory_miss(mocker: MockerFixture) -> None: + """Test _get_cached_memory returns None when no row exists.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = None + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result is None + + +def test_get_cached_memory_db_error(mocker: MockerFixture) -> None: + """Test _get_cached_memory returns None on DB error.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + result = _get_cached_memory("user-1") + assert result is None + + +# --- Tests for _save_cached_memory() --- + + +def test_save_cached_memory_insert(mocker: MockerFixture) -> None: + """Test saving new user memory to the database.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = None + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + _save_cached_memory("user-1", "User prefers code examples", 2) + + session_ctx = mock_session.__enter__.return_value + session_ctx.add.assert_called_once() + session_ctx.commit.assert_called_once() + + +def test_save_cached_memory_update(mocker: MockerFixture) -> None: + """Test updating existing user memory in the database.""" + mock_session = mocker.MagicMock() + mock_row = mocker.MagicMock() + mock_session.__enter__.return_value.get.return_value = mock_row + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + _save_cached_memory("user-1", "Updated preferences", 5) + + assert mock_row.preferences == "Updated preferences" + assert mock_row.conversation_count == 5 + mock_session.__enter__.return_value.commit.assert_called_once() + + +def test_save_cached_memory_db_error(mocker: MockerFixture) -> None: + """Test _save_cached_memory handles DB errors gracefully.""" + mock_session = mocker.MagicMock() + mock_session.__enter__.return_value.get.side_effect = SQLAlchemyError("db error") + mocker.patch("utils.user_memory.get_session", return_value=mock_session) + + # Should not raise + _save_cached_memory("user-1", "prefs", 1) + + # --- Tests for _list_user_conversation_ids() --- From 4611633e58fc5f31a27d5489e61cc718897ade1c Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Sun, 22 Mar 2026 11:59:09 +0100 Subject: [PATCH 3/6] invalidate user memory cache correctly --- src/app/endpoints/query.py | 6 ++++- src/app/endpoints/streaming_query.py | 6 ++++- src/utils/user_memory.py | 23 +++++++++++++------ tests/unit/utils/test_user_memory.py | 34 ++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 9 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index aa014f5e..ce95e565 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -186,7 +186,11 @@ async def query_endpoint_handler( # Extract user preferences from conversation history user_preferences = await user_memory( - user_id, client, responses_params.model, _skip_userid_check + user_id, + client, + responses_params.model, + _skip_userid_check, + new_conversation=not query_request.conversation_id, ) if user_preferences: responses_params.instructions = build_instructions_with_preferences( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 323c12a4..953accbc 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -222,7 +222,11 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # Extract user preferences from conversation history user_preferences = await user_memory( - user_id, client, responses_params.model, _skip_userid_check + user_id, + client, + responses_params.model, + _skip_userid_check, + new_conversation=not query_request.conversation_id, ) if user_preferences: responses_params.instructions = build_instructions_with_preferences( diff --git a/src/utils/user_memory.py b/src/utils/user_memory.py index 31162d6b..6d42ffde 100644 --- a/src/utils/user_memory.py +++ b/src/utils/user_memory.py @@ -56,6 +56,7 @@ async def user_memory( # pylint: disable=unused-argument client: AsyncLlamaStackClient, model: str, skip_user_id_check: bool = False, + new_conversation: bool = False, ) -> str: """Extract user preferences from conversation history. @@ -71,6 +72,8 @@ async def user_memory( # pylint: disable=unused-argument model: The model ID to use for preference extraction. skip_user_id_check: Whether to skip user ID validation (unused, kept for interface consistency with the auth tuple). + new_conversation: When True, forces cache invalidation because a new + conversation is being created but not yet stored in the database. Returns: A string describing the user's preferences, or empty string if none found. @@ -81,13 +84,19 @@ async def user_memory( # pylint: disable=unused-argument logger.info("No conversation history for user %s", user_id) return "" - # Check database cache - cached = _get_cached_memory(user_id) - if cached is not None: - cached_preferences, cached_count = cached - if cached_count == current_count: - logger.info("Using cached user memory for user %s", user_id) - return cached_preferences + # Check database cache (skip when starting a new conversation since the + # new conversation row hasn't been stored yet and the count would match) + if not new_conversation: + cached = _get_cached_memory(user_id) + if cached is not None: + cached_preferences, cached_count = cached + if cached_count == current_count: + logger.info( + "Using cached user memory for user %s: %s", + user_id, + cached_preferences, + ) + return cached_preferences # Collect conversation history from Llama Stack history_text = await _collect_conversation_history( diff --git a/tests/unit/utils/test_user_memory.py b/tests/unit/utils/test_user_memory.py index 4215009a..2619de86 100644 --- a/tests/unit/utils/test_user_memory.py +++ b/tests/unit/utils/test_user_memory.py @@ -61,6 +61,40 @@ async def test_user_memory_returns_cached_result( mock_client.responses.create.assert_not_called() +@pytest.mark.asyncio +async def test_user_memory_skips_cache_on_new_conversation( + mock_client: MockType, + mocker: MockerFixture, +) -> None: + """Test user_memory skips cache when new_conversation=True.""" + mocker.patch( + "utils.user_memory._list_user_conversation_ids", + return_value=["conv-1", "conv-2"], + ) + mock_get_cache = mocker.patch("utils.user_memory._get_cached_memory") + mocker.patch( + "utils.user_memory._collect_conversation_history", + return_value="User: be brief\nAssistant: Ok", + ) + mocker.patch("utils.user_memory._save_cached_memory") + + mock_response = mocker.MagicMock() + mock_output_item = mocker.MagicMock() + mock_output_item.type = "message" + mock_output_item.content = "User prefers brief responses" + mock_response.output = [mock_output_item] + mock_client.responses.create.return_value = mock_response + + result = await user_memory( + "user-1", mock_client, "provider/model", new_conversation=True + ) + assert result == "User prefers brief responses" + # Cache should NOT have been checked + mock_get_cache.assert_not_called() + # LLM should have been called + mock_client.responses.create.assert_called_once() + + @pytest.mark.asyncio async def test_user_memory_invalidates_cache_on_count_change( mock_client: MockType, From daee6f6d98ebf75c305617a2bb1b95c5851966c5 Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Sun, 22 Mar 2026 12:22:07 +0100 Subject: [PATCH 4/6] fix unit tests --- tests/unit/app/endpoints/test_query.py | 24 +++++++++++++++++++ .../app/endpoints/test_streaming_query.py | 20 ++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 06ee6992..c9f820ba 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -165,6 +165,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -248,6 +252,10 @@ async def test_query_merges_inline_and_tool_rag_chunks_and_documents( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -320,6 +328,10 @@ async def test_successful_query_with_conversation( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await query_endpoint_handler( request=dummy_request, @@ -403,6 +415,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, @@ -469,6 +485,10 @@ async def test_query_with_topic_summary( mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, @@ -565,6 +585,10 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mocker.patch("app.endpoints.query.store_query_results") mocker.patch("app.endpoints.query.consume_query_tokens") mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await query_endpoint_handler( request=dummy_request, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 99dee264..92b00414 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -396,6 +396,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await streaming_query_endpoint_handler( request=dummy_request, @@ -483,6 +487,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) response = await streaming_query_endpoint_handler( request=dummy_request, @@ -581,6 +589,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, @@ -677,6 +689,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, @@ -777,6 +793,10 @@ async def mock_generate_response( "app.endpoints.streaming_query.normalize_conversation_id", return_value="123", ) + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) await streaming_query_endpoint_handler( request=dummy_request, From b9e8b3db629a2baf23e11c4079ecfafe51ece7c7 Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Sun, 22 Mar 2026 12:43:54 +0100 Subject: [PATCH 5/6] fix integration tests --- .../integration/endpoints/test_query_byok_integration.py | 9 +++++++++ tests/integration/endpoints/test_query_integration.py | 9 +++++++++ .../endpoints/test_streaming_query_byok_integration.py | 9 +++++++++ .../endpoints/test_streaming_query_integration.py | 9 +++++++++ 4 files changed, 36 insertions(+) diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index 40191821..f80423b8 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -140,6 +140,15 @@ def _build_base_mock_client(mocker: MockerFixture) -> Any: # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="mock_byok_client") def mock_byok_client_fixture( mocker: MockerFixture, diff --git a/tests/integration/endpoints/test_query_integration.py b/tests/integration/endpoints/test_query_integration.py index d96879c4..42f6fcbd 100644 --- a/tests/integration/endpoints/test_query_integration.py +++ b/tests/integration/endpoints/test_query_integration.py @@ -113,6 +113,15 @@ def mock_llama_stack_client_fixture( yield mock_client +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="patch_db_session", autouse=True) def patch_db_session_fixture( test_db_session: Session, diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index 5f58f603..354e96e0 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -98,6 +98,15 @@ async def _responses_create(**kwargs: Any) -> Any: # --------------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="patch_db_session", autouse=True) def patch_db_session_fixture( test_db_session: Session, diff --git a/tests/integration/endpoints/test_streaming_query_integration.py b/tests/integration/endpoints/test_streaming_query_integration.py index 05fba0a5..28ed4d16 100644 --- a/tests/integration/endpoints/test_streaming_query_integration.py +++ b/tests/integration/endpoints/test_streaming_query_integration.py @@ -15,6 +15,15 @@ from models.requests import Attachment, QueryRequest +@pytest.fixture(autouse=True) +def mock_user_memory(mocker: MockerFixture) -> None: + """Mock user_memory to avoid hitting the Llama Stack conversations API.""" + mocker.patch( + "app.endpoints.streaming_query.user_memory", + new=mocker.AsyncMock(return_value=""), + ) + + @pytest.fixture(name="mock_streaming_llama_stack_client") def mock_llama_stack_streaming_fixture( mocker: MockerFixture, From c6a029b777cbe5f03759656e50acf3610f465a87 Mon Sep 17 00:00:00 2001 From: Jan Cervenka Date: Tue, 24 Mar 2026 18:44:47 +0100 Subject: [PATCH 6/6] fix off by one conversation count --- src/app/endpoints/query.py | 2 +- src/utils/user_memory.py | 7 +++++-- tests/unit/utils/test_user_memory.py | 4 +++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index ce95e565..7773d4c3 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -307,7 +307,7 @@ async def retrieve_response( id=moderation_result.moderation_id, llm_response=moderation_result.message ) try: - print("yooooo", responses_params.model_dump(exclude_none=True)) + logger.info(responses_params.model_dump(exclude_none=True)) response = await client.responses.create( **responses_params.model_dump(exclude_none=True) ) diff --git a/src/utils/user_memory.py b/src/utils/user_memory.py index 6d42ffde..8d54c9e8 100644 --- a/src/utils/user_memory.py +++ b/src/utils/user_memory.py @@ -108,8 +108,11 @@ async def user_memory( # pylint: disable=unused-argument # Extract preferences using LLM preferences = await _extract_preferences(client, model, history_text) - # Cache the result in the database - _save_cached_memory(user_id, preferences, current_count) + # Cache the result in the database. When starting a new conversation, the + # conversation row hasn't been stored yet so current_count is off by one. + # Save count + 1 so the next request in the same conversation sees a match. + save_count = current_count + 1 if new_conversation else current_count + _save_cached_memory(user_id, preferences, save_count) if preferences: logger.info("User memory for user %s: %s", user_id, preferences) diff --git a/tests/unit/utils/test_user_memory.py b/tests/unit/utils/test_user_memory.py index 2619de86..0f81c1d6 100644 --- a/tests/unit/utils/test_user_memory.py +++ b/tests/unit/utils/test_user_memory.py @@ -76,7 +76,7 @@ async def test_user_memory_skips_cache_on_new_conversation( "utils.user_memory._collect_conversation_history", return_value="User: be brief\nAssistant: Ok", ) - mocker.patch("utils.user_memory._save_cached_memory") + mock_save = mocker.patch("utils.user_memory._save_cached_memory") mock_response = mocker.MagicMock() mock_output_item = mocker.MagicMock() @@ -93,6 +93,8 @@ async def test_user_memory_skips_cache_on_new_conversation( mock_get_cache.assert_not_called() # LLM should have been called mock_client.responses.create.assert_called_once() + # Count saved should be current_count + 1 (2 existing + 1 new = 3) + mock_save.assert_called_once_with("user-1", "User prefers brief responses", 3) @pytest.mark.asyncio