Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion athena/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Query, Request

from athena.configuration.config import settings

api_router = APIRouter()


@api_router.get("/memory/search", tags=["memory"])
async def search_memory(
request: Request,
q: str,
field: str = "task_state",
limit: int = Query(default=settings.MEMORY_SEARCH_LIMIT, ge=1, le=100),
):
services = getattr(request.app.state, "service", {})
memory_service = services.get("memory_service")
if memory_service is None:
raise HTTPException(status_code=503, detail="Memory service not initialized")
if field not in {"task_state", "task", "state"}:
raise HTTPException(status_code=400, detail=f"Invalid field: {field}")
results = await memory_service.search_memory(q, limit=limit, field=field)
return [m.model_dump() for m in results]


@api_router.get("/memory/{memory_id}", tags=["memory"])
async def get_memory(request: Request, memory_id: str):
services = getattr(request.app.state, "service", {})
memory_service = services.get("memory_service")
if memory_service is None:
raise HTTPException(status_code=503, detail="Memory service not initialized")
result = await memory_service.get_memory_by_key(memory_id)
if result is None:
raise HTTPException(status_code=404, detail="Memory not found")
return result.model_dump()
29 changes: 22 additions & 7 deletions athena/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from athena.app.services.base_service import BaseService
from athena.app.services.database_service import DatabaseService
from athena.app.services.embedding_service import EmbeddingService
from athena.app.services.llm_service import LLMService
from athena.app.services.memory_service import MemoryService
from athena.app.services.memory_storage_service import MemoryStorageService
from athena.configuration.config import settings


Expand All @@ -19,12 +21,10 @@ def initialize_services() -> Dict[str, BaseService]:
Note:
This function assumes all required settings are properly configured in
the settings module using Dynaconf. The following settings are required:
- NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD
- LITELLM_MODEL
- NEO4J_BATCH_SIZE
- KNOWLEDGE_GRAPH_MAX_AST_DEPTH
- WORKING_DIRECTORY
- GITHUB_ACCESS_TOKEN
- LLM_MODEL
- EMBEDDING_MODEL
- DATABASE_URL
- MEMORY_STORAGE_BACKEND

Returns:
A fully configured ServiceCoordinator instance managing all services.
Expand All @@ -39,9 +39,24 @@ def initialize_services() -> Dict[str, BaseService]:
settings.OPENAI_FORMAT_BASE_URL,
settings.ANTHROPIC_API_KEY,
settings.GEMINI_API_KEY,
settings.GOOGLE_APPLICATION_CREDENTIALS,
)

memory_service = MemoryService(storage_backend=settings.MEMORY_STORAGE_BACKEND)
embedding_service = None
api_key = settings.EMBEDDING_API_KEY or settings.MISTRAL_API_KEY
if settings.EMBEDDING_MODEL and api_key and settings.EMBEDDING_BASE_URL:
embedding_service = EmbeddingService(
model=settings.EMBEDDING_MODEL,
api_key=api_key,
base_url=settings.EMBEDDING_BASE_URL,
embed_dim=settings.EMBEDDING_DIM or 1024,
)

memory_store = MemoryStorageService(database_service.get_sessionmaker(), embedding_service)

memory_service = MemoryService(
storage_backend=settings.MEMORY_STORAGE_BACKEND, store=memory_store
)

