Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fast_cache_middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .controller import Controller
from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig
from .middleware import FastCacheMiddleware
from .serializers import BaseSerializer, JSONSerializer
from .storages import BaseStorage, InMemoryStorage, RedisStorage

__version__ = "1.0.0"
Expand All @@ -31,5 +32,5 @@
"RedisStorage",
# Serialization
"BaseSerializer",
"DefaultSerializer",
"JSONSerializer",
]
53 changes: 53 additions & 0 deletions fast_cache_middleware/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import typing as tp

from fastapi import FastAPI, routing
from starlette.routing import Mount

from .depends import CacheConfig

Expand Down Expand Up @@ -26,3 +29,53 @@ def set_cache_age_in_openapi_schema(app: FastAPI) -> None:

app.openapi_schema = openapi_schema
return None


def get_app_routes(app: FastAPI) -> tp.List[routing.APIRoute]:
"""Gets all routes from FastAPI application.

Recursively traverses all application routers and collects their routes.

Args:
app: FastAPI application

Returns:
List of all application routes
"""
routes = []

# Get routes from main application router
routes.extend(get_routes(app.router))

# Traverse all nested routers
for route in app.router.routes:
if isinstance(route, Mount):
if isinstance(route.app, routing.APIRouter):
routes.extend(get_routes(route.app))

return routes


def get_routes(router: routing.APIRouter) -> list[routing.APIRoute]:
"""Recursively gets all routes from router.

Traverses all routes in router and its sub-routers, collecting them into a single list.

Args:
router: APIRouter to traverse

Returns:
List of all routes from router and its sub-routers
"""
routes = []

# Get all routes from current router
for route in router.routes:
if isinstance(route, routing.APIRoute):
routes.append(route)
elif isinstance(route, Mount):
# Recursively traverse sub-routers
if isinstance(route.app, routing.APIRouter):
routes.extend(get_routes(route.app))

return routes
1 change: 1 addition & 0 deletions fast_cache_middleware/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .middleware import FastCacheMiddleware
92 changes: 92 additions & 0 deletions fast_cache_middleware/middleware/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
import typing as tp

from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send

logger = logging.getLogger(__name__)


class BaseMiddleware:
def __init__(
self,
app: ASGIApp,
) -> None:
self.app = app

self.executors_map = {
"lifespan": self.on_lifespan,
"http": self.on_http,
}

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope_type = scope["type"]
try:
is_request_processed = await self.executors_map[scope_type](
scope, receive, send
)
except KeyError:
logger.debug("Not supported scope type: %s", scope_type)
is_request_processed = False

if not is_request_processed:
await self.app(scope, receive, send)

async def on_lifespan(
self, scope: Scope, receive: Receive, send: Send
) -> bool | None:
pass

async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None:
pass


class BaseSendWrapper:
def __init__(self, app: ASGIApp, scope: Scope, receive: Receive, send: Send):
self.app = app
self.scope = scope
self.receive = receive
self.send = send

self._response_status: int = 200
self._response_headers: dict[str, str] = dict()
self._response_body: bytes = b""

self.executors_map = {
"http.response.start": self.on_response_start,
"http.response.body": self.on_response_body,
}

async def __call__(self) -> None:
return await self.app(self.scope, self.receive, self._message_processor)

async def _message_processor(self, message: tp.MutableMapping[str, tp.Any]) -> None:
try:
executor = self.executors_map[message["type"]]
except KeyError:
logger.error("Not found executor for %s message type", message["type"])
else:
await executor(message)

await self.send(message)

async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None:
self._response_status = message["status"]
self._response_headers = {
k.decode(): v.decode() for k, v in message.get("headers", [])
}

async def on_response_body(self, message: tp.MutableMapping[str, tp.Any]) -> None:
self._response_body += message.get("body", b"")

# this is the last chunk
if not message.get("more_body", False):
response = Response(
content=self._response_body,
status_code=self._response_status,
headers=self._response_headers,
)
await self.on_response_ready(response)

async def on_response_ready(self, response: Response) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,191 +3,31 @@
import re
import typing as tp

from fastapi import FastAPI, routing
from fastapi import routing
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Match, Mount, compile_path, get_name
from starlette.routing import Match, compile_path, get_name
from starlette.types import ASGIApp, Receive, Scope, Send

