Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions app/core/async_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import functools
import hashlib
import os
import time
from collections.abc import Callable
Expand Down Expand Up @@ -156,6 +157,8 @@ async def wrapper(*args, **kwargs):
async_provider_service_cache: "AsyncCache" = _AsyncBackend(
ttl_seconds=3600
) # 1-hour TTL
# OAuth2 token caching (no TTL - uses token's own expiration with smart cleanup)
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None)


# User-specific functions
Expand Down Expand Up @@ -336,6 +339,98 @@ async def invalidate_provider_service_cache_async(user_id: int) -> None:
)


# OAuth2 token caching functions
async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None:
"""Get a cached OAuth2 token by API key asynchronously"""
if not api_key:
return None

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
cached_data = await async_oauth_token_cache.get(cache_key)
if not cached_data:
return None

expires_at = cached_data.get("expires_at")
if not expires_at:
await async_oauth_token_cache.delete(cache_key)
return None

current_time = time.time()
if expires_at <= current_time:
await async_oauth_token_cache.delete(cache_key)
await _opportunistic_cleanup(current_time, max_items=2)
return None

return cached_data


async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None:
"""Cache an OAuth2 token by API key asynchronously"""
if not api_key or not token_data:
return

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
if "expires_at" not in token_data:
logger.warning("OAuth token cached without expires_at - skipping")
return

await async_oauth_token_cache.set(cache_key, token_data)


async def invalidate_oauth_token_cache_async(api_key: str) -> None:
"""Invalidate OAuth2 token cache for a specific API key asynchronously"""
if not api_key:
return

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
await async_oauth_token_cache.delete(cache_key)
if DEBUG_CACHE:
logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {cache_key[:16]}...")


async def _opportunistic_cleanup(current_time: float, max_items: int = 2) -> None:
"""Opportunistically clean up expired OAuth tokens from cache"""
cleaned = 0

# Case 1: in-memory backend exposes .cache dict
if hasattr(async_oauth_token_cache, "cache"):
async with async_oauth_token_cache.lock:
for key, value in list(async_oauth_token_cache.cache.items()):
if cleaned >= max_items:
break
if key.startswith("token:"):
expires_at = value.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(key)
cleaned += 1
if DEBUG_CACHE:
logger.debug(f"Cache: Cleaned up expired token: {key[:16]}...")

# Case 2: Redis backend
elif hasattr(async_oauth_token_cache, "client"):
try:
pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*"
async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10):
if cleaned >= max_items:
break
key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key
internal_key = key_str.split(":", 1)[-1]
cached_data = await async_oauth_token_cache.get(internal_key)
if cached_data:
expires_at = cached_data.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(internal_key)
cleaned += 1
if DEBUG_CACHE:
logger.debug(f"Cache: Cleaned up expired token: {internal_key[:16]}...")
except Exception as exc:
if DEBUG_CACHE:
logger.warning(f"Failed to perform opportunistic cleanup: {exc}")

if DEBUG_CACHE and cleaned > 0:
logger.debug(f"Cache: Opportunistic cleanup removed {cleaned} expired tokens")


async def invalidate_provider_models_cache_async(provider_name: str) -> None:
"""Invalidate model cache for a specific provider asynchronously"""
if not provider_name:
Expand Down Expand Up @@ -387,6 +482,7 @@ async def invalidate_all_caches_async() -> None:
"""Invalidate all caches in the system asynchronously"""
await async_user_cache.clear()
await async_provider_service_cache.clear()
await async_oauth_token_cache.clear()

if DEBUG_CACHE:
logger.debug("Cache: Invalidated all caches")
Expand Down Expand Up @@ -429,6 +525,7 @@ async def get_cache_stats_async() -> dict[str, dict[str, Any]]:
return {
"user_cache": await async_user_cache.stats(),
"provider_service_cache": await async_provider_service_cache.stats(),
"oauth_token_cache": await async_oauth_token_cache.stats(),
}


Expand All @@ -437,9 +534,15 @@ async def monitor_cache_performance_async() -> dict[str, Any]:
stats = await get_cache_stats_async()

# Calculate overall hit rates
total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"]
total_hits = (
stats["user_cache"]["hits"]
+ stats["provider_service_cache"]["hits"]
+ stats["oauth_token_cache"]["hits"]
)
total_requests = (
stats["user_cache"]["total"] + stats["provider_service_cache"]["total"]
stats["user_cache"]["total"]
+ stats["provider_service_cache"]["total"]
+ stats["oauth_token_cache"]["total"]
)
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0

Expand Down
27 changes: 27 additions & 0 deletions app/services/providers/vertex_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Any
import aiohttp
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderConfigException, InvalidProviderAPIKeyException, ProviderAPIException

from app.core.async_cache import get_cached_oauth_token_async, cache_oauth_token_async, invalidate_oauth_token_cache_async
from app.core.logger import get_logger

from .base import ProviderAdapter
Expand Down Expand Up @@ -125,12 +128,36 @@ async def vertex_authentication(self, api_key: str) -> str:
# validate api key
self.parse_api_key(api_key)

# check cache first for existing valid token
cached_token = await get_cached_oauth_token_async(api_key)
if cached_token:
access_token = cached_token.get("access_token")
if access_token:
return access_token

# load credentials within scope
try:
credentials = service_account.Credentials.from_service_account_info(self.cred_json, scopes=["https://www.googleapis.com/auth/cloud-platform"])

# refresh token - run in thread pool to avoid blocking
await asyncio.to_thread(credentials.refresh, Request())

# cache the token with expiry information
if credentials.token and credentials.expiry:
# Add 5-minute safety buffer to prevent using tokens too close to expiry
safety_buffer_seconds = 5 * 60 # 5 minutes
expires_at_with_buffer = credentials.expiry.timestamp() - safety_buffer_seconds

token_data = {
"access_token": credentials.token,
"token_type": "Bearer",
"expires_at": expires_at_with_buffer, # Unix timestamp with safety buffer
"scope": "https://www.googleapis.com/auth/cloud-platform",
"cached_at": time.time(), # For debugging
"provider": "vertex" # Helpful for multi-provider systems
}
await cache_oauth_token_async(api_key, token_data)

return credentials.token
except Exception as e:
logger.error(f"Error authenticating with Vertex API: {e}")
Expand Down
Loading