Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions backend/middleware/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import os
import time
from collections import defaultdict
from fastapi import Request, HTTPException

from utils import get_client_ip
TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true"


def get_client_ip(request: Request) -> str:
if TRUST_PROXY:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"


class RateLimiter:
Expand All @@ -16,14 +24,34 @@ class RateLimiter:
def __init__(self, requests_per_minute: int = 15):
self.requests_per_minute = requests_per_minute
self.window_seconds = 60
self.requests: dict[str, list[float]] = defaultdict(list)
self.requests: dict[str, list[float]] = {}
self.last_cleanup = time.time()

def _cleanup_old_requests(self, client_ip: str, current_time: float) -> None:
"""Remove requests outside the sliding window."""
if client_ip not in self.requests:
return

cutoff = current_time - self.window_seconds
self.requests[client_ip] = [
ts for ts in self.requests[client_ip] if ts > cutoff
]
if not self.requests[client_ip]:
del self.requests[client_ip]

def _cleanup_stale_entries(self, current_time: float) -> None:
"""Periodically remove stale client entries from memory."""
if current_time - self.last_cleanup < self.window_seconds:
return

cutoff = current_time - self.window_seconds
for client_ip, timestamps in list(self.requests.items()):
fresh = [ts for ts in timestamps if ts > cutoff]
if fresh:
self.requests[client_ip] = fresh
else:
del self.requests[client_ip]
self.last_cleanup = current_time

def check_rate_limit(self, request: Request) -> None:
"""
Expand All @@ -35,16 +63,19 @@ def check_rate_limit(self, request: Request) -> None:
client_ip = get_client_ip(request)
current_time = time.time()

self._cleanup_stale_entries(current_time)
self._cleanup_old_requests(client_ip, current_time)

if len(self.requests[client_ip]) >= self.requests_per_minute:
client_requests = self.requests.setdefault(client_ip, [])

if len(client_requests) >= self.requests_per_minute:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Maximum {self.requests_per_minute} requests per minute.",
headers={"Retry-After": "60"}
)

self.requests[client_ip].append(current_time)
client_requests.append(current_time)


# Global rate limiter instance
Expand Down
14 changes: 13 additions & 1 deletion backend/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from .request import get_client_ip
import os

from fastapi import Request

TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true"


def get_client_ip(request: Request) -> str:
if TRUST_PROXY:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"

__all__ = ["get_client_ip"]
25 changes: 8 additions & 17 deletions backend/utils/request.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from fastapi import Request

import os

def get_client_ip(request: Request) -> str:
"""
Extract the real client IP from a request.
from fastapi import Request

Checks X-Forwarded-For header first (for requests behind proxies/load balancers),
then falls back to the direct client host.
TRUST_PROXY = os.getenv("TRUST_PROXY", "false").lower() == "true"

Args:
request: FastAPI Request object

Returns:
Client IP address string, or "unknown" if not available
"""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
# X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2, ...
# The first one is the original client IP
return forwarded.split(",")[0].strip()
def get_client_ip(request: Request) -> str:
if TRUST_PROXY:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"