Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 1 addition & 41 deletions app/api/schemas/provider_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,6 @@

logger = get_logger(name="provider_key")

# Constants for API key masking
API_KEY_MASK_PREFIX_LENGTH = 2
API_KEY_MASK_SUFFIX_LENGTH = 4
# Minimum length to apply the full prefix + suffix mask (e.g., pr******fix)
# This means if length is > (PREFIX + SUFFIX), we can apply the full rule.
MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC = (
API_KEY_MASK_PREFIX_LENGTH + API_KEY_MASK_SUFFIX_LENGTH
)


# Helper function for masking API keys
def _mask_api_key_value(value: str | None) -> str | None:
if not value:
return None

length = len(value)

if length == 0:
return ""

# If key is too short for any meaningful prefix/suffix masking
if length <= API_KEY_MASK_PREFIX_LENGTH:
return "*" * length

# If key is long enough for prefix, but not for prefix + suffix
# e.g., length is 3, 4, 5, 6. For these, show prefix and mask the rest.
if length <= MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC:
return value[:API_KEY_MASK_PREFIX_LENGTH] + "*" * (
length - API_KEY_MASK_PREFIX_LENGTH
)

# If key is long enough for the full prefix + ... + suffix mask
# number of asterisks = length - prefix_length - suffix_length
num_asterisks = length - API_KEY_MASK_PREFIX_LENGTH - API_KEY_MASK_SUFFIX_LENGTH
return (
value[:API_KEY_MASK_PREFIX_LENGTH]
+ "*" * num_asterisks
+ value[-API_KEY_MASK_SUFFIX_LENGTH:]
)


class ProviderKeyBase(BaseModel):
provider_name: str = Field(..., min_length=1)
Expand Down Expand Up @@ -109,7 +69,7 @@ def api_key(self) -> str | None:
api_key, _ = provider_adapter_cls.deserialize_api_key_config(
decrypted_value
)
return _mask_api_key_value(api_key)
return provider_adapter_cls.mask_api_key(api_key)
except Exception as e:
logger.error(
f"Error deserializing API key for provider {self.provider_name}: {e}"
Expand Down
107 changes: 105 additions & 2 deletions app/core/async_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import functools
import hashlib
import os
import time
from collections.abc import Callable
Expand Down Expand Up @@ -156,6 +157,8 @@ async def wrapper(*args, **kwargs):
async_provider_service_cache: "AsyncCache" = _AsyncBackend(
ttl_seconds=3600
) # 1-hour TTL
# OAuth2 token caching (no TTL - uses token's own expiration with smart cleanup)
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None)


# User-specific functions
Expand Down Expand Up @@ -336,6 +339,98 @@ async def invalidate_provider_service_cache_async(user_id: int) -> None:
)


# OAuth2 token caching functions
async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None:
"""Get a cached OAuth2 token by API key asynchronously"""
if not api_key:
return None

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
cached_data = await async_oauth_token_cache.get(cache_key)
if not cached_data:
return None

expires_at = cached_data.get("expires_at")
if not expires_at:
await async_oauth_token_cache.delete(cache_key)
return None

current_time = time.time()
if expires_at <= current_time:
await async_oauth_token_cache.delete(cache_key)
await _opportunistic_cleanup(current_time, max_items=2)
return None

return cached_data


async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None:
"""Cache an OAuth2 token by API key asynchronously"""
if not api_key or not token_data:
return

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
if "expires_at" not in token_data:
logger.warning("OAuth token cached without expires_at - skipping")
return

await async_oauth_token_cache.set(cache_key, token_data)


async def invalidate_oauth_token_cache_async(api_key: str) -> None:
"""Invalidate OAuth2 token cache for a specific API key asynchronously"""
if not api_key:
return

cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
await async_oauth_token_cache.delete(cache_key)
if DEBUG_CACHE:
logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {cache_key[:16]}...")


async def _opportunistic_cleanup(current_time: float, max_items: int = 2) -> None:
"""Opportunistically clean up expired OAuth tokens from cache"""
cleaned = 0

# Case 1: in-memory backend exposes .cache dict
if hasattr(async_oauth_token_cache, "cache"):
async with async_oauth_token_cache.lock:
for key, value in list(async_oauth_token_cache.cache.items()):
if cleaned >= max_items:
break
if key.startswith("token:"):
expires_at = value.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(key)
cleaned += 1
if DEBUG_CACHE:
logger.debug(f"Cache: Cleaned up expired token: {key[:16]}...")

# Case 2: Redis backend
elif hasattr(async_oauth_token_cache, "client"):
try:
pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*"
async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10):
if cleaned >= max_items:
break
key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key
internal_key = key_str.split(":", 1)[-1]
cached_data = await async_oauth_token_cache.get(internal_key)
if cached_data:
expires_at = cached_data.get("expires_at")
if expires_at and expires_at <= current_time:
await async_oauth_token_cache.delete(internal_key)
cleaned += 1
if DEBUG_CACHE:
logger.debug(f"Cache: Cleaned up expired token: {internal_key[:16]}...")
except Exception as exc:
if DEBUG_CACHE:
logger.warning(f"Failed to perform opportunistic cleanup: {exc}")

if DEBUG_CACHE and cleaned > 0:
logger.debug(f"Cache: Opportunistic cleanup removed {cleaned} expired tokens")


