diff --git a/fast_cache_middleware/__init__.py b/fast_cache_middleware/__init__.py index 71a8b8c..5c27625 100644 --- a/fast_cache_middleware/__init__.py +++ b/fast_cache_middleware/__init__.py @@ -13,6 +13,7 @@ from .controller import Controller from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig from .middleware import FastCacheMiddleware +from .serializers import BaseSerializer, JSONSerializer from .storages import BaseStorage, InMemoryStorage, RedisStorage __version__ = "1.0.0" @@ -31,5 +32,5 @@ "RedisStorage", # Serialization "BaseSerializer", - "DefaultSerializer", + "JSONSerializer", ] diff --git a/fast_cache_middleware/_helpers.py b/fast_cache_middleware/_helpers.py index e9809c1..4c73a1e 100644 --- a/fast_cache_middleware/_helpers.py +++ b/fast_cache_middleware/_helpers.py @@ -1,4 +1,7 @@ +import typing as tp + from fastapi import FastAPI, routing +from starlette.routing import Mount from .depends import CacheConfig @@ -26,3 +29,53 @@ def set_cache_age_in_openapi_schema(app: FastAPI) -> None: app.openapi_schema = openapi_schema return None + + +def get_app_routes(app: FastAPI) -> tp.List[routing.APIRoute]: + """Gets all routes from FastAPI application. + + Recursively traverses all application routers and collects their routes. + + Args: + app: FastAPI application + + Returns: + List of all application routes + """ + routes = [] + + # Get routes from main application router + routes.extend(get_routes(app.router)) + + # Traverse all nested routers + for route in app.router.routes: + if isinstance(route, Mount): + if isinstance(route.app, routing.APIRouter): + routes.extend(get_routes(route.app)) + + return routes + + +def get_routes(router: routing.APIRouter) -> list[routing.APIRoute]: + """Recursively gets all routes from router. + + Traverses all routes in router and its sub-routers, collecting them into a single list. + + Args: + router: APIRouter to traverse + + Returns: + List of all routes from router and its sub-routers + """ + routes = [] + + # Get all routes from current router + for route in router.routes: + if isinstance(route, routing.APIRoute): + routes.append(route) + elif isinstance(route, Mount): + # Recursively traverse sub-routers + if isinstance(route.app, routing.APIRouter): + routes.extend(get_routes(route.app)) + + return routes diff --git a/fast_cache_middleware/middleware/__init__.py b/fast_cache_middleware/middleware/__init__.py new file mode 100644 index 0000000..f4c6c77 --- /dev/null +++ b/fast_cache_middleware/middleware/__init__.py @@ -0,0 +1 @@ +from .middleware import FastCacheMiddleware diff --git a/fast_cache_middleware/middleware/base.py b/fast_cache_middleware/middleware/base.py new file mode 100644 index 0000000..ab7a57d --- /dev/null +++ b/fast_cache_middleware/middleware/base.py @@ -0,0 +1,92 @@ +import logging +import typing as tp + +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +logger = logging.getLogger(__name__) + + +class BaseMiddleware: + def __init__( + self, + app: ASGIApp, + ) -> None: + self.app = app + + self.executors_map = { + "lifespan": self.on_lifespan, + "http": self.on_http, + } + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope_type = scope["type"] + try: + is_request_processed = await self.executors_map[scope_type]( + scope, receive, send + ) + except KeyError: + logger.debug("Not supported scope type: %s", scope_type) + is_request_processed = False + + if not is_request_processed: + await self.app(scope, receive, send) + + async def on_lifespan( + self, scope: Scope, receive: Receive, send: Send + ) -> bool | None: + pass + + async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None: + pass + + +class BaseSendWrapper: + def __init__(self, app: ASGIApp, scope: Scope, receive: Receive, send: Send): + self.app = app + self.scope = scope + self.receive = receive + self.send = send + + self._response_status: int = 200 + self._response_headers: dict[str, str] = dict() + self._response_body: bytes = b"" + + self.executors_map = { + "http.response.start": self.on_response_start, + "http.response.body": self.on_response_body, + } + + async def __call__(self) -> None: + return await self.app(self.scope, self.receive, self._message_processor) + + async def _message_processor(self, message: tp.MutableMapping[str, tp.Any]) -> None: + try: + executor = self.executors_map[message["type"]] + except KeyError: + logger.error("Not found executor for %s message type", message["type"]) + else: + await executor(message) + + await self.send(message) + + async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None: + self._response_status = message["status"] + self._response_headers = { + k.decode(): v.decode() for k, v in message.get("headers", []) + } + + async def on_response_body(self, message: tp.MutableMapping[str, tp.Any]) -> None: + self._response_body += message.get("body", b"") + + # this is the last chunk + if not message.get("more_body", False): + response = Response( + content=self._response_body, + status_code=self._response_status, + headers=self._response_headers, + ) + await self.on_response_ready(response) + + async def on_response_ready(self, response: Response) -> None: + pass diff --git a/fast_cache_middleware/middleware.py b/fast_cache_middleware/middleware/middleware.py similarity index 56% rename from fast_cache_middleware/middleware.py rename to fast_cache_middleware/middleware/middleware.py index 859c09a..f64b1c3 100644 --- a/fast_cache_middleware/middleware.py +++ b/fast_cache_middleware/middleware/middleware.py @@ -3,191 +3,31 @@ import re import typing as tp -from fastapi import FastAPI, routing +from fastapi import routing from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Match, Mount, compile_path, get_name +from starlette.routing import Match, compile_path, get_name from starlette.types import ASGIApp, Receive, Scope, Send -from ._helpers import set_cache_age_in_openapi_schema -from .controller import Controller -from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig -from .schemas import CacheConfiguration, RouteInfo -from .storages import BaseStorage, InMemoryStorage +from fast_cache_middleware._helpers import ( + get_app_routes, + get_routes, + set_cache_age_in_openapi_schema, +) +from fast_cache_middleware.controller import Controller +from fast_cache_middleware.depends import ( + BaseCacheConfigDepends, + CacheConfig, + CacheDropConfig, +) +from fast_cache_middleware.schemas import CacheConfiguration, RouteInfo +from fast_cache_middleware.storages import BaseStorage, InMemoryStorage + +from .base import BaseMiddleware +from .send_wrapper import CacheSendWrapper logger = logging.getLogger(__name__) -class BaseMiddleware: - def __init__( - self, - app: ASGIApp, - ) -> None: - self.app = app - - self.executors_map = { - "lifespan": self.on_lifespan, - "http": self.on_http, - } - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - scope_type = scope["type"] - try: - is_request_processed = await self.executors_map[scope_type]( - scope, receive, send - ) - except KeyError: - logger.debug("Not supported scope type: %s", scope_type) - is_request_processed = False - - if not is_request_processed: - await self.app(scope, receive, send) - - async def on_lifespan( - self, scope: Scope, receive: Receive, send: Send - ) -> bool | None: - pass - - async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None: - pass - - -class BaseSendWrapper: - def __init__(self, app: ASGIApp, scope: Scope, receive: Receive, send: Send): - self.app = app - self.scope = scope - self.receive = receive - self.send = send - - self._response_status: int = 200 - self._response_headers: dict[str, str] = dict() - self._response_body: bytes = b"" - - self.executors_map = { - "http.response.start": self.on_response_start, - "http.response.body": self.on_response_body, - } - - async def __call__(self) -> None: - return await self.app(self.scope, self.receive, self._message_processor) - - async def _message_processor(self, message: tp.MutableMapping[str, tp.Any]) -> None: - try: - executor = self.executors_map[message["type"]] - except KeyError: - logger.error("Not found executor for %s message type", message["type"]) - else: - await executor(message) - - await self.send(message) - - async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None: - self._response_status = message["status"] - self._response_headers = { - k.decode(): v.decode() for k, v in message.get("headers", []) - } - - async def on_response_body(self, message: tp.MutableMapping[str, tp.Any]) -> None: - self._response_body += message.get("body", b"") - - # this is the last chunk - if not message.get("more_body", False): - response = Response( - content=self._response_body, - status_code=self._response_status, - headers=self._response_headers, - ) - await self.on_response_ready(response) - - async def on_response_ready(self, response: Response) -> None: - pass - - -class CacheSendWrapper(BaseSendWrapper): - def __init__( - self, - controller: Controller, - storage: BaseStorage, - request: Request, - cache_key: str, - ttl: int, - app: ASGIApp, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - super().__init__(app, scope, receive, send) - - self.controller = controller - self.storage = storage - self.request = request - self.cache_key = cache_key - self.ttl = ttl - - async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None: - message.get("headers", []).append(("X-Cache-Status".encode(), "MISS".encode())) - return await super().on_response_start(message) - - async def on_response_ready(self, response: Response) -> None: - await self.controller.cache_response( - cache_key=self.cache_key, - request=self.request, - response=response, - storage=self.storage, - ttl=self.ttl, - ) - - -def get_app_routes(app: FastAPI) -> tp.List[routing.APIRoute]: - """Gets all routes from FastAPI application. - - Recursively traverses all application routers and collects their routes. - - Args: - app: FastAPI application - - Returns: - List of all application routes - """ - routes = [] - - # Get routes from main application router - routes.extend(get_routes(app.router)) - - # Traverse all nested routers - for route in app.router.routes: - if isinstance(route, Mount): - if isinstance(route.app, routing.APIRouter): - routes.extend(get_routes(route.app)) - - return routes - - -def get_routes(router: routing.APIRouter) -> list[routing.APIRoute]: - """Recursively gets all routes from router. - - Traverses all routes in router and its sub-routers, collecting them into a single list. - - Args: - router: APIRouter to traverse - - Returns: - List of all routes from router and its sub-routers - """ - routes = [] - - # Get all routes from current router - for route in router.routes: - if isinstance(route, routing.APIRoute): - routes.append(route) - elif isinstance(route, Mount): - # Recursively traverse sub-routers - if isinstance(route.app, routing.APIRouter): - routes.extend(get_routes(route.app)) - - return routes - - class FastCacheMiddleware(BaseMiddleware): """Middleware for caching responses in ASGI applications. diff --git a/fast_cache_middleware/middleware/send_wrapper.py b/fast_cache_middleware/middleware/send_wrapper.py new file mode 100644 index 0000000..bd1abbb --- /dev/null +++ b/fast_cache_middleware/middleware/send_wrapper.py @@ -0,0 +1,45 @@ +import typing as tp + +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +from fast_cache_middleware.controller import Controller +from fast_cache_middleware.storages import BaseStorage + +from .base import BaseSendWrapper + + +class CacheSendWrapper(BaseSendWrapper): + def __init__( + self, + controller: Controller, + storage: BaseStorage, + request: Request, + cache_key: str, + ttl: int, + app: ASGIApp, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + super().__init__(app, scope, receive, send) + + self.controller = controller + self.storage = storage + self.request = request + self.cache_key = cache_key + self.ttl = ttl + + async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None: + message.get("headers", []).append(("X-Cache-Status".encode(), "MISS".encode())) + return await super().on_response_start(message) + + async def on_response_ready(self, response: Response) -> None: + await self.controller.cache_response( + cache_key=self.cache_key, + request=self.request, + response=response, + storage=self.storage, + ttl=self.ttl, + )