From 988f9839913bd36e10f0d8d31b3ceec700a27896 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 18:09:15 +0800 Subject: [PATCH 1/3] feat: implement memory storage --- athena/app/api/main.py | 11 +- athena/app/dependencies.py | 16 +- athena/app/services/database_service.py | 35 +++- athena/app/services/embedding_service.py | 32 +++ athena/app/services/memory_service.py | 133 +++++++------ athena/app/services/memory_storage_service.py | 185 ++++++++++++++++++ athena/scripts/offline_ingest_hf.py | 47 +++++ docker-compose.yml | 2 +- example.env | 4 + pyproject.toml | 6 +- 10 files changed, 411 insertions(+), 60 deletions(-) create mode 100644 athena/app/services/embedding_service.py create mode 100644 athena/app/services/memory_storage_service.py create mode 100644 athena/scripts/offline_ingest_hf.py diff --git a/athena/app/api/main.py b/athena/app/api/main.py index eef7723..f1722f1 100644 --- a/athena/app/api/main.py +++ b/athena/app/api/main.py @@ -1,3 +1,12 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Request api_router = APIRouter() + + +@api_router.get("/memory/search", tags=["memory"]) +async def search_memory(request: Request, q: str, field: str = "task_state", limit: int = 10): + services = getattr(request.app.state, "service", {}) + memory_service = services.get("memory_service") + assert memory_service is not None, "Memory service not initialized" + results = await memory_service.search_memory(q, limit=limit, field=field) + return [m.model_dump() for m in results] diff --git a/athena/app/dependencies.py b/athena/app/dependencies.py index 92f94c6..1495754 100644 --- a/athena/app/dependencies.py +++ b/athena/app/dependencies.py @@ -4,8 +4,10 @@ from athena.app.services.base_service import BaseService from athena.app.services.database_service import DatabaseService +from athena.app.services.embedding_service import EmbeddingService from athena.app.services.llm_service import LLMService from athena.app.services.memory_service import MemoryService +from athena.app.services.memory_storage_service import MemoryStorageService from athena.configuration.config import settings @@ -41,7 +43,19 @@ def initialize_services() -> Dict[str, BaseService]: settings.GEMINI_API_KEY, ) - memory_service = MemoryService(storage_backend=settings.MEMORY_STORAGE_BACKEND) + embedding_service = None + if settings.EMBEDDINGS_MODEL and settings.EMBEDDINGS_API_KEY and settings.EMBEDDINGS_BASE_URL: + embedding_service = EmbeddingService( + model=settings.EMBEDDINGS_MODEL, + api_key=settings.EMBEDDINGS_API_KEY, + base_url=settings.EMBEDDINGS_BASE_URL, + ) + + memory_store = MemoryStorageService(database_service.get_sessionmaker(), embedding_service) + + memory_service = MemoryService( + storage_backend=settings.MEMORY_STORAGE_BACKEND, store=memory_store + ) return { "llm_service": llm_service, diff --git a/athena/app/services/database_service.py b/athena/app/services/database_service.py index 1303aa5..69726fd 100644 --- a/athena/app/services/database_service.py +++ b/athena/app/services/database_service.py @@ -1,4 +1,5 @@ -from sqlalchemy.ext.asyncio import create_async_engine +from pgvector.sqlalchemy import register_vector +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlmodel import SQLModel from athena.app.services.base_service import BaseService @@ -8,12 +9,40 @@ class DatabaseService(BaseService): def __init__(self, DATABASE_URL: str): self.engine = create_async_engine(DATABASE_URL, echo=True) + self.sessionmaker = async_sessionmaker( + self.engine, expire_on_commit=False, class_=AsyncSession + ) self._logger = get_logger(__name__) # Create the database and tables async def create_db_and_tables(self): async with self.engine.begin() as conn: + # Try to enable pgvector if available (safe to ignore errors on non-Postgres or missing extension) + try: + await conn.exec_driver_sql("CREATE EXTENSION IF NOT EXISTS vector") + except Exception: + # Extension not available; proceed without it + pass + # Ensure vector types are registered with SQLAlchemy + try: + register_vector(conn.sync_connection) + except Exception: + pass await conn.run_sync(SQLModel.metadata.create_all) + # Create ivfflat indexes for vector columns (if extension present) + try: + await conn.exec_driver_sql( + "CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON memory_units USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = 100)" + ) + await conn.exec_driver_sql( + "CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON memory_units USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = 100)" + ) + await conn.exec_driver_sql( + "CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON memory_units USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = 100)" + ) + except Exception: + # Index creation failed (likely no pgvector). Continue without indexes. + pass async def start(self): """ @@ -29,3 +58,7 @@ async def close(self): """ await self.engine.dispose() self._logger.info("Database connection closed.") + + def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]: + """Return the async sessionmaker for dependency injection.""" + return self.sessionmaker diff --git a/athena/app/services/embedding_service.py b/athena/app/services/embedding_service.py new file mode 100644 index 0000000..d4e43ec --- /dev/null +++ b/athena/app/services/embedding_service.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +import requests + +from athena.app.services.base_service import BaseService + + +class EmbeddingService(BaseService): + """ + Simple OpenAI-compatible embedding client. + + Works with providers exposing POST /v1/embeddings, including Codestral embed deployments + behind OpenAI-format gateways. Configure via model, api_key, base_url. + """ + + def __init__(self, model: str, api_key: str, base_url: str): + self.model = model + self.api_key = api_key + self.base_url = base_url.rstrip("/") + + def embed(self, inputs: Iterable[str]) -> List[List[float]]: + data = {"model": self.model, "input": list(inputs)} + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + url = f"{self.base_url}/v1/embeddings" + resp = requests.post(url, json=data, headers=headers, timeout=60) + resp.raise_for_status() + payload = resp.json() + vectors = [item["embedding"] for item in payload.get("data", [])] + return vectors diff --git a/athena/app/services/memory_service.py b/athena/app/services/memory_service.py index 47a0547..de41fd4 100644 --- a/athena/app/services/memory_service.py +++ b/athena/app/services/memory_service.py @@ -1,7 +1,9 @@ -from typing import List, Optional +import inspect +from typing import List, Optional, Sequence from athena.app.services.base_service import BaseService -from init_memory_base.trajectory_memory_extractor_deepseek import MemoryUnit +from athena.app.services.memory_storage_service import MemoryStorageService +from athena.models.memory import MemoryUnit class MemoryService(BaseService): @@ -24,7 +26,9 @@ class MemoryService(BaseService): with various storage backends (in-memory, SQL database, vector database). """ - def __init__(self, storage_backend: str = "in_memory"): + def __init__( + self, storage_backend: str = "in_memory", store: Optional[MemoryStorageService] = None + ): """ Initialize the Memory Service with a specified storage backend. @@ -35,34 +39,29 @@ def __init__(self, storage_backend: str = "in_memory"): - "vector": Vector database for semantic search """ self.storage_backend = storage_backend - self._store = None - - def start(self): - """ - Start the memory service and initialize the storage backend. - - This method should be overridden to implement specific initialization - logic for the chosen storage backend (e.g., database connection, - vector index creation). - """ - # TODO: Implement storage backend initialization based on storage_backend - # For in_memory: create empty dictionary - # For database: establish connection and create tables if needed - # For vector: initialize vector index and embeddings - pass - - def close(self): - """ - Close the memory service and release any resources. - - This method should be overridden to implement specific cleanup - logic for the storage backend (e.g., close database connections, - persist in-memory data). - """ - # TODO: Implement resource cleanup based on storage_backend - pass - - def store_memory(self, memory_unit: MemoryUnit) -> None: + self._store = store + + async def start(self): + """Initialize the storage backend if needed and validate configuration.""" + if self.storage_backend in {"database", "vector"} and self._store is None: + raise RuntimeError( + "MemoryService requires a storage store for database/vector backends" + ) + if self._store is not None and hasattr(self._store, "start"): + if inspect.iscoroutinefunction(self._store.start): + await self._store.start() # type: ignore[misc] + else: + self._store.start() # type: ignore[misc] + + async def close(self): + """Close the storage backend if it exposes a close method.""" + if self._store is not None and hasattr(self._store, "close"): + if inspect.iscoroutinefunction(self._store.close): + await self._store.close() # type: ignore[misc] + else: + self._store.close() # type: ignore[misc] + + async def store_memory(self, memory_unit: MemoryUnit) -> None: """ Store a memory unit in the memory service. @@ -70,16 +69,27 @@ def store_memory(self, memory_unit: MemoryUnit) -> None: memory_unit: The MemoryUnit object containing task, state, action, and result information to be stored. - This method should handle deduplication using the memory unit's - canonical key and support bulk operations for efficiency. + Deduplication is handled at the database layer on memory_id via upsert. """ - # TODO: Implement memory storage logic - # - Use memory_unit.key() for deduplication - # - Support different storage backends - # - Handle bulk operations efficiently - pass + if self._store is None: + return + try: + await self._store.upsert([memory_unit]) # type: ignore[arg-type] + except Exception: + return - def search_memory(self, query: str, limit: int = 10) -> List[MemoryUnit]: + async def store_memories(self, memory_units: Sequence[MemoryUnit]) -> None: + """Bulk store multiple memory units efficiently.""" + if self._store is None or not memory_units: + return + try: + await self._store.upsert(list(memory_units)) # type: ignore[arg-type] + except Exception: + return + + async def search_memory( + self, query: str, limit: int = 10, field: str = "task_state" + ) -> List[MemoryUnit]: """ Search for relevant memory units based on a query. @@ -95,14 +105,14 @@ def search_memory(self, query: str, limit: int = 10) -> List[MemoryUnit]: fields (task description, state information, action details, result summaries) and return the most relevant memories. """ - # TODO: Implement memory search logic - # - Support semantic search across multiple fields - # - Implement relevance scoring and ranking - # - Handle different storage backend search capabilities - # - Return results ordered by relevance - pass - - def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: + if self._store is None: + return [] + try: + return await self._store.search_by_text(query, field=field, limit=limit) # type: ignore[no-any-return] + except Exception: + return [] + + async def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: """ Retrieve a specific memory unit by its canonical key. @@ -112,25 +122,38 @@ def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: Returns: The MemoryUnit object if found, None otherwise. """ - # TODO: Implement direct key-based retrieval - pass + # Here we treat key as memory_id for simplicity + if self._store is None: + return None + try: + return await self._store.get_by_memory_id(key) # type: ignore[no-any-return] + except Exception: + return None - def get_all_memories(self) -> List[MemoryUnit]: + async def get_all_memories(self) -> List[MemoryUnit]: """ Retrieve all memory units stored in the service. Returns: List of all MemoryUnit objects in the service. """ - # TODO: Implement retrieval of all memories - pass + if self._store is None: + return [] + try: + return await self._store.list_all() # type: ignore[no-any-return] + except Exception: + return [] - def clear_memories(self) -> None: + async def clear_memories(self) -> None: """ Clear all memory units from the service. This method is primarily for testing and should be used with caution in production environments. """ - # TODO: Implement memory clearing functionality - pass + if self._store is None: + return + try: + await self._store.clear_all() # type: ignore[arg-type] + except Exception: + return diff --git a/athena/app/services/memory_storage_service.py b/athena/app/services/memory_storage_service.py new file mode 100644 index 0000000..bf0b657 --- /dev/null +++ b/athena/app/services/memory_storage_service.py @@ -0,0 +1,185 @@ +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlmodel import col + +from athena.app.services.base_service import BaseService +from athena.app.services.embedding_service import EmbeddingService +from athena.configuration.config import settings +from athena.entity.memory import MemoryUnitDB +from athena.models.memory import MemoryUnit + + +def _ensure_dim(vec: List[float]) -> List[float]: + dim = settings.EMBEDDINGS_DIM or len(vec) + if len(vec) == dim: + return vec + if len(vec) > dim: + return vec[:dim] + # pad with zeros + return vec + [0.0] * (dim - len(vec)) + + +class MemoryStorageService(BaseService): + """ + Postgres-backed memory store using SQLModel and optional embeddings for semantic search. + + Embeddings are stored as float arrays in columns `task_embedding` and `state_embedding`. + Cosine similarity is computed in Python for portability unless pgvector is wired in later. + """ + + def __init__( + self, + sessionmaker: async_sessionmaker[AsyncSession], + embedding_service: Optional[EmbeddingService] = None, + max_stored_units: int = 100000, + ): + self._sessionmaker = sessionmaker + self._embeddings = embedding_service + self._max = max_stored_units + + async def upsert(self, units: List[MemoryUnit]) -> None: + if not units: + return + async with self._sessionmaker() as session: + for u in units: + await self._upsert_one(session, u) + await session.commit() + + async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: + # Prepare embeddings if configured + task_text = self._serialize_task(unit) + state_text = self._serialize_state(unit) + task_vec: Optional[List[float]] = None + state_vec: Optional[List[float]] = None + task_state_vec: Optional[List[float]] = None + if self._embeddings is not None: + combined_text = self._serialize_task_state(unit) + vecs = self._embeddings.embed([task_text, state_text, combined_text]) + if len(vecs) >= 3: + task_vec, state_vec, task_state_vec = ( + _ensure_dim(vecs[0]), + _ensure_dim(vecs[1]), + _ensure_dim(vecs[2]), + ) + + # Check existing + existing = await session.scalar( + select(MemoryUnitDB).where(col(MemoryUnitDB.memory_id) == unit.memory_id) + ) + + if existing is None: + row = MemoryUnitDB.from_memory_unit(unit) + row.task_embedding = task_vec + row.state_embedding = state_vec + row.task_state_embedding = task_state_vec + session.add(row) + else: + # Update fields + fresh = MemoryUnitDB.from_memory_unit(unit) + for attr in ( + "memory_source_name", + "memory_run_id", + "memory_created_at", + "memory_updated_at", + "memory_invalid_at", + "memory_metadata", + "task_issue_title", + "task_issue_body", + "task_issue_comments", + "task_issue_type", + "task_repository", + "state_done", + "state_todo", + "state_open_file", + "state_working_dir", + "state_extra_environment", + "action_name", + "action_description", + "action_target", + "action_tool", + "result_type", + "result_description", + "result_exit_code", + ): + setattr(existing, attr, getattr(fresh, attr)) + if task_vec is not None: + existing.task_embedding = task_vec + if state_vec is not None: + existing.state_embedding = state_vec + if task_state_vec is not None: + existing.task_state_embedding = task_state_vec + session.add(existing) + + async def search_by_text( + self, text: str, field: str = "task_state", limit: int = 10 + ) -> List[MemoryUnit]: + if self._embeddings is None: + return [] + q_vec = _ensure_dim(self._embeddings.embed([text])[0]) + async with self._sessionmaker() as session: + # Choose column by field + col_expr = { + "task": MemoryUnitDB.task_embedding, + "state": MemoryUnitDB.state_embedding, + "task_state": MemoryUnitDB.task_state_embedding, + }.get(field, MemoryUnitDB.task_state_embedding) + + # Order by cosine distance using pgvector `<=>` + res = await session.execute( + select(MemoryUnitDB) + .where(col_expr.is_not(None)) + .order_by(col_expr.cosine_distance(q_vec)) + .limit(limit) + ) + rows: List[MemoryUnitDB] = list(res.scalars()) + return [r.to_memory_unit() for r in rows] + + async def get_by_memory_id(self, memory_id: str) -> Optional[MemoryUnit]: + async with self._sessionmaker() as session: + row = await session.scalar( + select(MemoryUnitDB).where(col(MemoryUnitDB.memory_id) == memory_id) + ) + return None if row is None else row.to_memory_unit() + + async def list_all(self, limit: Optional[int] = None) -> List[MemoryUnit]: + async with self._sessionmaker() as session: + stmt = select(MemoryUnitDB).order_by(MemoryUnitDB.id.desc()) + if limit is not None and limit > 0: + stmt = stmt.limit(limit) + res = await session.execute(stmt) + rows: List[MemoryUnitDB] = list(res.scalars()) + return [r.to_memory_unit() for r in rows] + + async def clear_all(self) -> None: + async with self._sessionmaker() as session: + # Delete all rows from table + await session.execute("DELETE FROM memory_units") + await session.commit() + + @staticmethod + def _serialize_task(u: MemoryUnit) -> str: + t = u.task + parts = [t.issue_title, t.issue_type, t.repository, t.issue_body, t.issue_comments] + return "\n".join([p for p in parts if p]) + + @staticmethod + def _serialize_state(u: MemoryUnit) -> str: + s = u.state + parts = [s.done, s.todo] + if s.open_file: + parts.append(f"open_file: {s.open_file}") + if s.working_dir: + parts.append(f"working_dir: {s.working_dir}") + if s.extra_environment: + parts.append(f"env: {s.extra_environment}") + return "\n".join([str(p) for p in parts if p]) + + @staticmethod + def _serialize_task_state(u: MemoryUnit) -> str: # TODO: text templates for static embeddings + return ( + MemoryStorageService._serialize_task(u) + + "\n\n" + + MemoryStorageService._serialize_state(u) + ) diff --git a/athena/scripts/offline_ingest_hf.py b/athena/scripts/offline_ingest_hf.py new file mode 100644 index 0000000..48ffd70 --- /dev/null +++ b/athena/scripts/offline_ingest_hf.py @@ -0,0 +1,47 @@ +import asyncio +import os + +from athena.app.services.database_service import DatabaseService +from athena.app.services.embedding_service import EmbeddingService +from athena.app.services.llm_service import LLMService +from athena.app.services.memory_extraction_service import MemoryExtractionService +from athena.app.services.memory_storage_service import MemoryStorageService +from athena.configuration.config import settings + + +async def main(): + db = DatabaseService(settings.DATABASE_URL) + await db.start() + + embedding_service = None + if settings.EMBEDDINGS_MODEL and settings.EMBEDDINGS_API_KEY and settings.EMBEDDINGS_BASE_URL: + embedding_service = EmbeddingService( + model=settings.EMBEDDINGS_MODEL, + api_key=settings.EMBEDDINGS_API_KEY, + base_url=settings.EMBEDDINGS_BASE_URL, + ) + + store = MemoryStorageService(db.get_sessionmaker(), embedding_service) + + llm = LLMService( + model_name=settings.MODEL_NAME, + model_temperature=settings.MODEL_TEMPERATURE, + model_max_input_tokens=settings.MODEL_MAX_INPUT_TOKENS, + model_max_output_tokens=settings.MODEL_MAX_OUTPUT_TOKENS or 4096, + openai_format_api_key=settings.OPENAI_FORMAT_API_KEY, + openai_format_base_url=settings.OPENAI_FORMAT_BASE_URL, + anthropic_api_key=settings.ANTHROPIC_API_KEY, + gemini_api_key=settings.GEMINI_API_KEY, + ) + + extractor = MemoryExtractionService(llm_service=llm, memory_store=store) + + repo = os.environ.get("HF_REPO", "SWE-Gym/OpenHands-SFT-Trajectories") + split = os.environ.get("HF_SPLIT", "train.success.oss") + extractor.extract_from_huggingface_trajectory_repository(repo, split) + + await db.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docker-compose.yml b/docker-compose.yml index 07bd08a..b7d03c5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ networks: services: postgres: - image: postgres + image: pgvector/pgvector:pg16 container_name: postgres_container networks: - athena_network diff --git a/example.env b/example.env index de521ed..e07d8ec 100644 --- a/example.env +++ b/example.env @@ -21,3 +21,7 @@ ATHENA_MODEL_MAX_OUTPUT_TOKENS=15000 # Database settings ATHENA_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/postgres +ATHENA_EMBEDDINGS_MODEL=codestral-embed-2505 +ATHENA_EMBEDDINGS_API_KEY=... +ATHENA_EMBEDDINGS_BASE_URL=https://your-openai-compatible-gateway +ATHENA_EMBEDDINGS_DIM=1024 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 89610b0..cb38909 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,15 @@ dependencies = [ "dynaconf>=3.2.6", "sqlmodel==0.0.24", "asyncpg", + "pgvector==0.3.4", + "requests>=2.31.0", "langchain-anthropic==0.3.0", "langchain-openai==0.2.8", "langchain-google-genai==2.0.4", "langchain_community==0.3.2", - "langchain_google_vertexai==2.1.0" + "langchain_google_vertexai==2.1.0", + "datasets>=2.20.0", + "tqdm>=4.66.0", ] requires-python = ">= 3.11" From 8f1415832a9cf26421b7998e6674af1ae0d4071b Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 22:23:22 +0800 Subject: [PATCH 2/3] feat: refine memory storage --- athena/app/api/main.py | 28 ++++++++++++-- athena/app/dependencies.py | 20 +++++----- athena/app/main.py | 15 ++++++-- athena/app/services/database_service.py | 38 +++++++++++++++++-- athena/app/services/embedding_service.py | 5 ++- athena/app/services/memory_service.py | 4 +- athena/app/services/memory_storage_service.py | 6 +-- athena/configuration/config.py | 9 +++-- athena/scripts/offline_ingest_hf.py | 20 ++++++---- example.env | 11 ++++-- pyproject.toml | 14 ++++--- 11 files changed, 123 insertions(+), 47 deletions(-) diff --git a/athena/app/api/main.py b/athena/app/api/main.py index f1722f1..9b884b2 100644 --- a/athena/app/api/main.py +++ b/athena/app/api/main.py @@ -1,12 +1,34 @@ -from fastapi import APIRouter, Request +from fastapi import APIRouter, HTTPException, Query, Request + +from athena.configuration.config import settings api_router = APIRouter() @api_router.get("/memory/search", tags=["memory"]) -async def search_memory(request: Request, q: str, field: str = "task_state", limit: int = 10): +async def search_memory( + request: Request, + q: str, + field: str = "task_state", + limit: int = Query(default=settings.MEMORY_SEARCH_LIMIT, ge=1, le=100), +): services = getattr(request.app.state, "service", {}) memory_service = services.get("memory_service") - assert memory_service is not None, "Memory service not initialized" + if memory_service is None: + raise HTTPException(status_code=503, detail="Memory service not initialized") + if field not in {"task_state", "task", "state"}: + raise HTTPException(status_code=400, detail=f"Invalid field: {field}") results = await memory_service.search_memory(q, limit=limit, field=field) return [m.model_dump() for m in results] + + +@api_router.get("/memory/{memory_id}", tags=["memory"]) +async def get_memory(request: Request, memory_id: str): + services = getattr(request.app.state, "service", {}) + memory_service = services.get("memory_service") + if memory_service is None: + raise HTTPException(status_code=503, detail="Memory service not initialized") + result = await memory_service.get_memory_by_key(memory_id) + if result is None: + raise HTTPException(status_code=404, detail="Memory not found") + return result.model_dump() diff --git a/athena/app/dependencies.py b/athena/app/dependencies.py index 1495754..dc61c2e 100644 --- a/athena/app/dependencies.py +++ b/athena/app/dependencies.py @@ -21,12 +21,10 @@ def initialize_services() -> Dict[str, BaseService]: Note: This function assumes all required settings are properly configured in the settings module using Dynaconf. The following settings are required: - - NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD - - LITELLM_MODEL - - NEO4J_BATCH_SIZE - - KNOWLEDGE_GRAPH_MAX_AST_DEPTH - - WORKING_DIRECTORY - - GITHUB_ACCESS_TOKEN + - LLM_MODEL + - EMBEDDING_MODEL + - DATABASE_URL + - MEMORY_STORAGE_BACKEND Returns: A fully configured ServiceCoordinator instance managing all services. @@ -44,11 +42,13 @@ def initialize_services() -> Dict[str, BaseService]: ) embedding_service = None - if settings.EMBEDDINGS_MODEL and settings.EMBEDDINGS_API_KEY and settings.EMBEDDINGS_BASE_URL: + api_key = settings.EMBEDDING_API_KEY or settings.MISTRAL_API_KEY + if settings.EMBEDDING_MODEL and api_key and settings.EMBEDDING_BASE_URL: embedding_service = EmbeddingService( - model=settings.EMBEDDINGS_MODEL, - api_key=settings.EMBEDDINGS_API_KEY, - base_url=settings.EMBEDDINGS_BASE_URL, + model=settings.EMBEDDING_MODEL, + api_key=api_key, + base_url=settings.EMBEDDING_BASE_URL, + embed_dim=settings.EMBEDDING_DIM or 1024, ) memory_store = MemoryStorageService(database_service.get_sessionmaker(), embedding_service) diff --git a/athena/app/main.py b/athena/app/main.py index 610e817..c0a84d0 100644 --- a/athena/app/main.py +++ b/athena/app/main.py @@ -45,7 +45,8 @@ def custom_generate_unique_id(route: APIRoute) -> str: """ Custom function to generate unique IDs for API routes based on their tags and names. """ - return f"{route.tags[0]}-{route.name}" + tag = route.tags[0] if getattr(route, "tags", None) else "default" + return f"{tag}-{route.name}" app = FastAPI( @@ -74,5 +75,13 @@ def custom_generate_unique_id(route: APIRoute) -> str: @app.get("/health", tags=["health"]) -def health_check(): - return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()} +async def health_check(): + services = getattr(app.state, "service", {}) + db = services.get("database_service") + db_ok = await db.health_check() if db is not None else False + status = "healthy" if db_ok else "degraded" + return { + "status": status, + "database": db_ok, + "timestamp": datetime.now(timezone.utc).isoformat(), + } diff --git a/athena/app/services/database_service.py b/athena/app/services/database_service.py index 69726fd..d5190c8 100644 --- a/athena/app/services/database_service.py +++ b/athena/app/services/database_service.py @@ -1,3 +1,5 @@ +import asyncio + from pgvector.sqlalchemy import register_vector from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlmodel import SQLModel @@ -7,12 +9,14 @@ class DatabaseService(BaseService): - def __init__(self, DATABASE_URL: str): + def __init__(self, DATABASE_URL: str, max_retries: int = 5, initial_backoff: float = 1.0): self.engine = create_async_engine(DATABASE_URL, echo=True) self.sessionmaker = async_sessionmaker( self.engine, expire_on_commit=False, class_=AsyncSession ) self._logger = get_logger(__name__) + self._max_retries = max_retries + self._initial_backoff = initial_backoff # Create the database and tables async def create_db_and_tables(self): @@ -49,8 +53,26 @@ async def start(self): Start the database service by creating the database and tables. This method is called when the service is initialized. """ - await self.create_db_and_tables() - self._logger.info("Database and tables created successfully.") + attempt = 0 + backoff = self._initial_backoff + while True: + try: + await self.create_db_and_tables() + self._logger.info("Database and tables created successfully.") + break + except Exception as exc: + attempt += 1 + if attempt > self._max_retries: + self._logger.error( + f"Database start failed after {self._max_retries} retries: {exc}" + ) + raise + self._logger.warning( + f"Database start failed (attempt {attempt}/{self._max_retries}): {exc}. " + f"Retrying in {backoff:.1f}s..." + ) + await asyncio.sleep(backoff) + backoff *= 2 async def close(self): """ @@ -62,3 +84,13 @@ async def close(self): def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]: """Return the async sessionmaker for dependency injection.""" return self.sessionmaker + + async def health_check(self) -> bool: + """Perform a lightweight connectivity check (SELECT 1).""" + try: + async with self.engine.connect() as conn: + await conn.exec_driver_sql("SELECT 1") + return True + except Exception as exc: + self._logger.warning(f"Database health_check failed: {exc}") + return False diff --git a/athena/app/services/embedding_service.py b/athena/app/services/embedding_service.py index d4e43ec..872394d 100644 --- a/athena/app/services/embedding_service.py +++ b/athena/app/services/embedding_service.py @@ -13,13 +13,14 @@ class EmbeddingService(BaseService): behind OpenAI-format gateways. Configure via model, api_key, base_url. """ - def __init__(self, model: str, api_key: str, base_url: str): + def __init__(self, model: str, api_key: str, base_url: str, embed_dim: int): self.model = model self.api_key = api_key self.base_url = base_url.rstrip("/") + self.embed_dim = embed_dim def embed(self, inputs: Iterable[str]) -> List[List[float]]: - data = {"model": self.model, "input": list(inputs)} + data = {"model": self.model, "input": list(inputs), "output_dimension": self.embed_dim} headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", diff --git a/athena/app/services/memory_service.py b/athena/app/services/memory_service.py index de41fd4..af740ac 100644 --- a/athena/app/services/memory_service.py +++ b/athena/app/services/memory_service.py @@ -114,10 +114,10 @@ async def search_memory( async def get_memory_by_key(self, key: str) -> Optional[MemoryUnit]: """ - Retrieve a specific memory unit by its canonical key. + Retrieve a specific memory unit by its memory id. Args: - key: The canonical key of the memory unit (generated by MemoryUnit.key()) + key: The id of the memory unit Returns: The MemoryUnit object if found, None otherwise. diff --git a/athena/app/services/memory_storage_service.py b/athena/app/services/memory_storage_service.py index bf0b657..1d2e308 100644 --- a/athena/app/services/memory_storage_service.py +++ b/athena/app/services/memory_storage_service.py @@ -12,7 +12,7 @@ def _ensure_dim(vec: List[float]) -> List[float]: - dim = settings.EMBEDDINGS_DIM or len(vec) + dim = settings.EMBEDDING_DIM or len(vec) if len(vec) == dim: return vec if len(vec) > dim: @@ -79,11 +79,11 @@ async def _upsert_one(self, session: AsyncSession, unit: MemoryUnit) -> None: # Update fields fresh = MemoryUnitDB.from_memory_unit(unit) for attr in ( - "memory_source_name", - "memory_run_id", "memory_created_at", "memory_updated_at", "memory_invalid_at", + "memory_source_name", + "memory_run_id", "memory_metadata", "task_issue_title", "task_issue_body", diff --git a/athena/configuration/config.py b/athena/configuration/config.py index 7c169ac..d48430a 100644 --- a/athena/configuration/config.py +++ b/athena/configuration/config.py @@ -24,6 +24,7 @@ class Settings(BaseSettings): # API Keys ANTHROPIC_API_KEY: Optional[str] = None GEMINI_API_KEY: Optional[str] = None + MISTRAL_API_KEY: Optional[str] = None OPENAI_FORMAT_BASE_URL: Optional[str] = None OPENAI_FORMAT_API_KEY: Optional[str] = None @@ -41,10 +42,10 @@ class Settings(BaseSettings): MEMORY_SEARCH_LIMIT: int = 10 # Default search result limit # Embeddings (OpenAI-format, works with Codestral embed via a compatible gateway) - EMBEDDINGS_MODEL: Optional[str] = None - EMBEDDINGS_API_KEY: Optional[str] = None - EMBEDDINGS_BASE_URL: Optional[str] = None - EMBEDDINGS_DIM: Optional[int] = None + EMBEDDING_MODEL: Optional[str] = None + EMBEDDING_API_KEY: Optional[str] = None + EMBEDDING_BASE_URL: Optional[str] = None + EMBEDDING_DIM: Optional[int] = None settings = Settings() diff --git a/athena/scripts/offline_ingest_hf.py b/athena/scripts/offline_ingest_hf.py index 48ffd70..2baf223 100644 --- a/athena/scripts/offline_ingest_hf.py +++ b/athena/scripts/offline_ingest_hf.py @@ -1,3 +1,4 @@ +import argparse import asyncio import os @@ -10,15 +11,22 @@ async def main(): + parser = argparse.ArgumentParser(description="Offline ingest from HF trajectories") + parser.add_argument("--repo", default=os.environ.get("HF_REPO")) + parser.add_argument("--split", default=os.environ.get("HF_SPLIT")) + args = parser.parse_args() + db = DatabaseService(settings.DATABASE_URL) await db.start() embedding_service = None - if settings.EMBEDDINGS_MODEL and settings.EMBEDDINGS_API_KEY and settings.EMBEDDINGS_BASE_URL: + api_key = settings.EMBEDDING_API_KEY or settings.MISTRAL_API_KEY + if settings.EMBEDDING_MODEL and api_key and settings.EMBEDDING_BASE_URL: embedding_service = EmbeddingService( - model=settings.EMBEDDINGS_MODEL, - api_key=settings.EMBEDDINGS_API_KEY, - base_url=settings.EMBEDDINGS_BASE_URL, + model=settings.EMBEDDING_MODEL, + api_key=api_key, + base_url=settings.EMBEDDING_BASE_URL, + embed_dim=settings.EMBEDDING_DIM or 1024, ) store = MemoryStorageService(db.get_sessionmaker(), embedding_service) @@ -36,9 +44,7 @@ async def main(): extractor = MemoryExtractionService(llm_service=llm, memory_store=store) - repo = os.environ.get("HF_REPO", "SWE-Gym/OpenHands-SFT-Trajectories") - split = os.environ.get("HF_SPLIT", "train.success.oss") - extractor.extract_from_huggingface_trajectory_repository(repo, split) + extractor.extract_from_huggingface_trajectory_repository(args.repo, args.split) await db.close() diff --git a/example.env b/example.env index e07d8ec..951e4ea 100644 --- a/example.env +++ b/example.env @@ -11,6 +11,7 @@ ATHENA_MODEL_NAME=gpt-4o # API keys for various LLM providers ATHENA_ANTHROPIC_API_KEY=anthropic_api_key ATHENA_GEMINI_API_KEY=gemini_api_key +ATHENA_MISTRAL_API_KEY=mistral_api_key ATHENA_OPENAI_FORMAT_BASE_URL=https://api.openai.com/v1 ATHENA_OPENAI_FORMAT_API_KEY=your_api_key @@ -19,9 +20,11 @@ ATHENA_MODEL_MAX_INPUT_TOKENS=64000 ATHENA_MODEL_TEMPERATURE=0.3 ATHENA_MODEL_MAX_OUTPUT_TOKENS=15000 +# Embedding model settings +ATHENA_EMBEDDING_MODEL=codestral-embed-2505 +ATHENA_EMBEDDING_API_KEY=embedding_api_key +ATHENA_EMBEDDING_BASE_URL=https://api.mistral.ai +ATHENA_EMBEDDING_DIM=1024 + # Database settings ATHENA_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/postgres -ATHENA_EMBEDDINGS_MODEL=codestral-embed-2505 -ATHENA_EMBEDDINGS_API_KEY=... -ATHENA_EMBEDDINGS_BASE_URL=https://your-openai-compatible-gateway -ATHENA_EMBEDDINGS_DIM=1024 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cb38909..a72eac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,20 @@ build-backend = "hatchling.build" name = "Athena" version = "0.0.1" dependencies = [ - "langchain==0.3.3", + "langchain==0.3.27", "fastapi[standard]>=0.115.2", "dynaconf>=3.2.6", "sqlmodel==0.0.24", "asyncpg", "pgvector==0.3.4", "requests>=2.31.0", - "langchain-anthropic==0.3.0", - "langchain-openai==0.2.8", - "langchain-google-genai==2.0.4", - "langchain_community==0.3.2", - "langchain_google_vertexai==2.1.0", + "langchain-core==0.3.76", + "langchain-anthropic==0.3.20", + "langchain-openai==0.3.33", + "langchain-google-genai==2.1.10", + "langchain-community==0.3.29", + "langchain-google-vertexai==2.1.0", + "langchain-mistralai==0.2.11", "datasets>=2.20.0", "tqdm>=4.66.0", ] From cd1ea86e584e11f869bf550c4503812ad768ae92 Mon Sep 17 00:00:00 2001 From: Zhaoyang-Chu Date: Mon, 15 Sep 2025 23:15:13 +0800 Subject: [PATCH 3/3] feat: refine memory storage --- athena/app/dependencies.py | 1 + athena/app/services/database_service.py | 20 ++++++++------------ athena/app/services/llm_service.py | 4 ++++ athena/configuration/config.py | 2 ++ athena/entity/memory.py | 12 ++++++++---- example.env | 2 ++ 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/athena/app/dependencies.py b/athena/app/dependencies.py index dc61c2e..d292359 100644 --- a/athena/app/dependencies.py +++ b/athena/app/dependencies.py @@ -39,6 +39,7 @@ def initialize_services() -> Dict[str, BaseService]: settings.OPENAI_FORMAT_BASE_URL, settings.ANTHROPIC_API_KEY, settings.GEMINI_API_KEY, + settings.GOOGLE_APPLICATION_CREDENTIALS, ) embedding_service = None diff --git a/athena/app/services/database_service.py b/athena/app/services/database_service.py index d5190c8..57c903e 100644 --- a/athena/app/services/database_service.py +++ b/athena/app/services/database_service.py @@ -1,10 +1,11 @@ import asyncio -from pgvector.sqlalchemy import register_vector +from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlmodel import SQLModel from athena.app.services.base_service import BaseService +from athena.configuration.config import settings from athena.utils.logger_manager import get_logger @@ -21,28 +22,23 @@ def __init__(self, DATABASE_URL: str, max_retries: int = 5, initial_backoff: flo # Create the database and tables async def create_db_and_tables(self): async with self.engine.begin() as conn: - # Try to enable pgvector if available (safe to ignore errors on non-Postgres or missing extension) + # Ensure pgvector extension exists (safe to ignore if unavailable) try: - await conn.exec_driver_sql("CREATE EXTENSION IF NOT EXISTS vector") - except Exception: - # Extension not available; proceed without it - pass - # Ensure vector types are registered with SQLAlchemy - try: - register_vector(conn.sync_connection) + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) except Exception: pass await conn.run_sync(SQLModel.metadata.create_all) # Create ivfflat indexes for vector columns (if extension present) try: + lists = settings.EMBEDDING_IVFFLAT_LISTS or 100 await conn.exec_driver_sql( - "CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON memory_units USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = 100)" + f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON memory_units USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = {lists})" ) await conn.exec_driver_sql( - "CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON memory_units USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = 100)" + f"CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON memory_units USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = {lists})" ) await conn.exec_driver_sql( - "CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON memory_units USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = 100)" + f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON memory_units USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = {lists})" ) except Exception: # Index creation failed (likely no pgvector). Continue without indexes. diff --git a/athena/app/services/llm_service.py b/athena/app/services/llm_service.py index fe39fbd..419f016 100644 --- a/athena/app/services/llm_service.py +++ b/athena/app/services/llm_service.py @@ -20,6 +20,7 @@ def __init__( openai_format_base_url: Optional[str] = None, anthropic_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None, + google_application_credentials: Optional[str] = None, ): self.model = get_model( model_name, @@ -30,6 +31,7 @@ def __init__( openai_format_base_url=openai_format_base_url, anthropic_api_key=anthropic_api_key, gemini_api_key=gemini_api_key, + google_application_credentials=google_application_credentials, ) @@ -42,6 +44,7 @@ def get_model( openai_format_base_url: Optional[str] = None, anthropic_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None, + google_application_credentials: Optional[str] = None, ) -> BaseChatModel: if "claude" in model_name: return ChatAnthropic( @@ -61,6 +64,7 @@ def get_model( temperature=temperature, max_output_tokens=max_output_tokens, max_retries=3, + credentials=google_application_credentials, ) elif "gemini" in model_name: return ChatGoogleGenerativeAI( diff --git a/athena/configuration/config.py b/athena/configuration/config.py index d48430a..c0024e7 100644 --- a/athena/configuration/config.py +++ b/athena/configuration/config.py @@ -27,6 +27,7 @@ class Settings(BaseSettings): MISTRAL_API_KEY: Optional[str] = None OPENAI_FORMAT_BASE_URL: Optional[str] = None OPENAI_FORMAT_API_KEY: Optional[str] = None + GOOGLE_APPLICATION_CREDENTIALS: Optional[str] = None # Model parameters MODEL_MAX_INPUT_TOKENS: int @@ -46,6 +47,7 @@ class Settings(BaseSettings): EMBEDDING_API_KEY: Optional[str] = None EMBEDDING_BASE_URL: Optional[str] = None EMBEDDING_DIM: Optional[int] = None + EMBEDDING_IVFFLAT_LISTS: Optional[int] = None settings = Settings() diff --git a/athena/entity/memory.py b/athena/entity/memory.py index a816010..eb2730c 100644 --- a/athena/entity/memory.py +++ b/athena/entity/memory.py @@ -6,6 +6,7 @@ from sqlalchemy import Column, DateTime, Text from sqlmodel import Field, SQLModel, UniqueConstraint +from athena.configuration.config import settings from athena.models import Action, MemorySource, MemoryTimestamp, MemoryUnit, Result, State, Task @@ -60,10 +61,13 @@ class MemoryUnitDB(SQLModel, table=True): result_description: str = Field(sa_column=Column(Text)) result_exit_code: Optional[int] = None - # Embeddings for semantic retrieval (optional, pgvector recommended). Stored as float arrays - task_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) - state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) - task_state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector())) + # Embeddings for semantic retrieval (pgvector). Fix dimension via settings, default 1024 + _vec_dim: int = settings.EMBEDDING_DIM or 1024 # type: ignore[assignment] + task_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector(_vec_dim))) + state_embedding: Optional[list[float]] = Field(default=None, sa_column=Column(Vector(_vec_dim))) + task_state_embedding: Optional[list[float]] = Field( + default=None, sa_column=Column(Vector(_vec_dim)) + ) @classmethod def from_memory_unit(cls, memory_unit: MemoryUnit) -> "MemoryUnitDB": diff --git a/example.env b/example.env index 951e4ea..ab66400 100644 --- a/example.env +++ b/example.env @@ -14,6 +14,7 @@ ATHENA_GEMINI_API_KEY=gemini_api_key ATHENA_MISTRAL_API_KEY=mistral_api_key ATHENA_OPENAI_FORMAT_BASE_URL=https://api.openai.com/v1 ATHENA_OPENAI_FORMAT_API_KEY=your_api_key +ATHENA_GOOGLE_APPLICATION_CREDENTIALS=/abs/path/to/your-service-account.json # Model settings ATHENA_MODEL_MAX_INPUT_TOKENS=64000 @@ -25,6 +26,7 @@ ATHENA_EMBEDDING_MODEL=codestral-embed-2505 ATHENA_EMBEDDING_API_KEY=embedding_api_key ATHENA_EMBEDDING_BASE_URL=https://api.mistral.ai ATHENA_EMBEDDING_DIM=1024 +ATHENA_EMBEDDING_IVFFLAT_LISTS=100 # Database settings ATHENA_DATABASE_URL=postgresql+asyncpg://postgres:password@postgres:5432/postgres