async def invalidate_provider_models_cache_async(provider_name: str) -> None:
"""Invalidate model cache for a specific provider asynchronously"""
if not provider_name:
Expand Down Expand Up @@ -387,6 +482,7 @@ async def invalidate_all_caches_async() -> None:
"""Invalidate all caches in the system asynchronously"""
await async_user_cache.clear()
await async_provider_service_cache.clear()
await async_oauth_token_cache.clear()

if DEBUG_CACHE:
logger.debug("Cache: Invalidated all caches")
Expand Down Expand Up @@ -429,6 +525,7 @@ async def get_cache_stats_async() -> dict[str, dict[str, Any]]:
return {
"user_cache": await async_user_cache.stats(),
"provider_service_cache": await async_provider_service_cache.stats(),
"oauth_token_cache": await async_oauth_token_cache.stats(),
}


Expand All @@ -437,9 +534,15 @@ async def monitor_cache_performance_async() -> dict[str, Any]:
stats = await get_cache_stats_async()

# Calculate overall hit rates
total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"]
total_hits = (
stats["user_cache"]["hits"]
+ stats["provider_service_cache"]["hits"]
+ stats["oauth_token_cache"]["hits"]
)
total_requests = (
stats["user_cache"]["total"] + stats["provider_service_cache"]["total"]
stats["user_cache"]["total"]
+ stats["provider_service_cache"]["total"]
+ stats["oauth_token_cache"]["total"]
)
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0

Expand Down
41 changes: 39 additions & 2 deletions app/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def stats(self) -> dict[str, Any]:
# Expose the global cache instances
user_cache: "Cache" = _CacheBackend(ttl_seconds=300) # 5-minute TTL for users
provider_service_cache: "Cache" = _CacheBackend(ttl_seconds=3600) # 1-hour TTL
# OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry)
oauth_token_cache: "Cache" = _CacheBackend(ttl_seconds=3300) # 55-min TTL


def cached(cache_instance: Cache, key_func: Callable[[Any], str] = None):
Expand Down Expand Up @@ -238,6 +240,33 @@ def invalidate_provider_service_cache(user_id: int) -> None:
)


# OAuth2 token caching functions
def get_cached_oauth_token(api_key: str) -> dict[str, Any] | None:
"""Get a cached OAuth2 token by API key"""
if not api_key:
return None
cached_data = oauth_token_cache.get(f"token:{api_key}")
if cached_data:
return cached_data
return None


def cache_oauth_token(api_key: str, token_data: dict[str, Any]) -> None:
"""Cache an OAuth2 token by API key"""
if not api_key or not token_data:
return
oauth_token_cache.set(f"token:{api_key}", token_data)


def invalidate_oauth_token_cache(api_key: str) -> None:
"""Invalidate OAuth2 token cache for a specific API key"""
if not api_key:
return
oauth_token_cache.delete(f"token:{api_key}")
if DEBUG_CACHE:
logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {api_key[:8]}...")


def invalidate_user_cache_by_id(user_id: int) -> None:
"""Invalidate all cache entries for a specific user ID"""
if not user_id:
Expand Down Expand Up @@ -325,6 +354,7 @@ def invalidate_all_caches() -> None:
"""Invalidate all caches in the system"""
user_cache.clear()
provider_service_cache.clear()
oauth_token_cache.clear()

if DEBUG_CACHE:
logger.debug("Cache: Invalidated all caches")
Expand Down Expand Up @@ -362,6 +392,7 @@ def get_cache_stats() -> dict[str, dict[str, Any]]:
return {
"user_cache": user_cache.stats(),
"provider_service_cache": provider_service_cache.stats(),
"oauth_token_cache": oauth_token_cache.stats(),
}


Expand All @@ -370,9 +401,15 @@ def monitor_cache_performance() -> dict[str, Any]:
stats = get_cache_stats()

# Calculate overall hit rates
total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"]
total_hits = (
stats["user_cache"]["hits"]
+ stats["provider_service_cache"]["hits"]
+ stats["oauth_token_cache"]["hits"]
)
total_requests = (
stats["user_cache"]["total"] + stats["provider_service_cache"]["total"]
stats["user_cache"]["total"]
+ stats["provider_service_cache"]["total"]
+ stats["oauth_token_cache"]["total"]
)
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0

Expand Down
2 changes: 1 addition & 1 deletion app/exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ def __init__(self, error: Exception):
class InvalidForgeKeyException(BaseInvalidForgeKeyException):
"""Exception raised when a Forge key is invalid."""
def __init__(self, error: Exception):
super().__init__(error)
super().__init__(error)
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def provider_authentication_exception_handler(request: Request, exc: Provi
async def invalid_provider_exception_handler(request: Request, exc: InvalidProviderException):
return HTTPException(
status_code=400,
detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure youre using a valid provider name, model name, and model ID."
detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you're using a valid provider name, model name, and model ID."
)

# Add exception handler for BaseInvalidProviderSetupException
Expand Down
4 changes: 4 additions & 0 deletions app/services/providers/adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .perplexity_adapter import PerplexityAdapter
from .tensorblock_adapter import TensorblockAdapter
from .zhipu_adapter import ZhipuAdapter
from .vertex_adapter import VertexAdapter


class ProviderAdapterFactory:
Expand Down Expand Up @@ -185,6 +186,9 @@ class ProviderAdapterFactory:
"bedrock": {
"adapter": BedrockAdapter,
},
"vertex": {
"adapter": VertexAdapter,
},
"customized": {
"adapter": OpenAIAdapter,
},
Expand Down
Loading
Loading