diff --git a/athena/app/api/main.py b/athena/app/api/main.py index 9b884b2..73cf0b2 100644 --- a/athena/app/api/main.py +++ b/athena/app/api/main.py @@ -1,34 +1,11 @@ -from fastapi import APIRouter, HTTPException, Query, Request +from fastapi import APIRouter -from athena.configuration.config import settings +from athena.app.api.routes import episodic_memory, semantic_memory api_router = APIRouter() - - -@api_router.get("/memory/search", tags=["memory"]) -async def search_memory( - request: Request, - q: str, - field: str = "task_state", - limit: int = Query(default=settings.MEMORY_SEARCH_LIMIT, ge=1, le=100), -): - services = getattr(request.app.state, "service", {}) - memory_service = services.get("memory_service") - if memory_service is None: - raise HTTPException(status_code=503, detail="Memory service not initialized") - if field not in {"task_state", "task", "state"}: - raise HTTPException(status_code=400, detail=f"Invalid field: {field}") - results = await memory_service.search_memory(q, limit=limit, field=field) - return [m.model_dump() for m in results] - - -@api_router.get("/memory/{memory_id}", tags=["memory"]) -async def get_memory(request: Request, memory_id: str): - services = getattr(request.app.state, "service", {}) - memory_service = services.get("memory_service") - if memory_service is None: - raise HTTPException(status_code=503, detail="Memory service not initialized") - result = await memory_service.get_memory_by_key(memory_id) - if result is None: - raise HTTPException(status_code=404, detail="Memory not found") - return result.model_dump() +api_router.include_router( + episodic_memory.router, prefix="/episodic-memory", tags=["episodic-memory"] +) +api_router.include_router( + semantic_memory.router, prefix="/semantic-memory", tags=["semantic-memory"] +) diff --git a/athena/app/api/routes/__init__.py b/athena/app/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/athena/app/api/routes/episodic_memory.py b/athena/app/api/routes/episodic_memory.py new file mode 100644 index 0000000..44e49de --- /dev/null +++ b/athena/app/api/routes/episodic_memory.py @@ -0,0 +1,36 @@ +from fastapi import APIRouter, HTTPException, Query, Request + +from athena.configuration.config import settings + +router = APIRouter() + + +@router.get("/search/", summary="Search Episodic Memory from database") +async def search_episodic_memory( + request: Request, + q: str, + field: str = "task_state", + limit: int = Query(default=settings.MEMORY_SEARCH_LIMIT, ge=1, le=100), +): + # Get the episodic memory service from the app state + episodic_memory_service = request.app.state.service["episodic_memory_service"] + + # Validate field + if field not in {"task_state", "task", "state"}: + raise HTTPException(status_code=400, detail=f"Invalid field: {field}") + + # Perform the search + results = await episodic_memory_service.search_memory(q, limit=limit, field=field) + return [m.model_dump() for m in results] + + +@router.get("/{memory_id}/", summary="Get Episodic Memory from database") +async def get_episodic_memory(request: Request, memory_id: str): + # Get the episodic memory service from the app state + episodic_memory_service = request.app.state.service["episodic_memory_service"] + + # Fetch the memory by ID + result = await episodic_memory_service.get_memory_by_key(memory_id) + if result is None: + raise HTTPException(status_code=404, detail="Episodic memory not found") + return result.model_dump() diff --git a/athena/app/api/routes/semantic_memory.py b/athena/app/api/routes/semantic_memory.py new file mode 100644 index 0000000..a6a1bf3 --- /dev/null +++ b/athena/app/api/routes/semantic_memory.py @@ -0,0 +1,129 @@ +from typing import List + +from fastapi import APIRouter, Request + +from athena.app.services.semantic_memory_service import SemanticMemoryService +from athena.models.query import Query +from athena.models.requests.semantic_memory import ( + StoreSemanticMemoryRequest, +) +from athena.models.responses.response import Response +from athena.models.responses.semantic_memory import SemanticMemoryResponse +from athena.models.semantic_memory import SemanticMemoryUnit + +router = APIRouter() + + +@router.post("/store/", summary="Store semantic memory", response_model=Response) +async def store_semantic_memory( + request: Request, + body: StoreSemanticMemoryRequest, +): + """ + Store a semantic memory unit with vector embeddings. + + Args: + request: FastAPI request object + body: Request containing repository_id, query, and memory_context + + Returns: + Success message + """ + # Get the semantic memory service from the app state + semantic_memory_service: SemanticMemoryService = request.app.state.service[ + "semantic_memory_service" + ] + + # Create SemanticMemoryUnit from request + memory_unit = SemanticMemoryUnit(query=body.query, contexts=body.contexts) + + # Store the memory + await semantic_memory_service.store_memory( + repository_id=body.repository_id, + memory=memory_unit, + ) + + return Response() + + +@router.get( + "/retrieve/{repository_id}/", + response_model=Response[List[SemanticMemoryResponse]], + summary="Retrieve semantic memories", +) +async def retrieve_semantic_memory( + request: Request, + repository_id: int, + essential_query: str, + extra_requirements: str = "", + purpose: str = "", +): + """ + Retrieve semantic memories similar to the given query. + + Uses weighted multi-vector similarity search across: + - Essential query (50% weight) + - Extra requirements (25% weight) + - Purpose (25% weight) + + Args: + request: FastAPI request object + repository_id: Repository identifier + essential_query: The main query string + extra_requirements: Additional constraints or filters for the query + purpose: The intent or context behind the query + + Returns: + List of semantic memory units ordered by similarity + """ + # Get the semantic memory service from the app state + semantic_memory_service = request.app.state.service["semantic_memory_service"] + + # Retrieve memories + results = await semantic_memory_service.retrieve_memory( + repository_id=repository_id, + query=Query( + essential_query=essential_query, extra_requirements=extra_requirements, purpose=purpose + ), + ) + + # Convert to response models + return Response( + data=[ + SemanticMemoryResponse( + query_essential_query=memory.query.essential_query, + query_extra_requirements=memory.query.extra_requirements, + query_purpose=memory.query.purpose, + memory_context_contexts=memory.contexts, + ) + for memory in results + ] + ) + + +@router.delete( + "/{repository_id}/", summary="Delete semantic memories by repository", response_model=Response +) +async def delete_semantic_memories_by_repository( + request: Request, + repository_id: int, +): + """ + Delete all semantic memories for a given repository. + + Args: + request: FastAPI request object + repository_id: Repository identifier to filter memories for deletion + + Returns: + Success message + """ + # Get the semantic memory service from the app state + semantic_memory_service: SemanticMemoryService = request.app.state.service[ + "semantic_memory_service" + ] + + # Delete memories + await semantic_memory_service.delete_memories_by_repository(repository_id=repository_id) + + return Response() diff --git a/athena/app/dependencies.py b/athena/app/dependencies.py index d292359..c5a92a5 100644 --- a/athena/app/dependencies.py +++ b/athena/app/dependencies.py @@ -5,9 +5,10 @@ from athena.app.services.base_service import BaseService from athena.app.services.database_service import DatabaseService from athena.app.services.embedding_service import EmbeddingService +from athena.app.services.episodic_memory_service import EpisodicMemoryService +from athena.app.services.episodic_memory_storage_service import EpisodicMemoryStorageService from athena.app.services.llm_service import LLMService -from athena.app.services.memory_service import MemoryService -from athena.app.services.memory_storage_service import MemoryStorageService +from athena.app.services.semantic_memory_service import SemanticMemoryService from athena.configuration.config import settings @@ -33,8 +34,6 @@ def initialize_services() -> Dict[str, BaseService]: llm_service = LLMService( settings.MODEL_NAME, settings.MODEL_TEMPERATURE, - settings.MODEL_MAX_INPUT_TOKENS, - settings.MODEL_MAX_OUTPUT_TOKENS, settings.OPENAI_FORMAT_API_KEY, settings.OPENAI_FORMAT_BASE_URL, settings.ANTHROPIC_API_KEY, @@ -42,24 +41,29 @@ def initialize_services() -> Dict[str, BaseService]: settings.GOOGLE_APPLICATION_CREDENTIALS, ) - embedding_service = None - api_key = settings.EMBEDDING_API_KEY or settings.MISTRAL_API_KEY - if settings.EMBEDDING_MODEL and api_key and settings.EMBEDDING_BASE_URL: - embedding_service = EmbeddingService( - model=settings.EMBEDDING_MODEL, - api_key=api_key, - base_url=settings.EMBEDDING_BASE_URL, - embed_dim=settings.EMBEDDING_DIM or 1024, - ) + embedding_service = EmbeddingService( + model=settings.EMBEDDING_MODEL, + api_key=settings.EMBEDDING_API_KEY, + base_url=settings.EMBEDDING_BASE_URL, + embed_dim=settings.EMBEDDING_DIM, + ) - memory_store = MemoryStorageService(database_service.get_sessionmaker(), embedding_service) + episodic_memory_store = EpisodicMemoryStorageService( + database_service.get_sessionmaker(), embedding_service + ) - memory_service = MemoryService( - storage_backend=settings.MEMORY_STORAGE_BACKEND, store=memory_store + episodic_memory_service = EpisodicMemoryService( + database_service=database_service, + storage_backend=settings.MEMORY_STORAGE_BACKEND, + store=episodic_memory_store, + ) + semantic_memory_service = SemanticMemoryService( + database_service=database_service, embedding_service=embedding_service ) return { "llm_service": llm_service, "database_service": database_service, - "memory_service": memory_service, + "episodic_memory_service": episodic_memory_service, + "semantic_memory_service": semantic_memory_service, } diff --git a/athena/app/main.py b/athena/app/main.py index c0a84d0..879184f 100644 --- a/athena/app/main.py +++ b/athena/app/main.py @@ -75,13 +75,5 @@ def custom_generate_unique_id(route: APIRoute) -> str: @app.get("/health", tags=["health"]) -async def health_check(): - services = getattr(app.state, "service", {}) - db = services.get("database_service") - db_ok = await db.health_check() if db is not None else False - status = "healthy" if db_ok else "degraded" - return { - "status": status, - "database": db_ok, - "timestamp": datetime.now(timezone.utc).isoformat(), - } +def health_check(): + return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()} diff --git a/athena/app/services/database_service.py b/athena/app/services/database_service.py index 57c903e..4393e55 100644 --- a/athena/app/services/database_service.py +++ b/athena/app/services/database_service.py @@ -1,74 +1,41 @@ -import asyncio - from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlmodel import SQLModel from athena.app.services.base_service import BaseService -from athena.configuration.config import settings from athena.utils.logger_manager import get_logger class DatabaseService(BaseService): - def __init__(self, DATABASE_URL: str, max_retries: int = 5, initial_backoff: float = 1.0): + def __init__(self, DATABASE_URL: str): self.engine = create_async_engine(DATABASE_URL, echo=True) self.sessionmaker = async_sessionmaker( self.engine, expire_on_commit=False, class_=AsyncSession ) self._logger = get_logger(__name__) - self._max_retries = max_retries - self._initial_backoff = initial_backoff + + async def create_vector_extension( + self, + ): + async with self.engine.begin() as conn: + # Ensure pgvector extension exists + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + self._logger.info("pgvector extension ensured.") # Create the database and tables async def create_db_and_tables(self): async with self.engine.begin() as conn: - # Ensure pgvector extension exists (safe to ignore if unavailable) - try: - await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - except Exception: - pass + # Create all tables await conn.run_sync(SQLModel.metadata.create_all) - # Create ivfflat indexes for vector columns (if extension present) - try: - lists = settings.EMBEDDING_IVFFLAT_LISTS or 100 - await conn.exec_driver_sql( - f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON memory_units USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = {lists})" - ) - await conn.exec_driver_sql( - f"CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON memory_units USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = {lists})" - ) - await conn.exec_driver_sql( - f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON memory_units USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = {lists})" - ) - except Exception: - # Index creation failed (likely no pgvector). Continue without indexes. - pass + self._logger.info("Database and tables created.") async def start(self): """ Start the database service by creating the database and tables. This method is called when the service is initialized. """ - attempt = 0 - backoff = self._initial_backoff - while True: - try: - await self.create_db_and_tables() - self._logger.info("Database and tables created successfully.") - break - except Exception as exc: - attempt += 1 - if attempt > self._max_retries: - self._logger.error( - f"Database start failed after {self._max_retries} retries: {exc}" - ) - raise - self._logger.warning( - f"Database start failed (attempt {attempt}/{self._max_retries}): {exc}. " - f"Retrying in {backoff:.1f}s..." - ) - await asyncio.sleep(backoff) - backoff *= 2 + await self.create_vector_extension() + await self.create_db_and_tables() async def close(self): """ @@ -80,13 +47,3 @@ async def close(self): def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]: """Return the async sessionmaker for dependency injection.""" return self.sessionmaker - - async def health_check(self) -> bool: - """Perform a lightweight connectivity check (SELECT 1).""" - try: - async with self.engine.connect() as conn: - await conn.exec_driver_sql("SELECT 1") - return True - except Exception as exc: - self._logger.warning(f"Database health_check failed: {exc}") - return False diff --git a/athena/app/services/embedding_service.py b/athena/app/services/embedding_service.py index 872394d..01781b9 100644 --- a/athena/app/services/embedding_service.py +++ b/athena/app/services/embedding_service.py @@ -1,6 +1,6 @@ from typing import Iterable, List -import requests +import httpx from athena.app.services.base_service import BaseService @@ -18,16 +18,21 @@ def __init__(self, model: str, api_key: str, base_url: str, embed_dim: int): self.api_key = api_key self.base_url = base_url.rstrip("/") self.embed_dim = embed_dim + self.client = httpx.AsyncClient(timeout=60.0) - def embed(self, inputs: Iterable[str]) -> List[List[float]]: + async def embed(self, inputs: Iterable[str]) -> List[List[float]]: data = {"model": self.model, "input": list(inputs), "output_dimension": self.embed_dim} headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } url = f"{self.base_url}/v1/embeddings" - resp = requests.post(url, json=data, headers=headers, timeout=60) + resp = await self.client.post(url, json=data, headers=headers) resp.raise_for_status() payload = resp.json() vectors = [item["embedding"] for item in payload.get("data", [])] return vectors + + async def close(self): + """Close the httpx client.""" + await self.client.aclose() diff --git a/athena/app/services/memory_extraction_service.py b/athena/app/services/episodic_memory_extraction_service.py similarity index 96% rename from athena/app/services/memory_extraction_service.py rename to athena/app/services/episodic_memory_extraction_service.py index b6a5071..66b5936 100644 --- a/athena/app/services/memory_extraction_service.py +++ b/athena/app/services/episodic_memory_extraction_service.py @@ -2,7 +2,6 @@ import html import json import re -import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -10,19 +9,19 @@ from tqdm import tqdm from athena.app.services.base_service import BaseService +from athena.app.services.episodic_memory_storage_service import EpisodicMemoryStorageService from athena.app.services.llm_service import LLMService -from athena.app.services.memory_storage_service import MemoryStorageService -from athena.models import ( +from athena.models.episodic_memory import ( Action, - MemorySource, - MemoryTimestamp, - MemoryUnit, - Message, + EpisodicMemorySource, + EpisodicMemoryTimestamp, + EpisodicMemoryUnit, Result, State, Task, ) -from athena.prompts.memory_extraction import ( +from athena.models.message import Message +from athena.prompts.episodic_memory_extraction import ( ACTION_EXTRACTION_PROMPT, ACTION_JUDGE_PROMPT, RESULT_EXTRACTION_PROMPT, @@ -43,7 +42,7 @@ def __init__(self, source_name: str, run_id: str): super().__init__(f"Failed to extract memory units from {run_id} in {source_name}") -class MemoryExtractionService(BaseService): +class EpisodicMemoryExtractionService(BaseService): """ Service for extracting structured memory units from interaction trajectories. @@ -64,12 +63,12 @@ class MemoryExtractionService(BaseService): The service follows a pipeline pattern: 1. Data Source -> Load trajectories 2. Extraction Strategy -> Extract components (task, action, result, state) - 3. Memory Unit Assembly -> Combine components into MemoryUnit + 3. Memory Unit Assembly -> Combine components into EpisodicMemoryUnit 4. Validation -> Ensure data quality and consistency 5. Storage -> Persist to memory system Usage: - service = MemoryExtractionService(llm_service) + service = EpisodicMemoryExtractionService(llm_service) memory_units = service.extract_from_trajectories(trajectory_data) """ @@ -78,7 +77,7 @@ def __init__( llm_service: LLMService, batch_size: int = 100, max_retries: int = 3, - memory_store: Optional[MemoryStorageService] = None, + memory_store: Optional[EpisodicMemoryStorageService] = None, ): """ Initialize the Memory Extraction Service. @@ -91,7 +90,7 @@ def __init__( self.llm_service = llm_service self.batch_size = batch_size self.max_retries = max_retries - self._extraction_cache: Dict[str, MemoryUnit] = {} + self._extraction_cache: Dict[str, EpisodicMemoryUnit] = {} self.memory_store = memory_store # self._logger = get_logger(__name__) @@ -106,7 +105,7 @@ def close(self): def extract_from_huggingface_trajectory_repository( # TODO: batch extraction self, repo_name: str, split: str - ) -> List[MemoryUnit]: + ) -> List[EpisodicMemoryUnit]: """ Extract memory units from a HuggingFace trajectory repository. @@ -185,7 +184,7 @@ def _pick(d: Dict[str, Any], keys) -> Optional[Any]: def _extract_memory_source( self, source: str, run_id: str, metadata: Optional[Dict[str, Any]] = None - ) -> MemorySource: + ) -> EpisodicMemorySource: """ Extract memory source for a trajectory. @@ -195,9 +194,9 @@ def _extract_memory_source( metadata: Optional additional metadata Returns: - MemorySource object + EpisodicMemorySource object """ - return MemorySource( + return EpisodicMemorySource( source_name=source, run_id=run_id, metadata=metadata or {}, @@ -206,9 +205,9 @@ def _extract_memory_source( def _extract_memory_units_by_action_windows( self, messages: List[Message], - memory_source: MemorySource, - ) -> List[MemoryUnit]: - ordered_memory_units: List[MemoryUnit] = [] + memory_source: EpisodicMemorySource, + ) -> List[EpisodicMemoryUnit]: + ordered_memory_units: List[EpisodicMemoryUnit] = [] window_msgs: List[Message] = [] window_first_action: Optional[Message] = None @@ -285,11 +284,11 @@ def _is_action_message(self, message: Message) -> bool: def _create_memory_unit( self, - source: MemorySource, + source: EpisodicMemorySource, task: Task, window_msgs: List[Message], - prior_units: List[MemoryUnit], - ) -> MemoryUnit: + prior_units: List[EpisodicMemoryUnit], + ) -> EpisodicMemoryUnit: """ Extract a single memory unit from a window of messages. - Synthesize state.done from prior actions @@ -301,9 +300,8 @@ def _create_memory_unit( state_done = self._synthesize_state_done_from_context(prior_units, task) state_todo = self._synthesize_state_todo_from_window(window_msgs, task, state_done, action) result = self._extract_result_from_window(window_msgs, action) - return MemoryUnit( - memory_id=str(uuid.uuid4()), - timestamp=MemoryTimestamp( + return EpisodicMemoryUnit( + timestamp=EpisodicMemoryTimestamp( created_at=datetime.now(timezone.utc), updated_at=None, invalid_at=None, @@ -384,7 +382,7 @@ def s(x): def _synthesize_state_from_context( self, - prior_units: List[MemoryUnit], + prior_units: List[EpisodicMemoryUnit], task: Task, window_msgs: List[Message], current_action: Optional[Action] = None, @@ -398,7 +396,9 @@ def _synthesize_state_from_context( ) return State(done=state_done, todo=state_todo) # TODO: open_file, working_dir - def _synthesize_state_done_from_context(self, prior_units: List[MemoryUnit], task: Task) -> str: + def _synthesize_state_done_from_context( + self, prior_units: List[EpisodicMemoryUnit], task: Task + ) -> str: """ Summarize previous context into a concise `state.done` string. Summary of what has ALREADY BEEN COMPLETED (no plans). @@ -868,7 +868,7 @@ def _repair_json_with_llm(self, raw: str, expect_key: Optional[str]) -> str: if __name__ == "__main__": - service = MemoryExtractionService( + service = EpisodicMemoryExtractionService( llm_service=LLMService( model_name="vertex:gemini-2.5-flash", model_temperature=0.0, diff --git a/athena/app/services/memory_service.py b/athena/app/services/episodic_memory_service.py similarity index 67% rename from athena/app/services/memory_service.py rename to athena/app/services/episodic_memory_service.py index af740ac..bd82736 100644 --- a/athena/app/services/memory_service.py +++ b/athena/app/services/episodic_memory_service.py @@ -2,11 +2,13 @@ from typing import List, Optional, Sequence from athena.app.services.base_service import BaseService -from athena.app.services.memory_storage_service import MemoryStorageService -from athena.models.memory import MemoryUnit +from athena.app.services.database_service import DatabaseService +from athena.app.services.episodic_memory_storage_service import EpisodicMemoryStorageService +from athena.configuration.config import settings +from athena.models.episodic_memory import EpisodicMemoryUnit -class MemoryService(BaseService): +class EpisodicMemoryService(BaseService): """ Memory Service for managing memory information for software engineering agents. @@ -17,7 +19,7 @@ class MemoryService(BaseService): Design Principles: - Store complete execution trajectories as memory units - - Support semantic search for relevant past experiences + - Support semantic search for relevant experience - Enable agents to learn from previous successes and failures - Provide scalable storage backend options (in-memory, database, vector store) - Support deduplication and relevance ranking of memories @@ -27,7 +29,10 @@ class MemoryService(BaseService): """ def __init__( - self, storage_backend: str = "in_memory", store: Optional[MemoryStorageService] = None + self, + database_service: DatabaseService, + storage_backend: str = "in_memory", + store: Optional[EpisodicMemoryStorageService] = None, ): """ Initialize the Memory Service with a specified storage backend. @@ -38,14 +43,33 @@ def __init__( - "database": Persistent database storage - "vector": Vector database for semantic search """ + self.database_service = database_service self.storage_backend = storage_backend self._store = store + async def create_vector_field_index(self): + async with self.database_service.engine.begin() as conn: + # Create ivfflat indexes for vector columns (if extension present) + lists = settings.EPISODE_MEMORY_EMBEDDING_IVFFLAT_LISTS + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON episodic_memories USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = {lists})" + ) + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON episodic_memories USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = {lists})" + ) + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON episodic_memories USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = {lists})" + ) + async def start(self): """Initialize the storage backend if needed and validate configuration.""" + # Create vector extension + await self.create_vector_field_index() + + # Validate storage backend configuration if self.storage_backend in {"database", "vector"} and self._store is None: raise RuntimeError( - "MemoryService requires a storage store for database/vector backends" + "EpisodicMemoryService requires a storage store for database/vector backends" ) if self._store is not None and hasattr(self._store, "start"): if inspect.iscoroutinefunction(self._store.start): @@ -61,12 +85,12 @@ async def close(self): else: self._store.close() # type: ignore[misc] - async def store_memory(self, memory_unit: MemoryUnit) -> None: + async def store_memory(self, memory_unit: EpisodicMemoryUnit) -> None: """ Store a memory unit in the memory service. Args: - memory_unit: The MemoryUnit object containing task, state, + memory_unit: The EpisodicMemoryUnit object containing task, state, action, and result information to be stored. Deduplication is handled at the database layer on memory_id via upsert. @@ -78,7 +102,7 @@ async def store_memory(self, memory_unit: MemoryUnit) -> None: except Exception: return - async def store_memories(self, memory_units: Sequence[MemoryUnit]) -> None: + async def store_memories(self, memory_units: Sequence[EpisodicMemoryUnit]) -> None: """Bulk store multiple memory units efficiently.""" if self._store is None or not memory_units: return @@ -89,7 +113,7 @@ async def store_memories(self, memory_units: Sequence[MemoryUnit]) -> None: async def search_memory( self, query: str, limit: int = 10, field: str = "task_state" - ) -> List[MemoryUnit]: + ) -> List[EpisodicMemoryUnit]: """ Search for relevant memory units based on a query. @@ -98,7 +122,7 @@ async def search_memory( limit: Maximum number of results to return (default: 10) Returns: - List of MemoryUnit objects matching the search query, + List of EpisodicMemoryUnit objects matching the search query, ordered by relevance. This method should support semantic search across multiple @@ -112,7 +136,7 @@ async def search_memory( except Exception: return [] - async def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: + async def get_memory_by_key(self, key: str) -> Optional[EpisodicMemoryUnit]: """ Retrieve a specific memory unit by its memory id. @@ -120,7 +144,7 @@ async def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: key: The id of the memory unit Returns: - The MemoryUnit object if found, None otherwise. + The EpisodicMemoryUnit object if found, None otherwise. """ # Here we treat key as memory_id for simplicity if self._store is None: @@ -130,12 +154,12 @@ async def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: except Exception: return None - async def get_all_memories(self) -> List[MemoryUnit]: + async def get_all_memories(self) -> List[EpisodicMemoryUnit]: """ Retrieve all memory units stored in the service. Returns: - List of all MemoryUnit objects in the service. + List of all EpisodicMemoryUnit objects in the service. """ if self._store is None: return [] diff --git a/athena/app/services/memory_storage_service.py b/athena/app/services/episodic_memory_storage_service.py similarity index 75% rename from athena/app/services/memory_storage_service.py rename to athena/app/services/episodic_memory_storage_service.py index 1d2e308..2a83f76 100644 --- a/athena/app/services/memory_storage_service.py +++ b/athena/app/services/episodic_memory_storage_service.py @@ -7,8 +7,8 @@ from athena.app.services.base_service import BaseService from athena.app.services.embedding_service import EmbeddingService from athena.configuration.config import settings -from athena.entity.memory import MemoryUnitDB -from athena.models.memory import MemoryUnit +from athena.entity.episodic_memory import EpisodicMemoryUnitDB +from athena.models.episodic_memory import EpisodicMemoryUnit def _ensure_dim(vec: List[float]) -> List[float]: @@ -21,7 +21,7 @@ def _ensure_dim(vec: List[float]) -> List[float]: return vec + [0.0] * (dim - len(vec)) -class MemoryStorageService(BaseService): +class EpisodicMemoryStorageService(BaseService): """ Postgres-backed memory store using SQLModel and optional embeddings for semantic search. @@ -39,7 +39,7 @@ def __init__( self._embeddings = embedding_service self._max = max_stored_units - async def upsert(self, units: List[MemoryUnit]) -> None: + async def upsert(self, units: List[EpisodicMemoryUnit]) -> None: if not units: return async with self._sessionmaker() as session: @@ -47,7 +47,7 @@ async def upsert(self, units: List[MemoryUnit]) -> None: await self._upsert_one(session, u) await session.commit() - async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: + async def _upsert_one(self, session: AsyncSession, unit: EpisodicMemoryUnit) -> None: # Prepare embeddings if configured task_text = self._serialize_task(unit) state_text = self._serialize_state(unit) @@ -56,7 +56,7 @@ async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: task_state_vec: Optional[List[float]] = None if self._embeddings is not None: combined_text = self._serialize_task_state(unit) - vecs = self._embeddings.embed([task_text, state_text, combined_text]) + vecs = await self._embeddings.embed([task_text, state_text, combined_text]) if len(vecs) >= 3: task_vec, state_vec, task_state_vec = ( _ensure_dim(vecs[0]), @@ -66,18 +66,20 @@ async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: # Check existing existing = await session.scalar( - select(MemoryUnitDB).where(col(MemoryUnitDB.memory_id) == unit.memory_id) + select(EpisodicMemoryUnitDB).where( + col(EpisodicMemoryUnitDB.memory_id) == unit.memory_id + ) ) if existing is None: - row = MemoryUnitDB.from_memory_unit(unit) + row = EpisodicMemoryUnitDB.from_memory_unit(unit) row.task_embedding = task_vec row.state_embedding = state_vec row.task_state_embedding = task_state_vec session.add(row) else: # Update fields - fresh = MemoryUnitDB.from_memory_unit(unit) + fresh = EpisodicMemoryUnitDB.from_memory_unit(unit) for attr in ( "memory_created_at", "memory_updated_at", @@ -114,42 +116,44 @@ async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: async def search_by_text( self, text: str, field: str = "task_state", limit: int = 10 - ) -> List[MemoryUnit]: + ) -> List[EpisodicMemoryUnit]: if self._embeddings is None: return [] - q_vec = _ensure_dim(self._embeddings.embed([text])[0]) + + vectors = await self._embeddings.embed([text]) + q_vec = _ensure_dim(vectors[0]) async with self._sessionmaker() as session: # Choose column by field col_expr = { - "task": MemoryUnitDB.task_embedding, - "state": MemoryUnitDB.state_embedding, - "task_state": MemoryUnitDB.task_state_embedding, - }.get(field, MemoryUnitDB.task_state_embedding) + "task": EpisodicMemoryUnitDB.task_embedding, + "state": EpisodicMemoryUnitDB.state_embedding, + "task_state": EpisodicMemoryUnitDB.task_state_embedding, + }.get(field, EpisodicMemoryUnitDB.task_state_embedding) # Order by cosine distance using pgvector `<=>` res = await session.execute( - select(MemoryUnitDB) + select(EpisodicMemoryUnitDB) .where(col_expr.is_not(None)) .order_by(col_expr.cosine_distance(q_vec)) .limit(limit) ) - rows: List[MemoryUnitDB] = list(res.scalars()) + rows: List[EpisodicMemoryUnitDB] = list(res.scalars()) return [r.to_memory_unit() for r in rows] - async def get_by_memory_id(self, memory_id: str) -> Optional[MemoryUnit]: + async def get_by_memory_id(self, memory_id: str) -> Optional[EpisodicMemoryUnit]: async with self._sessionmaker() as session: row = await session.scalar( - select(MemoryUnitDB).where(col(MemoryUnitDB.memory_id) == memory_id) + select(EpisodicMemoryUnitDB).where(col(EpisodicMemoryUnitDB.memory_id) == memory_id) ) return None if row is None else row.to_memory_unit() - async def list_all(self, limit: Optional[int] = None) -> List[MemoryUnit]: + async def list_all(self, limit: Optional[int] = None) -> List[EpisodicMemoryUnit]: async with self._sessionmaker() as session: - stmt = select(MemoryUnitDB).order_by(MemoryUnitDB.id.desc()) + stmt = select(EpisodicMemoryUnitDB).order_by(EpisodicMemoryUnitDB.id.desc()) if limit is not None and limit > 0: stmt = stmt.limit(limit) res = await session.execute(stmt) - rows: List[MemoryUnitDB] = list(res.scalars()) + rows: List[EpisodicMemoryUnitDB] = list(res.scalars()) return [r.to_memory_unit() for r in rows] async def clear_all(self) -> None: @@ -159,13 +163,13 @@ async def clear_all(self) -> None: await session.commit() @staticmethod - def _serialize_task(u: MemoryUnit) -> str: + def _serialize_task(u: EpisodicMemoryUnit) -> str: t = u.task parts = [t.issue_title, t.issue_type, t.repository, t.issue_body, t.issue_comments] return "\n".join([p for p in parts if p]) @staticmethod - def _serialize_state(u: MemoryUnit) -> str: + def _serialize_state(u: EpisodicMemoryUnit) -> str: s = u.state parts = [s.done, s.todo] if s.open_file: @@ -177,9 +181,11 @@ def _serialize_state(u: MemoryUnit) -> str: return "\n".join([str(p) for p in parts if p]) @staticmethod - def _serialize_task_state(u: MemoryUnit) -> str: # TODO: text templates for static embeddings + def _serialize_task_state( + u: EpisodicMemoryUnit, + ) -> str: # TODO: text templates for static embeddings return ( - MemoryStorageService._serialize_task(u) + EpisodicMemoryStorageService._serialize_task(u) + "\n\n" - + MemoryStorageService._serialize_state(u) + + EpisodicMemoryStorageService._serialize_state(u) ) diff --git a/athena/app/services/llm_service.py b/athena/app/services/llm_service.py index 419f016..3926ef7 100644 --- a/athena/app/services/llm_service.py +++ b/athena/app/services/llm_service.py @@ -1,5 +1,6 @@ from typing import Optional +from google.oauth2.service_account import Credentials from langchain_anthropic import ChatAnthropic from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI @@ -14,8 +15,6 @@ def __init__( self, model_name: str, model_temperature: float, - model_max_input_tokens: int, - model_max_output_tokens: int, openai_format_api_key: Optional[str] = None, openai_format_base_url: Optional[str] = None, anthropic_api_key: Optional[str] = None, @@ -25,8 +24,6 @@ def __init__( self.model = get_model( model_name, temperature=model_temperature, - max_input_tokens=model_max_input_tokens, - max_output_tokens=model_max_output_tokens, openai_format_api_key=openai_format_api_key, openai_format_base_url=openai_format_base_url, anthropic_api_key=anthropic_api_key, @@ -38,8 +35,6 @@ def __init__( def get_model( model_name: str, temperature: float, - max_input_tokens: int, - max_output_tokens: int, openai_format_api_key: Optional[str] = None, openai_format_base_url: Optional[str] = None, anthropic_api_key: Optional[str] = None, @@ -51,39 +46,45 @@ def get_model( model_name=model_name, api_key=anthropic_api_key, temperature=temperature, - max_tokens_to_sample=max_output_tokens, max_retries=3, ) elif model_name.startswith("vertex:"): # example: model_name="vertex:gemini-2.5-pro" vertex_model = model_name.split("vertex:", 1)[1] + print(google_application_credentials) + if isinstance(google_application_credentials, dict): + creds = Credentials.from_service_account_info( + google_application_credentials, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + else: + creds = Credentials.from_service_account_file( + google_application_credentials, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) return ChatVertexAI( model_name=vertex_model, project="prometheus-code-agent", location="us-central1", temperature=temperature, - max_output_tokens=max_output_tokens, max_retries=3, - credentials=google_application_credentials, + credentials=creds, ) elif "gemini" in model_name: return ChatGoogleGenerativeAI( model=model_name, api_key=gemini_api_key, temperature=temperature, - max_tokens=max_output_tokens, max_retries=3, ) else: """ - Use tiktoken_counter to ensure that the input messages do not exceed the maximum token limit. + Custom OpenAI chat model with specific configuration. """ return CustomChatOpenAI( - max_input_tokens=max_input_tokens, model=model_name, api_key=openai_format_api_key, base_url=openai_format_base_url, temperature=temperature, - max_tokens=max_output_tokens, max_retries=3, ) diff --git a/athena/app/services/semantic_memory_service.py b/athena/app/services/semantic_memory_service.py new file mode 100644 index 0000000..5466fa5 --- /dev/null +++ b/athena/app/services/semantic_memory_service.py @@ -0,0 +1,282 @@ +""" +Semantic Memory Service for storing and retrieving code knowledge. + +This service manages semantic memories which represent structured knowledge about code, +including queries and their corresponding memory contexts. It uses vector embeddings +for efficient semantic similarity search across three query components: +- Essential query: The core question or search intent +- Extra requirements: Additional constraints or specifications +- Purpose: The intended use case or goal + +The service leverages pgvector's IVFFlat indexing for fast approximate nearest neighbor +search on high-dimensional embeddings. +""" + +from typing import Sequence + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import delete, select + +from athena.app.services.base_service import BaseService +from athena.app.services.database_service import DatabaseService +from athena.app.services.embedding_service import EmbeddingService +from athena.configuration.config import settings +from athena.entity.context import ContextDB +from athena.entity.semantic_memory import SemanticMemoryUnitDB +from athena.models.context import Context +from athena.models.query import Query +from athena.models.semantic_memory import SemanticMemoryUnit + + +class SemanticMemoryService(BaseService): + """ + Manages semantic memory storage and retrieval using vector embeddings. + + This service provides functionality to: + - Store semantic memory units with vector embeddings + - Retrieve relevant memories using weighted multi-vector similarity search + - Maintain vector indexes for efficient similarity queries + + Attributes: + engine: SQLAlchemy async engine for database operations + embedding_service: Service for generating vector embeddings from text + """ + + def __init__(self, database_service: DatabaseService, embedding_service: EmbeddingService): + """ + Initialize the semantic memory service. + + Args: + database_service: Database service providing engine and session management + embedding_service: Service for generating embeddings from text queries + """ + self.engine = database_service.engine + self.embedding_service = embedding_service + + async def create_semantic_memory_vector_indexes(self): + """ + Create vector indexes for semantic memory table using IVFFlat algorithm. + + Creates three separate IVFFlat indexes for efficient cosine similarity search: + 1. Essential query embedding index (highest priority) + 2. Extra requirements embedding index (medium priority) + 3. Purpose embedding index (medium priority) + + The number of IVFFlat lists is configured via SEMANTIC_MEMORY_EMBEDDING_IVFFLAT_LISTS. + More lists improve query speed but increase index build time and memory usage. + """ + async with self.engine.begin() as conn: + # Get lists parameter from settings (default 100) + lists = settings.SEMANTIC_MEMORY_EMBEDDING_IVFFLAT_LISTS + + # Create ivfflat indexes for each query component embedding + # Essential query embedding (highest priority, most important for retrieval) + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_semantic_memories_essential_query_embedding " + f"ON semantic_memories USING ivfflat (essential_query_embedding vector_cosine_ops) " + f"WITH (lists = {lists})" + ) + + # Extra requirements embedding (medium priority) + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_semantic_memories_extra_requirements_embedding " + f"ON semantic_memories USING ivfflat (extra_requirements_embedding vector_cosine_ops) " + f"WITH (lists = {lists})" + ) + + # Purpose embedding (medium priority) + await conn.exec_driver_sql( + f"CREATE INDEX IF NOT EXISTS idx_semantic_memories_purpose_embedding " + f"ON semantic_memories USING ivfflat (purpose_embedding vector_cosine_ops) " + f"WITH (lists = {lists})" + ) + + async def start(self): + """Initialize the service by creating vector indexes.""" + await self.create_semantic_memory_vector_indexes() + + async def get_query_embeddings( + self, query: Query + ) -> tuple[list[float], list[float], list[float]]: + """ + Generate embeddings for all three components of a query. + + Args: + query: Query object containing essential_query, extra_requirements, and purpose + + Returns: + Tuple of three embedding vectors in order: + (essential_query_embedding, extra_requirements_embedding, purpose_embedding) + """ + texts = [query.essential_query, query.extra_requirements, query.purpose] + embeddings = await self.embedding_service.embed(texts) + return embeddings[0], embeddings[1], embeddings[2] + + async def store_memory(self, repository_id: int, memory: SemanticMemoryUnit): + """ + Store a semantic memory unit with its vector embeddings and contexts. + + This method: + 1. Generates embeddings for the query components + 2. Creates a database record with query and overview information + 3. Stores each context separately in the contexts table + 4. Persists everything to the database in a transaction + + Args: + repository_id: Repository identifier + memory: SemanticMemoryUnit containing query and memory context + """ + # Get embeddings for the query + ( + essential_query_embedding, + extra_requirements_embedding, + purpose_embedding, + ) = await self.get_query_embeddings(memory.query) + + # Create a new SemanticMemoryUnitDB instance with embeddings + async with AsyncSession(self.engine) as session: + semantic_memory = SemanticMemoryUnitDB( + repository_id=repository_id, + query_essential_query=memory.query.essential_query, + query_extra_requirements=memory.query.extra_requirements, + query_purpose=memory.query.purpose, + essential_query_embedding=essential_query_embedding, + extra_requirements_embedding=extra_requirements_embedding, + purpose_embedding=purpose_embedding, + ) + session.add(semantic_memory) + await session.flush() # Flush to get the semantic_memory.id + + # Store each context separately + for context in memory.contexts: + context_db = ContextDB( + semantic_memory_id=semantic_memory.id, + relative_path=context.relative_path, + content=context.content, + start_line_number=context.start_line_number, + end_line_number=context.end_line_number, + ) + session.add(context_db) + + await session.commit() + + async def retrieve_memory( + self, repository_id: int, query: Query + ) -> Sequence[SemanticMemoryUnit]: + """ + Retrieve semantic memories most similar to the given query. + + Uses a weighted combination of cosine similarity across three query components: + - Essential query: 50% weight (most important for relevance) + - Extra requirements: 25% weight (filters and constraints) + - Purpose: 25% weight (intent and context) + + The combined similarity score ranks results, with higher scores indicating + better matches. Only results with similarity >= SEMANTIC_MEMORY_MIN_SIMILARITY + are returned. + + Args: + repository_id: Filter memories by this repository identifier + query: Query object to search for similar memories + + Returns: + List of SemanticMemoryUnit objects with contexts loaded, ordered by similarity + (most similar first), limited to SEMANTIC_MEMORY_MAX_RESULTS and filtered by + minimum similarity threshold + """ + # Weights for similarity components: essential query gets highest weight + w1, w2, w3 = 0.5, 0.25, 0.5 + top_k = settings.SEMANTIC_MEMORY_MAX_RESULTS + min_similarity = settings.SEMANTIC_MEMORY_MIN_SIMILARITY + + # Get embeddings for the query components + ( + essential_query_embedding, + extra_requirements_embedding, + purpose_embedding, + ) = await self.get_query_embeddings(query) + + # Compute weighted similarity score + # Cosine distance is converted to similarity (1 - distance) + similarity = ( + w1 + * ( + 1 + - SemanticMemoryUnitDB.essential_query_embedding.cosine_distance( + essential_query_embedding + ) + ) + + w2 + * ( + 1 + - SemanticMemoryUnitDB.extra_requirements_embedding.cosine_distance( + extra_requirements_embedding + ) + ) + + w3 * (1 - SemanticMemoryUnitDB.purpose_embedding.cosine_distance(purpose_embedding)) + ) + + # Build query: filter by repository_id, similarity threshold, and order by similarity + stmt = ( + select(SemanticMemoryUnitDB) + .order_by(similarity.desc()) + .limit(top_k) + .where( + SemanticMemoryUnitDB.repository_id == repository_id, + similarity >= min_similarity, + ) + ) + + async with AsyncSession(self.engine) as session: + result = await session.execute(stmt) + semantic_memories = result.scalars().all() + + # Load contexts for each semantic memory + memory_units = [] + for sem_mem in semantic_memories: + # Query contexts for this semantic memory + contexts_stmt = select(ContextDB).where(ContextDB.semantic_memory_id == sem_mem.id) + contexts_result = await session.execute(contexts_stmt) + contexts_db = contexts_result.scalars().all() + + # Convert to Context model objects + contexts = [ + Context( + relative_path=ctx.relative_path, + content=ctx.content, + start_line_number=ctx.start_line_number, + end_line_number=ctx.end_line_number, + ) + for ctx in contexts_db + ] + + # Create SemanticMemoryUnit + memory_unit = SemanticMemoryUnit( + query=Query( + essential_query=sem_mem.query_essential_query, + extra_requirements=sem_mem.query_extra_requirements, + purpose=sem_mem.query_purpose, + ), + contexts=contexts, + ) + memory_units.append(memory_unit) + + return memory_units + + async def delete_memories_by_repository(self, repository_id: int): + """ + Delete all semantic memories for a given repository. + + Args: + repository_id: Repository identifier to filter memories for deletion + + Returns: + Number of memories deleted + """ + stmt = delete(SemanticMemoryUnitDB).where( + SemanticMemoryUnitDB.repository_id == repository_id + ) + + async with AsyncSession(self.engine) as session: + await session.execute(stmt) + await session.commit() diff --git a/athena/chat_models/custom_chat_openai.py b/athena/chat_models/custom_chat_openai.py index a459ec7..f1ac826 100644 --- a/athena/chat_models/custom_chat_openai.py +++ b/athena/chat_models/custom_chat_openai.py @@ -1,20 +1,18 @@ -from typing import Any, List, Optional +import logging +import threading +from typing import Any, Optional from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import BaseMessage, trim_messages +from langchain_core.messages import BaseMessage from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI -from pydantic import PrivateAttr - -from athena.utils.llm_util import tiktoken_counter class CustomChatOpenAI(ChatOpenAI): - _max_input_tokens: int = PrivateAttr() - - def __init__(self, max_input_tokens: int, *args: Any, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self._max_input_tokens = max_input_tokens + self.max_retries = 3 # Set the maximum number of retries + self._logger = logging.getLogger(f"thread-{threading.get_ident()}.{__name__}") def bind_tools(self, tools, tool_choice=None, **kwargs): kwargs["parallel_tool_calls"] = False @@ -25,19 +23,11 @@ def invoke( input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, - stop: Optional[List[str]] = None, + stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: return super().invoke( - input=trim_messages( - input, - token_counter=tiktoken_counter, - strategy="last", - max_tokens=self._max_input_tokens, - start_on="human", - end_on=("human", "tool"), - include_system=True, - ), + input=input, config=config, stop=stop, **kwargs, diff --git a/athena/configuration/config.py b/athena/configuration/config.py index c0024e7..9a7a0d3 100644 --- a/athena/configuration/config.py +++ b/athena/configuration/config.py @@ -30,9 +30,7 @@ class Settings(BaseSettings): GOOGLE_APPLICATION_CREDENTIALS: Optional[str] = None # Model parameters - MODEL_MAX_INPUT_TOKENS: int MODEL_TEMPERATURE: Optional[float] = None - MODEL_MAX_OUTPUT_TOKENS: Optional[int] = None # Database DATABASE_URL: str @@ -43,11 +41,18 @@ class Settings(BaseSettings): MEMORY_SEARCH_LIMIT: int = 10 # Default search result limit # Embeddings (OpenAI-format, works with Codestral embed via a compatible gateway) - EMBEDDING_MODEL: Optional[str] = None - EMBEDDING_API_KEY: Optional[str] = None - EMBEDDING_BASE_URL: Optional[str] = None - EMBEDDING_DIM: Optional[int] = None - EMBEDDING_IVFFLAT_LISTS: Optional[int] = None + EMBEDDING_MODEL: str + EMBEDDING_API_KEY: str + EMBEDDING_BASE_URL: str + EMBEDDING_DIM: Optional[int] = 1024 # Dimension of the embedding vectors + + # Embedding index settings + EPISODE_MEMORY_EMBEDDING_IVFFLAT_LISTS: Optional[int] = 100 + SEMANTIC_MEMORY_EMBEDDING_IVFFLAT_LISTS: Optional[int] = 10 + + # Semantic Memory settings + SEMANTIC_MEMORY_MAX_RESULTS: int # Max semantic memory results to retrieve + SEMANTIC_MEMORY_MIN_SIMILARITY: float # Minimum similarity threshold (0-1) settings = Settings() diff --git a/athena/entity/__init__.py b/athena/entity/__init__.py index f5b71d1..e69de29 100644 --- a/athena/entity/__init__.py +++ b/athena/entity/__init__.py @@ -1,3 +0,0 @@ -from .memory import MemoryUnitDB - -__all__ = ["MemoryUnitDB"] diff --git a/athena/entity/context.py b/athena/entity/context.py new file mode 100644 index 0000000..f0b9cdc --- /dev/null +++ b/athena/entity/context.py @@ -0,0 +1,34 @@ +from typing import Optional + +from sqlalchemy import Column, ForeignKey, Integer, Text +from sqlmodel import Field, SQLModel + + +class ContextDB(SQLModel, table=True): + """Database model for storing code context information.""" + + __tablename__ = "contexts" + + id: int = Field(primary_key=True, description="ID") + + # Foreign key to semantic_memories + semantic_memory_id: int = Field( + sa_column=Column( + Integer, + ForeignKey("semantic_memories.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), + description="Foreign key to semantic memory", + ) + + relative_path: str = Field( + sa_column=Column(Text, nullable=False), description="Relative file path" + ) + content: str = Field(sa_column=Column(Text, nullable=False), description="Code content") + start_line_number: Optional[int] = Field( + default=None, sa_column=Column(Integer), description="Starting line number" + ) + end_line_number: Optional[int] = Field( + default=None, sa_column=Column(Integer), description="Ending line number" + ) diff --git a/athena/entity/memory.py b/athena/entity/episodic_memory.py similarity index 67% rename from athena/entity/memory.py rename to athena/entity/episodic_memory.py index eb2730c..659031e 100644 --- a/athena/entity/memory.py +++ b/athena/entity/episodic_memory.py @@ -7,30 +7,38 @@ from sqlmodel import Field, SQLModel, UniqueConstraint from athena.configuration.config import settings -from athena.models import Action, MemorySource, MemoryTimestamp, MemoryUnit, Result, State, Task +from athena.models.episodic_memory import ( + Action, + EpisodicMemorySource, + EpisodicMemoryTimestamp, + EpisodicMemoryUnit, + Result, + State, + Task, +) # Database models for persistent storage -class MemoryUnitDB(SQLModel, table=True): - """Database model for persistent storage of memory units.""" +class EpisodicMemoryUnitDB(SQLModel, table=True): + """Database model for persistent storage of episodic memory units.""" - __tablename__ = "memory_units" - __table_args__ = (UniqueConstraint("memory_id", name="uq_memory_id"),) + __tablename__ = "episodic_memories" + __table_args__ = (UniqueConstraint("episodic_memory_id", name="uq_episodic_memory_id"),) id: Optional[int] = Field(default=None, primary_key=True) # Source information - memory_id: str - memory_source_name: str - memory_run_id: str - memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) - memory_updated_at: Optional[datetime] = Field( + episodic_memory_id: str + episodic_memory_source_name: str + episodic_memory_run_id: str + episodic_memory_created_at: datetime = Field(sa_column=Column(DateTime(timezone=True))) + episodic_memory_updated_at: Optional[datetime] = Field( default=None, sa_column=Column(DateTime(timezone=True)) ) - memory_invalid_at: Optional[datetime] = Field( + episodic_memory_invalid_at: Optional[datetime] = Field( default=None, sa_column=Column(DateTime(timezone=True)) ) - memory_metadata: str = Field( + episodic_memory_metadata: str = Field( default="{}", sa_column=Column(Text) ) # JSON string for Dict[str, Any] @@ -70,16 +78,16 @@ class MemoryUnitDB(SQLModel, table=True): ) @classmethod - def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": - """Create a database model from a MemoryUnit.""" + def from_episodic_memory_unit(cls, memory_unit: EpisodicMemoryUnit) -> "EpisodicMemoryUnitDB": + """Create a database model from a EpisodicMemoryUnit.""" return cls( - memory_id=memory_unit.memory_id, - memory_source_name=memory_unit.source.source_name, - memory_run_id=memory_unit.source.run_id, - memory_created_at=memory_unit.timestamp.created_at, - memory_updated_at=memory_unit.timestamp.updated_at, - memory_invalid_at=memory_unit.timestamp.invalid_at, - memory_metadata=json.dumps(memory_unit.source.metadata) + episodic_memory_id=memory_unit.episodic_memory_id, + episodic_memory_source_name=memory_unit.source.source_name, + episodic_memory_run_id=memory_unit.source.run_id, + episodic_memory_created_at=memory_unit.timestamp.created_at, + episodic_memory_updated_at=memory_unit.timestamp.updated_at, + episodic_memory_invalid_at=memory_unit.timestamp.invalid_at, + episodic_memory_metadata=json.dumps(memory_unit.source.metadata) if memory_unit.source.metadata else "{}", task_issue_title=memory_unit.task.issue_title, @@ -103,20 +111,20 @@ def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": result_exit_code=memory_unit.result.exit_code, ) - def to_memory_unit(self) -> MemoryUnit: - """Convert database model back to MemoryUnit.""" - return MemoryUnit( - memory_id=self.memory_id, - timestamp=MemoryTimestamp( - created_at=self.memory_created_at, - updated_at=self.memory_updated_at, - invalid_at=self.memory_invalid_at, + def to_episodic_memory_unit(self) -> EpisodicMemoryUnit: + """Convert database model back to EpisodicMemoryUnit.""" + return EpisodicMemoryUnit( + episodic_memory_id=self.episodic_memory_id, + timestamp=EpisodicMemoryTimestamp( + created_at=self.episodic_memory_created_at, + updated_at=self.episodic_memory_updated_at, + invalid_at=self.episodic_memory_invalid_at, ), - source=MemorySource( - source_name=self.memory_source_name, - run_id=self.memory_run_id, - metadata=json.loads(self.memory_metadata) - if self.memory_metadata not in (None, "", "null") + source=EpisodicMemorySource( + source_name=self.episodic_memory_source_name, + run_id=self.episodic_memory_run_id, + metadata=json.loads(self.episodic_memory_metadata) + if self.episodic_memory_metadata not in (None, "", "null") else {}, ), task=Task( diff --git a/athena/entity/semantic_memory.py b/athena/entity/semantic_memory.py new file mode 100644 index 0000000..062addf --- /dev/null +++ b/athena/entity/semantic_memory.py @@ -0,0 +1,34 @@ +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, Text +from sqlmodel import Field, SQLModel + +from athena.configuration.config import settings + + +# Database models for persistent storage +class SemanticMemoryUnitDB(SQLModel, table=True): + """Database model for persistent storage of semantic memory units.""" + + __tablename__ = "semantic_memories" + + id: int = Field(primary_key=True, description="ID") + + # Source information + repository_id: int = Field(index=True, description="Repository identifier") + + # Query information (for retrieval) + query_essential_query: str = Field(sa_column=Column(Text, nullable=False)) + query_extra_requirements: str = Field(sa_column=Column(Text, nullable=False)) + query_purpose: str = Field(sa_column=Column(Text, nullable=False)) + + # Embeddings for semantic retrieval (pgvector) + _vec_dim: int = settings.EMBEDDING_DIM or 1024 # type: ignore[assignment] + + # Query embeddings for weighted retrieval + essential_query_embedding: list[float] = Field( + sa_column=Column(Vector(_vec_dim), nullable=False) + ) + extra_requirements_embedding: list[float] = Field( + sa_column=Column(Vector(_vec_dim), nullable=False) + ) + purpose_embedding: list[float] = Field(sa_column=Column(Vector(_vec_dim), nullable=False)) diff --git a/athena/models/__init__.py b/athena/models/__init__.py index 5228135..e69de29 100644 --- a/athena/models/__init__.py +++ b/athena/models/__init__.py @@ -1,21 +0,0 @@ -from .memory import ( - Action, - MemorySource, - MemoryTimestamp, - MemoryUnit, - Result, - State, - Task, -) -from .message import Message - -__all__ = [ - "Message", - "MemoryUnit", - "MemorySource", - "MemoryTimestamp", - "Task", - "State", - "Action", - "Result", -] diff --git a/athena/models/context.py b/athena/models/context.py new file mode 100644 index 0000000..20b0107 --- /dev/null +++ b/athena/models/context.py @@ -0,0 +1,12 @@ +from typing import Optional + +from pydantic import BaseModel + + +class Context(BaseModel): + """Represents a code context with file location and content.""" + + relative_path: str + content: str + start_line_number: Optional[int] = None + end_line_number: Optional[int] = None diff --git a/athena/models/memory.py b/athena/models/episodic_memory.py similarity index 79% rename from athena/models/memory.py rename to athena/models/episodic_memory.py index 4776665..cb8b0c9 100644 --- a/athena/models/memory.py +++ b/athena/models/episodic_memory.py @@ -5,23 +5,23 @@ from pydantic import BaseModel, Field -class MemoryTimestamp(BaseModel): - """Lifecycle timestamps for a memory unit.""" +class EpisodicMemoryTimestamp(BaseModel): + """Lifecycle timestamps for a episodic memory unit.""" created_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), - description="When the memory unit was first created", + description="When the episodic memory unit was first created", ) updated_at: Optional[datetime] = Field( - None, description="When the memory was last updated/refreshed" + None, description="When the episodic memory was last updated/refreshed" ) invalid_at: Optional[datetime] = Field( - None, description="When the memory was invalidated or expired" + None, description="When the episodic memory was invalidated or expired" ) -class MemorySource(BaseModel): - """Source information for a memory unit.""" +class EpisodicMemorySource(BaseModel): + """Source information for a episodic memory unit.""" source_name: str = Field( ..., description="Memory source, e.g., agent name, model name, dataset name, or file path" @@ -93,14 +93,14 @@ class Result(BaseModel): ) -class MemoryUnit(BaseModel): +class EpisodicMemoryUnit(BaseModel): """ - Core memory unit capturing one action of agent execution. + Core episodic memory unit capturing one action of agent execution. This includes: - - The memory id (memory_id) - - The memory timestamp (timestamp) - - The memory source (source_name, run_id, metadata) + - The episodic memory id (memory_id) + - The episodic memory timestamp (timestamp) + - The episodic memory source (source_name, run_id, metadata) - The task being worked on (issue and repository details) - The current state (what's done, what's todo) - The action taken by the agent @@ -109,10 +109,10 @@ class MemoryUnit(BaseModel): memory_id: str = Field( default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for this memory unit", + description="Unique identifier for this episodic memory unit", ) - timestamp: MemoryTimestamp = Field(..., description="Timestamp of the memory") - source: MemorySource + timestamp: EpisodicMemoryTimestamp = Field(..., description="Timestamp of the episodic memory") + source: EpisodicMemorySource task: Task state: State action: Action diff --git a/athena/models/query.py b/athena/models/query.py new file mode 100644 index 0000000..e58ea41 --- /dev/null +++ b/athena/models/query.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class Query(BaseModel): + essential_query: str + extra_requirements: str + purpose: str diff --git a/athena/models/requests/__init__.py b/athena/models/requests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/athena/models/requests/semantic_memory.py b/athena/models/requests/semantic_memory.py new file mode 100644 index 0000000..84b153c --- /dev/null +++ b/athena/models/requests/semantic_memory.py @@ -0,0 +1,14 @@ +from typing import List + +from pydantic import BaseModel, Field + +from athena.models.context import Context +from athena.models.query import Query + + +class StoreSemanticMemoryRequest(BaseModel): + """Request model for storing semantic memory.""" + + repository_id: int = Field(description="Repository identifier") + query: Query + contexts: List[Context] diff --git a/athena/models/responses/__init__.py b/athena/models/responses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/athena/models/responses/response.py b/athena/models/responses/response.py new file mode 100644 index 0000000..09c7e59 --- /dev/null +++ b/athena/models/responses/response.py @@ -0,0 +1,15 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T") + + +class Response(BaseModel, Generic[T]): + """ + Generic response model for API responses. + """ + + code: int = 200 + message: str = "success" + data: T | None = None diff --git a/athena/models/responses/semantic_memory.py b/athena/models/responses/semantic_memory.py new file mode 100644 index 0000000..07c8121 --- /dev/null +++ b/athena/models/responses/semantic_memory.py @@ -0,0 +1,14 @@ +from typing import List + +from pydantic import BaseModel + +from athena.models.context import Context + + +class SemanticMemoryResponse(BaseModel): + """Response model for semantic memory retrieval.""" + + query_essential_query: str + query_extra_requirements: str + query_purpose: str + memory_context_contexts: List[Context] diff --git a/athena/models/semantic_memory.py b/athena/models/semantic_memory.py new file mode 100644 index 0000000..ec240a7 --- /dev/null +++ b/athena/models/semantic_memory.py @@ -0,0 +1,11 @@ +from typing import List + +from pydantic import BaseModel + +from athena.models.context import Context +from athena.models.query import Query + + +class SemanticMemoryUnit(BaseModel): + query: Query + contexts: List[Context] diff --git a/athena/prompts/memory_extraction.py b/athena/prompts/episodic_memory_extraction.py similarity index 100% rename from athena/prompts/memory_extraction.py rename to athena/prompts/episodic_memory_extraction.py diff --git a/athena/scripts/offline_ingest_hf.py b/athena/scripts/offline_ingest_hf.py index 2baf223..7c3cd14 100644 --- a/athena/scripts/offline_ingest_hf.py +++ b/athena/scripts/offline_ingest_hf.py @@ -4,9 +4,9 @@ from athena.app.services.database_service import DatabaseService from athena.app.services.embedding_service import EmbeddingService +from athena.app.services.episodic_memory_extraction_service import EpisodicMemoryExtractionService +from athena.app.services.episodic_memory_storage_service import EpisodicMemoryStorageService from athena.app.services.llm_service import LLMService -from athena.app.services.memory_extraction_service import MemoryExtractionService -from athena.app.services.memory_storage_service import MemoryStorageService from athena.configuration.config import settings @@ -29,7 +29,7 @@ async def main(): embed_dim=settings.EMBEDDING_DIM or 1024, ) - store = MemoryStorageService(db.get_sessionmaker(), embedding_service) + store = EpisodicMemoryStorageService(db.get_sessionmaker(), embedding_service) llm = LLMService( model_name=settings.MODEL_NAME, @@ -40,9 +40,10 @@ async def main(): openai_format_base_url=settings.OPENAI_FORMAT_BASE_URL, anthropic_api_key=settings.ANTHROPIC_API_KEY, gemini_api_key=settings.GEMINI_API_KEY, + google_application_credentials=settings.GOOGLE_APPLICATION_CREDENTIALS, ) - extractor = MemoryExtractionService(llm_service=llm, memory_store=store) + extractor = EpisodicMemoryExtractionService(llm_service=llm, memory_store=store) extractor.extract_from_huggingface_trajectory_repository(args.repo, args.split) diff --git a/athena/utils/llm_util.py b/athena/utils/llm_util.py deleted file mode 100644 index db02477..0000000 --- a/athena/utils/llm_util.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Sequence - -import tiktoken -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage -from langchain_core.output_parsers import StrOutputParser - - -def str_token_counter(text: str) -> int: - """Counts the number of tokens in a string using tiktoken's o200k_base encoding. - - Args: - text: The input string to count tokens for. - - Returns: - The number of tokens in the input string. - """ - enc = tiktoken.get_encoding("o200k_base") - return len(enc.encode(text)) - - -def tiktoken_counter(messages: Sequence[BaseMessage]) -> int: - """Counts tokens across multiple message types using tiktoken tokenization. - - Approximately reproduces the token counting methodology from OpenAI's cookbook: - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - - Args: - messages: A sequence of BaseMessage objects (HumanMessage, AIMessage, - ToolMessage, or SystemMessage) to count tokens for. - - Returns: - The total number of tokens across all messages, including overhead tokens. - - Raises: - ValueError: If an unsupported message type is encountered. - - Notes: - - Uses a fixed overhead of 3 tokens for reply priming - - Adds 3 tokens per message for message formatting - - Adds 1 token per message name if present - - For simplicity, only supports string message contents - """ - output_parser = StrOutputParser() - num_tokens = 3 # every reply is primed with <|start|>assistant<|message|> - tokens_per_message = 3 - tokens_per_name = 1 - for msg in messages: - if isinstance(msg, HumanMessage): - role = "user" - elif isinstance(msg, AIMessage): - role = "assistant" - elif isinstance(msg, ToolMessage): - role = "tool" - elif isinstance(msg, SystemMessage): - role = "system" - else: - raise ValueError(f"Unsupported messages type {msg.__class__}") - msg_content = output_parser.invoke(msg) - num_tokens += tokens_per_message + str_token_counter(role) + str_token_counter(msg_content) - if msg.name: - num_tokens += tokens_per_name + str_token_counter(msg.name) - return num_tokens diff --git a/docker-compose.yml b/docker-compose.yml index b7d03c5..6a440b3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,7 +5,7 @@ networks: services: postgres: image: pgvector/pgvector:pg16 - container_name: postgres_container + container_name: athena_postgres_container networks: - athena_network environment: @@ -47,13 +47,15 @@ services: - ATHENA_OPENAI_FORMAT_BASE_URL=${ATHENA_OPENAI_FORMAT_BASE_URL} # Model settings - - ATHENA_MODEL_MAX_INPUT_TOKENS=${ATHENA_MODEL_MAX_INPUT_TOKENS} - ATHENA_MODEL_TEMPERATURE=${ATHENA_MODEL_TEMPERATURE} - - ATHENA_MODEL_MAX_OUTPUT_TOKENS=${ATHENA_MODEL_MAX_OUTPUT_TOKENS} # Database settings - ATHENA_DATABASE_URL=${ATHENA_DATABASE_URL} + # Semantic memory settings + - ATHENA_SEMANTIC_MEMORY_MAX_RESULTS=${ATHENA_SEMANTIC_MEMORY_MAX_RESULTS} + - ATHENA_SEMANTIC_MEMORY_MIN_SIMILARITY=${ATHENA_SEMANTIC_MEMORY_MIN_SIMILARITY} + volumes: - .:/app - /var/run/docker.sock:/var/run/docker.sock diff --git a/example.env b/example.env index ab66400..d9637fa 100644 --- a/example.env +++ b/example.env @@ -30,3 +30,7 @@ ATHENA_EMBEDDING_IVFFLAT_LISTS=100 # Database settings ATHENA_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/postgres + +# Semantic memory settings +ATHENA_SEMANTIC_MEMORY_MAX_RESULTS=5 +ATHENA_SEMANTIC_MEMORY_MIN_SIMILARITY=0.85