Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ async def get_users() -> tp.List[UserResponse]:
]


@app.get("/orgs/{org_id}/users/{user_id}", dependencies=[CacheConfig(max_age=300)])
async def get_user_in_org(org_id: int, user_id: int) -> UserResponse:
"""Получение пользователя в конкретной организации.

Пример более сложного пути с несколькими параметрами.
"""
user = _USERS_STORAGE.get(user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")

return UserResponse(user_id=user_id, name=user.name, email=user.email)


@app.post("/users/{user_id}", dependencies=[CacheDropConfig(paths=["/users"])])
async def create_user(user_id: int, user_data: User) -> UserResponse:
"""Создание пользователя с инвалидацией кеша.
Expand All @@ -137,7 +150,10 @@ async def update_user(user_id: int, user_data: User) -> UserResponse:
return UserResponse(user_id=user_id, name=user_data.name, email=user_data.email)


@app.delete("/users/{user_id}", dependencies=[CacheDropConfig(paths=["/users"])])
@app.delete(
"/users/{user_id}",
dependencies=[CacheDropConfig(methods=[get_user, get_user_in_org])],
)
async def delete_user(user_id: int) -> UserResponse:
"""Удаление пользователя с инвалидацией кеша."""
user = _USERS_STORAGE.get(user_id)
Expand Down
9 changes: 7 additions & 2 deletions fast_cache_middleware/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ class CacheDropConfig(BaseCacheConfigDepends):
that matches the beginning of request path.
"""

def __init__(self, paths: list[str | re.Pattern]) -> None:
def __init__(
self,
paths: list[str | re.Pattern] | None = None,
methods: list[Callable] | None = None,
) -> None:
self.paths: list[re.Pattern] = [
p if isinstance(p, re.Pattern) else re.compile(f"^{p}") for p in paths
p if isinstance(p, re.Pattern) else re.compile(f"^{p}") for p in paths or []
]
self.methods: list[Callable] = methods or []

self.dependency = self
35 changes: 33 additions & 2 deletions fast_cache_middleware/middleware.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import logging
import re
import typing as tp

from fastapi import FastAPI, routing
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Match, Mount
from starlette.routing import Match, Mount, compile_path, get_name
from starlette.types import ASGIApp, Receive, Scope, Send

from ._helpers import set_cache_age_in_openapi_schema
Expand Down Expand Up @@ -294,12 +295,19 @@ def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo
routes: List of routes to analyze
"""
routes_info = []
route_names = {route.name: route.path for route in routes}

for route in routes:
(
cache_config,
cache_drop_config,
) = self._extract_cache_configs_from_route(route)

paths = self._convert_methods_to_path(route_names, cache_drop_config)

if cache_drop_config and paths is not None:
cache_drop_config.paths.extend(paths)

if cache_config or cache_drop_config:
cache_configuration = CacheConfiguration(
max_age=cache_config.max_age if cache_config else None,
Expand All @@ -308,7 +316,6 @@ def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo
cache_drop_config.paths if cache_drop_config else None
),
)

route_info = RouteInfo(
route=route,
cache_config=cache_configuration,
Expand Down Expand Up @@ -348,6 +355,30 @@ def _extract_cache_configs_from_route(

return cache_config, cache_drop_config

def _convert_methods_to_path(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

это самописный метод? или выдрал из фастапи/старлета?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Самописный с использованием встроенных методов из старлета: compile_path, get_name

self,
route_names: dict[str, str],
cache_drop_config: CacheDropConfig | None,
) -> list[re.Pattern] | None:
if not cache_drop_config:
return None

unique: dict[str, re.Pattern] = {}

for method in cache_drop_config.methods:
name = get_name(method)
route = route_names.get(name)
if not route:
continue

regex = compile_path(route)[0]
key = regex.pattern

if key not in unique:
unique[key] = regex

return list(unique.values())

def _find_matching_route(
self, request: Request, routes_info: list[RouteInfo]
) -> tp.Optional[RouteInfo]:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def app() -> FastAPI:
app.router.add_api_route(
"/users/{user_id}",
delete_user,
dependencies=[CacheDropConfig(paths=["/users/"])],
dependencies=[CacheDropConfig(methods=[get_user])],
methods={HTTPMethod.DELETE.value},
)
app.router.add_api_route(
Expand Down