diff --git a/backend/middleware/rate_limiter.py b/backend/middleware/rate_limiter.py index e0abc7e..22600e9 100644 --- a/backend/middleware/rate_limiter.py +++ b/backend/middleware/rate_limiter.py @@ -1,8 +1,16 @@ +import os import time -from collections import defaultdict from fastapi import Request, HTTPException -from utils import get_client_ip +TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true" + + +def get_client_ip(request: Request) -> str: + if TRUST_PROXY: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" class RateLimiter: @@ -16,14 +24,34 @@ class RateLimiter: def __init__(self, requests_per_minute: int = 15): self.requests_per_minute = requests_per_minute self.window_seconds = 60 - self.requests: dict[str, list[float]] = defaultdict(list) + self.requests: dict[str, list[float]] = {} + self.last_cleanup = time.time() def _cleanup_old_requests(self, client_ip: str, current_time: float) -> None: """Remove requests outside the sliding window.""" + if client_ip not in self.requests: + return + cutoff = current_time - self.window_seconds self.requests[client_ip] = [ ts for ts in self.requests[client_ip] if ts > cutoff ] + if not self.requests[client_ip]: + del self.requests[client_ip] + + def _cleanup_stale_entries(self, current_time: float) -> None: + """Periodically remove stale client entries from memory.""" + if current_time - self.last_cleanup < self.window_seconds: + return + + cutoff = current_time - self.window_seconds + for client_ip, timestamps in list(self.requests.items()): + fresh = [ts for ts in timestamps if ts > cutoff] + if fresh: + self.requests[client_ip] = fresh + else: + del self.requests[client_ip] + self.last_cleanup = current_time def check_rate_limit(self, request: Request) -> None: """ @@ -35,16 +63,19 @@ def check_rate_limit(self, request: Request) -> None: client_ip = get_client_ip(request) current_time = time.time() + self._cleanup_stale_entries(current_time) self._cleanup_old_requests(client_ip, current_time) - if len(self.requests[client_ip]) >= self.requests_per_minute: + client_requests = self.requests.setdefault(client_ip, []) + + if len(client_requests) >= self.requests_per_minute: raise HTTPException( status_code=429, detail=f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.", headers={"Retry-After": "60"} ) - self.requests[client_ip].append(current_time) + client_requests.append(current_time) # Global rate limiter instance diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py index 984ccef..bf71320 100644 --- a/backend/utils/__init__.py +++ b/backend/utils/__init__.py @@ -1,3 +1,15 @@ -from .request import get_client_ip +import os + +from fastapi import Request + +TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true" + + +def get_client_ip(request: Request) -> str: + if TRUST_PROXY: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" __all__ = ["get_client_ip"] diff --git a/backend/utils/request.py b/backend/utils/request.py index f345182..19b6ac5 100644 --- a/backend/utils/request.py +++ b/backend/utils/request.py @@ -1,22 +1,13 @@ -from fastapi import Request - +import os -def get_client_ip(request: Request) -> str: - """ - Extract the real client IP from a request. +from fastapi import Request - Checks X-Forwarded-For header first (for requests behind proxies/load balancers), - then falls back to the direct client host. +TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true" - Args: - request: FastAPI Request object - Returns: - Client IP address string, or "unknown" if not available - """ - forwarded = request.headers.get("x-forwarded-for") - if forwarded: - # X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ... - # The first one is the original client IP - return forwarded.split(",")[0].strip() +def get_client_ip(request: Request) -> str: + if TRUST_PROXY: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() return request.client.host if request.client else "unknown"