From 13b423fe2a6d9ae35d9c040adc4dc272eb75da9a Mon Sep 17 00:00:00 2001 From: Dokujaa Date: Fri, 25 Jul 2025 21:04:57 -0400 Subject: [PATCH] added oauth cache vertex AI --- app/core/async_cache.py | 107 ++++++++++++++++++++++- app/services/providers/vertex_adapter.py | 27 ++++++ 2 files changed, 132 insertions(+), 2 deletions(-) diff --git a/app/core/async_cache.py b/app/core/async_cache.py index c3ba536..3775e83 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -5,6 +5,7 @@ import asyncio import functools +import hashlib import os import time from collections.abc import Callable @@ -156,6 +157,8 @@ async def wrapper(*args, **kwargs): async_provider_service_cache: "AsyncCache" = _AsyncBackend( ttl_seconds=3600 ) # 1-hour TTL +# OAuth2 token caching (no TTL - uses token's own expiration with smart cleanup) +async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None) # User-specific functions @@ -336,6 +339,98 @@ async def invalidate_provider_service_cache_async(user_id: int) -> None: ) +# OAuth2 token caching functions +async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None: + """Get a cached OAuth2 token by API key asynchronously""" + if not api_key: + return None + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + cached_data = await async_oauth_token_cache.get(cache_key) + if not cached_data: + return None + + expires_at = cached_data.get("expires_at") + if not expires_at: + await async_oauth_token_cache.delete(cache_key) + return None + + current_time = time.time() + if expires_at <= current_time: + await async_oauth_token_cache.delete(cache_key) + await _opportunistic_cleanup(current_time, max_items=2) + return None + + return cached_data + + +async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None: + """Cache an OAuth2 token by API key asynchronously""" + if not api_key or not token_data: + return + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + if "expires_at" not in token_data: + logger.warning("OAuth token cached without expires_at - skipping") + return + + await async_oauth_token_cache.set(cache_key, token_data) + + +async def invalidate_oauth_token_cache_async(api_key: str) -> None: + """Invalidate OAuth2 token cache for a specific API key asynchronously""" + if not api_key: + return + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + await async_oauth_token_cache.delete(cache_key) + if DEBUG_CACHE: + logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {cache_key[:16]}...") + + +async def _opportunistic_cleanup(current_time: float, max_items: int = 2) -> None: + """Opportunistically clean up expired OAuth tokens from cache""" + cleaned = 0 + + # Case 1: in-memory backend exposes .cache dict + if hasattr(async_oauth_token_cache, "cache"): + async with async_oauth_token_cache.lock: + for key, value in list(async_oauth_token_cache.cache.items()): + if cleaned >= max_items: + break + if key.startswith("token:"): + expires_at = value.get("expires_at") + if expires_at and expires_at <= current_time: + await async_oauth_token_cache.delete(key) + cleaned += 1 + if DEBUG_CACHE: + logger.debug(f"Cache: Cleaned up expired token: {key[:16]}...") + + # Case 2: Redis backend + elif hasattr(async_oauth_token_cache, "client"): + try: + pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*" + async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10): + if cleaned >= max_items: + break + key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key + internal_key = key_str.split(":", 1)[-1] + cached_data = await async_oauth_token_cache.get(internal_key) + if cached_data: + expires_at = cached_data.get("expires_at") + if expires_at and expires_at <= current_time: + await async_oauth_token_cache.delete(internal_key) + cleaned += 1 + if DEBUG_CACHE: + logger.debug(f"Cache: Cleaned up expired token: {internal_key[:16]}...") + except Exception as exc: + if DEBUG_CACHE: + logger.warning(f"Failed to perform opportunistic cleanup: {exc}") + + if DEBUG_CACHE and cleaned > 0: + logger.debug(f"Cache: Opportunistic cleanup removed {cleaned} expired tokens") + + async def invalidate_provider_models_cache_async(provider_name: str) -> None: """Invalidate model cache for a specific provider asynchronously""" if not provider_name: @@ -387,6 +482,7 @@ async def invalidate_all_caches_async() -> None: """Invalidate all caches in the system asynchronously""" await async_user_cache.clear() await async_provider_service_cache.clear() + await async_oauth_token_cache.clear() if DEBUG_CACHE: logger.debug("Cache: Invalidated all caches") @@ -429,6 +525,7 @@ async def get_cache_stats_async() -> dict[str, dict[str, Any]]: return { "user_cache": await async_user_cache.stats(), "provider_service_cache": await async_provider_service_cache.stats(), + "oauth_token_cache": await async_oauth_token_cache.stats(), } @@ -437,9 +534,15 @@ async def monitor_cache_performance_async() -> dict[str, Any]: stats = await get_cache_stats_async() # Calculate overall hit rates - total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"] + total_hits = ( + stats["user_cache"]["hits"] + + stats["provider_service_cache"]["hits"] + + stats["oauth_token_cache"]["hits"] + ) total_requests = ( - stats["user_cache"]["total"] + stats["provider_service_cache"]["total"] + stats["user_cache"]["total"] + + stats["provider_service_cache"]["total"] + + stats["oauth_token_cache"]["total"] ) overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0 diff --git a/app/services/providers/vertex_adapter.py b/app/services/providers/vertex_adapter.py index 2023c71..089f966 100644 --- a/app/services/providers/vertex_adapter.py +++ b/app/services/providers/vertex_adapter.py @@ -1,12 +1,15 @@ import asyncio import json +import time from collections.abc import AsyncGenerator +from datetime import datetime, timezone from typing import Any import aiohttp from google.oauth2 import service_account from google.auth.transport.requests import Request from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderConfigException, InvalidProviderAPIKeyException, ProviderAPIException +from app.core.async_cache import get_cached_oauth_token_async, cache_oauth_token_async, invalidate_oauth_token_cache_async from app.core.logger import get_logger from .base import ProviderAdapter @@ -125,12 +128,36 @@ async def vertex_authentication(self, api_key: str) -> str: # validate api key self.parse_api_key(api_key) + # check cache first for existing valid token + cached_token = await get_cached_oauth_token_async(api_key) + if cached_token: + access_token = cached_token.get("access_token") + if access_token: + return access_token + # load credentials within scope try: credentials = service_account.Credentials.from_service_account_info(self.cred_json, scopes=["https://www.googleapis.com/auth/cloud-platform"]) # refresh token - run in thread pool to avoid blocking await asyncio.to_thread(credentials.refresh, Request()) + + # cache the token with expiry information + if credentials.token and credentials.expiry: + # Add 5-minute safety buffer to prevent using tokens too close to expiry + safety_buffer_seconds = 5 * 60 # 5 minutes + expires_at_with_buffer = credentials.expiry.timestamp() - safety_buffer_seconds + + token_data = { + "access_token": credentials.token, + "token_type": "Bearer", + "expires_at": expires_at_with_buffer, # Unix timestamp with safety buffer + "scope": "https://www.googleapis.com/auth/cloud-platform", + "cached_at": time.time(), # For debugging + "provider": "vertex" # Helpful for multi-provider systems + } + await cache_oauth_token_async(api_key, token_data) + return credentials.token except Exception as e: logger.error(f"Error authenticating with Vertex API: {e}")