diff --git a/fast_cache_middleware/controller.py b/fast_cache_middleware/controller.py index 1078c1d..27a23f6 100644 --- a/fast_cache_middleware/controller.py +++ b/fast_cache_middleware/controller.py @@ -4,8 +4,10 @@ from hashlib import blake2b from typing import Optional +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import Response +from starlette.routing import is_async_callable from .exceptions import NotFoundStorageError, TTLExpiredStorageError from .schemas import CacheConfiguration @@ -140,20 +142,13 @@ async def is_cachable_response(self, response: Response) -> bool: async def generate_cache_key( self, request: Request, cache_configuration: CacheConfiguration ) -> str: - """Generates cache key for request. - - Args: - request: HTTP request - cache_config: Cache configuration - - Returns: - str: Cache key - """ - # Use custom key generation function if available if cache_configuration.key_func: - return cache_configuration.key_func(request) + kf = cache_configuration.key_func + + if is_async_callable(kf): + return await kf(request) # type: ignore[no-any-return] + return await run_in_threadpool(kf, request) # type: ignore[arg-type] - # Use standard function return generate_key(request) async def cache_response( diff --git a/fast_cache_middleware/depends.py b/fast_cache_middleware/depends.py index 88756f4..a4b8a81 100644 --- a/fast_cache_middleware/depends.py +++ b/fast_cache_middleware/depends.py @@ -1,9 +1,11 @@ import re -from typing import Callable, Optional +from typing import Awaitable, Callable, Optional, Union from fastapi import params from starlette.requests import Request +SyncOrAsync = Union[Callable[[Request], str], Callable[[Request], Awaitable[str]]] + class BaseCacheConfigDepends(params.Depends): """Base class for cache configuration via ASGI scope extensions. @@ -29,7 +31,7 @@ class CacheConfig(BaseCacheConfigDepends): def __init__( self, max_age: int = 5 * 60, - key_func: Optional[Callable[[Request], str]] = None, + key_func: Optional[SyncOrAsync] = None, ) -> None: self.max_age = max_age self.key_func = key_func diff --git a/fast_cache_middleware/schemas.py b/fast_cache_middleware/schemas.py index d4d9e4a..abe80b1 100644 --- a/fast_cache_middleware/schemas.py +++ b/fast_cache_middleware/schemas.py @@ -1,5 +1,5 @@ import re -from typing import Any, Callable +from typing import Any from pydantic import ( BaseModel, @@ -9,10 +9,9 @@ field_validator, model_validator, ) -from starlette.requests import Request from starlette.routing import Route -from .depends import CacheConfig, CacheDropConfig +from .depends import SyncOrAsync class CacheConfiguration(BaseModel): @@ -22,7 +21,7 @@ class CacheConfiguration(BaseModel): default=None, description="Cache lifetime in seconds. If None, caching is disabled.", ) - key_func: Callable[[Request], str] | None = Field( + key_func: SyncOrAsync | None = Field( default=None, description="Custom cache key generation function. If None, default key generation is used.", )