Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,7 @@ dev-tools/mcp-mock-server/.certs/
requirements.*.backup

# Local run files
local-run.yaml
local-run.yaml

datadir/
INSTRUCTIONS.md
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ dependencies = [
"pyasn1>=0.6.3", # LCORE-1490
# Used for system prompt template variable rendering
"jinja2>=3.1.0",
"anthropic>=0.86.0",
]


Expand Down
15 changes: 15 additions & 0 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
ShieldModerationResult,
TurnSummary,
)
from utils.user_memory import build_instructions_with_preferences, user_memory
from utils.vector_search import build_rag_context

logger = get_logger(__name__)
Expand Down Expand Up @@ -183,6 +184,19 @@ async def query_endpoint_handler(
inline_rag_context=inline_rag_context.context_text,
)

# Extract user preferences from conversation history
user_preferences = await user_memory(
user_id,
client,
responses_params.model,
_skip_userid_check,
new_conversation=not query_request.conversation_id,
)
if user_preferences:
responses_params.instructions = build_instructions_with_preferences(
responses_params.instructions, user_preferences
)

# Handle Azure token refresh if needed
if (
responses_params.model.startswith("azure")
Expand Down Expand Up @@ -293,6 +307,7 @@ async def retrieve_response(
id=moderation_result.moderation_id, llm_response=moderation_result.message
)
try:
logger.info(responses_params.model_dump(exclude_none=True))
response = await client.responses.create(
**responses_params.model_dump(exclude_none=True)
)
Expand Down
16 changes: 15 additions & 1 deletion src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Streaming query handler using Responses API."""
"""Streaming query handler using Responses API.""" # pylint: disable=too-many-lines

import asyncio
import datetime
Expand Down Expand Up @@ -95,6 +95,7 @@
from utils.suid import get_suid, normalize_conversation_id
from utils.token_counter import TokenCounter
from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary
from utils.user_memory import build_instructions_with_preferences, user_memory
from utils.vector_search import build_rag_context

logger = get_logger(__name__)
Expand Down Expand Up @@ -219,6 +220,19 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
inline_rag_context=inline_rag_context.context_text,
)

# Extract user preferences from conversation history
user_preferences = await user_memory(
user_id,
client,
responses_params.model,
_skip_userid_check,
new_conversation=not query_request.conversation_id,
)
if user_preferences:
responses_params.instructions = build_instructions_with_preferences(
responses_params.instructions, user_preferences
)

# Handle Azure token refresh if needed
if (
responses_params.model.startswith("azure")
Expand Down
20 changes: 20 additions & 0 deletions src/models/database/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods
topic_summary: Mapped[str] = mapped_column(default="")


class UserMemory(Base): # pylint: disable=too-few-public-methods
"""Model for storing cached user preference extractions."""

__tablename__ = "user_memory"

# One row per user
user_id: Mapped[str] = mapped_column(primary_key=True)

# The extracted preferences string (empty string if none found)
preferences: Mapped[str] = mapped_column(default="")

# Conversation count at the time of extraction (used for cache invalidation)
conversation_count: Mapped[int] = mapped_column(default=0)

updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(), # pylint: disable=not-callable
)


class UserTurn(Base): # pylint: disable=too-few-public-methods
"""Model for storing turn-level metadata."""

Expand Down
Loading
Loading