from ._helpers import set_cache_age_in_openapi_schema
from .controller import Controller
from .depends import BaseCacheConfigDepends, CacheConfig, CacheDropConfig
from .schemas import CacheConfiguration, RouteInfo
from .storages import BaseStorage, InMemoryStorage
from fast_cache_middleware._helpers import (
get_app_routes,
get_routes,
set_cache_age_in_openapi_schema,
)
from fast_cache_middleware.controller import Controller
from fast_cache_middleware.depends import (
BaseCacheConfigDepends,
CacheConfig,
CacheDropConfig,
)
from fast_cache_middleware.schemas import CacheConfiguration, RouteInfo
from fast_cache_middleware.storages import BaseStorage, InMemoryStorage

from .base import BaseMiddleware
from .send_wrapper import CacheSendWrapper

logger = logging.getLogger(__name__)


class BaseMiddleware:
def __init__(
self,
app: ASGIApp,
) -> None:
self.app = app

self.executors_map = {
"lifespan": self.on_lifespan,
"http": self.on_http,
}

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope_type = scope["type"]
try:
is_request_processed = await self.executors_map[scope_type](
scope, receive, send
)
except KeyError:
logger.debug("Not supported scope type: %s", scope_type)
is_request_processed = False

if not is_request_processed:
await self.app(scope, receive, send)

async def on_lifespan(
self, scope: Scope, receive: Receive, send: Send
) -> bool | None:
pass

async def on_http(self, scope: Scope, receive: Receive, send: Send) -> bool | None:
pass


class BaseSendWrapper:
def __init__(self, app: ASGIApp, scope: Scope, receive: Receive, send: Send):
self.app = app
self.scope = scope
self.receive = receive
self.send = send

self._response_status: int = 200
self._response_headers: dict[str, str] = dict()
self._response_body: bytes = b""

self.executors_map = {
"http.response.start": self.on_response_start,
"http.response.body": self.on_response_body,
}

async def __call__(self) -> None:
return await self.app(self.scope, self.receive, self._message_processor)

async def _message_processor(self, message: tp.MutableMapping[str, tp.Any]) -> None:
try:
executor = self.executors_map[message["type"]]
except KeyError:
logger.error("Not found executor for %s message type", message["type"])
else:
await executor(message)

await self.send(message)

async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None:
self._response_status = message["status"]
self._response_headers = {
k.decode(): v.decode() for k, v in message.get("headers", [])
}

async def on_response_body(self, message: tp.MutableMapping[str, tp.Any]) -> None:
self._response_body += message.get("body", b"")

# this is the last chunk
if not message.get("more_body", False):
response = Response(
content=self._response_body,
status_code=self._response_status,
headers=self._response_headers,
)
await self.on_response_ready(response)

async def on_response_ready(self, response: Response) -> None:
pass


class CacheSendWrapper(BaseSendWrapper):
def __init__(
self,
controller: Controller,
storage: BaseStorage,
request: Request,
cache_key: str,
ttl: int,
app: ASGIApp,
scope: Scope,
receive: Receive,
send: Send,
) -> None:
super().__init__(app, scope, receive, send)

self.controller = controller
self.storage = storage
self.request = request
self.cache_key = cache_key
self.ttl = ttl

async def on_response_start(self, message: tp.MutableMapping[str, tp.Any]) -> None:
message.get("headers", []).append(("X-Cache-Status".encode(), "MISS".encode()))
return await super().on_response_start(message)

async def on_response_ready(self, response: Response) -> None:
await self.controller.cache_response(
cache_key=self.cache_key,
request=self.request,
response=response,
storage=self.storage,
ttl=self.ttl,
)


def get_app_routes(app: FastAPI) -> tp.List[routing.APIRoute]:
"""Gets all routes from FastAPI application.

Recursively traverses all application routers and collects their routes.

Args:
app: FastAPI application

Returns:
List of all application routes
"""
routes = []

# Get routes from main application router
routes.extend(get_routes(app.router))

# Traverse all nested routers
for route in app.router.routes:
if isinstance(route, Mount):
if isinstance(route.app, routing.APIRouter):
routes.extend(get_routes(route.app))

return routes


def get_routes(router: routing.APIRouter) -> list[routing.APIRoute]:
"""Recursively gets all routes from router.

Traverses all routes in router and its sub-routers, collecting them into a single list.

Args:
router: APIRouter to traverse

Returns:
List of all routes from router and its sub-routers
"""
routes = []

# Get all routes from current router
for route in router.routes:
if isinstance(route, routing.APIRoute):
routes.append(route)
elif isinstance(route, Mount):
# Recursively traverse sub-routers
if isinstance(route.app, routing.APIRouter):
routes.extend(get_routes(route.app))

return routes


class FastCacheMiddleware(BaseMiddleware):
"""Middleware for caching responses in ASGI applications.

Expand Down
Loading