diff --git a/examples/basic.py b/examples/basic.py index d5929e3..b5f39e3 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -115,6 +115,19 @@ async def get_users() -> tp.List[UserResponse]: ] +@app.get("/orgs/{org_id}/users/{user_id}", dependencies=[CacheConfig(max_age=300)]) +async def get_user_in_org(org_id: int, user_id: int) -> UserResponse: + """Получение пользователя в конкретной организации. + + Пример более сложного пути с несколькими параметрами. + """ + user = _USERS_STORAGE.get(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return UserResponse(user_id=user_id, name=user.name, email=user.email) + + @app.post("/users/{user_id}", dependencies=[CacheDropConfig(paths=["/users"])]) async def create_user(user_id: int, user_data: User) -> UserResponse: """Создание пользователя с инвалидацией кеша. @@ -137,7 +150,10 @@ async def update_user(user_id: int, user_data: User) -> UserResponse: return UserResponse(user_id=user_id, name=user_data.name, email=user_data.email) -@app.delete("/users/{user_id}", dependencies=[CacheDropConfig(paths=["/users"])]) +@app.delete( + "/users/{user_id}", + dependencies=[CacheDropConfig(methods=[get_user, get_user_in_org])], +) async def delete_user(user_id: int) -> UserResponse: """Удаление пользователя с инвалидацией кеша.""" user = _USERS_STORAGE.get(user_id) diff --git a/fast_cache_middleware/depends.py b/fast_cache_middleware/depends.py index 9b4528a..88756f4 100644 --- a/fast_cache_middleware/depends.py +++ b/fast_cache_middleware/depends.py @@ -46,9 +46,14 @@ class CacheDropConfig(BaseCacheConfigDepends): that matches the beginning of request path. """ - def __init__(self, paths: list[str | re.Pattern]) -> None: + def __init__( + self, + paths: list[str | re.Pattern] | None = None, + methods: list[Callable] | None = None, + ) -> None: self.paths: list[re.Pattern] = [ - p if isinstance(p, re.Pattern) else re.compile(f"^{p}") for p in paths + p if isinstance(p, re.Pattern) else re.compile(f"^{p}") for p in paths or [] ] + self.methods: list[Callable] = methods or [] self.dependency = self diff --git a/fast_cache_middleware/middleware.py b/fast_cache_middleware/middleware.py index ac009c2..859c09a 100644 --- a/fast_cache_middleware/middleware.py +++ b/fast_cache_middleware/middleware.py @@ -1,11 +1,12 @@ import copy import logging +import re import typing as tp from fastapi import FastAPI, routing from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Match, Mount +from starlette.routing import Match, Mount, compile_path, get_name from starlette.types import ASGIApp, Receive, Scope, Send from ._helpers import set_cache_age_in_openapi_schema @@ -294,12 +295,19 @@ def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo routes: List of routes to analyze """ routes_info = [] + route_names = {route.name: route.path for route in routes} + for route in routes: ( cache_config, cache_drop_config, ) = self._extract_cache_configs_from_route(route) + paths = self._convert_methods_to_path(route_names, cache_drop_config) + + if cache_drop_config and paths is not None: + cache_drop_config.paths.extend(paths) + if cache_config or cache_drop_config: cache_configuration = CacheConfiguration( max_age=cache_config.max_age if cache_config else None, @@ -308,7 +316,6 @@ def _extract_routes_info(self, routes: list[routing.APIRoute]) -> list[RouteInfo cache_drop_config.paths if cache_drop_config else None ), ) - route_info = RouteInfo( route=route, cache_config=cache_configuration, @@ -348,6 +355,30 @@ def _extract_cache_configs_from_route( return cache_config, cache_drop_config + def _convert_methods_to_path( + self, + route_names: dict[str, str], + cache_drop_config: CacheDropConfig | None, + ) -> list[re.Pattern] | None: + if not cache_drop_config: + return None + + unique: dict[str, re.Pattern] = {} + + for method in cache_drop_config.methods: + name = get_name(method) + route = route_names.get(name) + if not route: + continue + + regex = compile_path(route)[0] + key = regex.pattern + + if key not in unique: + unique[key] = regex + + return list(unique.values()) + def _find_matching_route( self, request: Request, routes_info: list[RouteInfo] ) -> tp.Optional[RouteInfo]: diff --git a/tests/conftest.py b/tests/conftest.py index 7be5a19..5bb9af7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,7 +85,7 @@ def app() -> FastAPI: app.router.add_api_route( "/users/{user_id}", delete_user, - dependencies=[CacheDropConfig(paths=["/users/"])], + dependencies=[CacheDropConfig(methods=[get_user])], methods={HTTPMethod.DELETE.value}, ) app.router.add_api_route(