return {
"llm_service": llm_service,
Expand Down
15 changes: 12 additions & 3 deletions athena/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def custom_generate_unique_id(route: APIRoute) -> str:
"""
Custom function to generate unique IDs for API routes based on their tags and names.
"""
return f"{route.tags[0]}-{route.name}"
tag = route.tags[0] if getattr(route, "tags", None) else "default"
return f"{tag}-{route.name}"


app = FastAPI(
Expand Down Expand Up @@ -74,5 +75,13 @@ def custom_generate_unique_id(route: APIRoute) -> str:


@app.get("/health", tags=["health"])
def health_check():
return {"status": "healthy", "timestamp": datetime.now(timezone.utc).isoformat()}
async def health_check():
services = getattr(app.state, "service", {})
db = services.get("database_service")
db_ok = await db.health_check() if db is not None else False
status = "healthy" if db_ok else "degraded"
return {
"status": status,
"database": db_ok,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
69 changes: 65 additions & 4 deletions athena/app/services/database_service.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,92 @@
from sqlalchemy.ext.asyncio import create_async_engine
import asyncio

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel

from athena.app.services.base_service import BaseService
from athena.configuration.config import settings
from athena.utils.logger_manager import get_logger


class DatabaseService(BaseService):
def __init__(self, DATABASE_URL: str):
def __init__(self, DATABASE_URL: str, max_retries: int = 5, initial_backoff: float = 1.0):
self.engine = create_async_engine(DATABASE_URL, echo=True)
self.sessionmaker = async_sessionmaker(
self.engine, expire_on_commit=False, class_=AsyncSession
)
self._logger = get_logger(__name__)
self._max_retries = max_retries
self._initial_backoff = initial_backoff

# Create the database and tables
async def create_db_and_tables(self):
async with self.engine.begin() as conn:
# Ensure pgvector extension exists (safe to ignore if unavailable)
try:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
except Exception:
pass
await conn.run_sync(SQLModel.metadata.create_all)
# Create ivfflat indexes for vector columns (if extension present)
try:
lists = settings.EMBEDDING_IVFFLAT_LISTS or 100
await conn.exec_driver_sql(
f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_embedding ON memory_units USING ivfflat (task_embedding vector_cosine_ops) WITH (lists = {lists})"
)
await conn.exec_driver_sql(
f"CREATE INDEX IF NOT EXISTS idx_memory_units_state_embedding ON memory_units USING ivfflat (state_embedding vector_cosine_ops) WITH (lists = {lists})"
)
await conn.exec_driver_sql(
f"CREATE INDEX IF NOT EXISTS idx_memory_units_task_state_embedding ON memory_units USING ivfflat (task_state_embedding vector_cosine_ops) WITH (lists = {lists})"
)
except Exception:
# Index creation failed (likely no pgvector). Continue without indexes.
pass

async def start(self):
"""
Start the database service by creating the database and tables.
This method is called when the service is initialized.
"""
await self.create_db_and_tables()
self._logger.info("Database and tables created successfully.")
attempt = 0
backoff = self._initial_backoff
while True:
try:
await self.create_db_and_tables()
self._logger.info("Database and tables created successfully.")
break
except Exception as exc:
attempt += 1
if attempt > self._max_retries:
self._logger.error(
f"Database start failed after {self._max_retries} retries: {exc}"
)
raise
self._logger.warning(
f"Database start failed (attempt {attempt}/{self._max_retries}): {exc}. "
f"Retrying in {backoff:.1f}s..."
)
await asyncio.sleep(backoff)
backoff *= 2

async def close(self):
"""
Close the database connection and release any resources.
"""
await self.engine.dispose()
self._logger.info("Database connection closed.")

def get_sessionmaker(self) -> async_sessionmaker[AsyncSession]:
"""Return the async sessionmaker for dependency injection."""
return self.sessionmaker

async def health_check(self) -> bool:
"""Perform a lightweight connectivity check (SELECT 1)."""
try:
async with self.engine.connect() as conn:
await conn.exec_driver_sql("SELECT 1")
return True
except Exception as exc:
self._logger.warning(f"Database health_check failed: {exc}")
return False
33 changes: 33 additions & 0 deletions athena/app/services/embedding_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Iterable, List

import requests

from athena.app.services.base_service import BaseService


class EmbeddingService(BaseService):
"""
Simple OpenAI-compatible embedding client.

Works with providers exposing POST /v1/embeddings, including Codestral embed deployments
behind OpenAI-format gateways. Configure via model, api_key, base_url.
"""

def __init__(self, model: str, api_key: str, base_url: str, embed_dim: int):
self.model = model
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.embed_dim = embed_dim

def embed(self, inputs: Iterable[str]) -> List[List[float]]:
data = {"model": self.model, "input": list(inputs), "output_dimension": self.embed_dim}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}/v1/embeddings"
resp = requests.post(url, json=data, headers=headers, timeout=60)
resp.raise_for_status()
payload = resp.json()
vectors = [item["embedding"] for item in payload.get("data", [])]
return vectors
4 changes: 4 additions & 0 deletions athena/app/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
openai_format_base_url: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
gemini_api_key: Optional[str] = None,
google_application_credentials: Optional[str] = None,
):
self.model = get_model(
model_name,
Expand All @@ -30,6 +31,7 @@ def __init__(
openai_format_base_url=openai_format_base_url,
anthropic_api_key=anthropic_api_key,
gemini_api_key=gemini_api_key,
google_application_credentials=google_application_credentials,
)


Expand All @@ -42,6 +44,7 @@ def get_model(
openai_format_base_url: Optional[str] = None,
anthropic_api_key: Optional[str] = None,
gemini_api_key: Optional[str] = None,
google_application_credentials: Optional[str] = None,
) -> BaseChatModel:
if "claude" in model_name:
return ChatAnthropic(
Expand All @@ -61,6 +64,7 @@ def get_model(
temperature=temperature,
max_output_tokens=max_output_tokens,
max_retries=3,
credentials=google_application_credentials,
)
elif "gemini" in model_name:
return ChatGoogleGenerativeAI(
Expand Down
Loading