Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ async def rate_limit_middleware(request: Request, call_next):

# AVA runs behind Caddy (trusted proxy). Caddy appends the real client IP to X-Forwarded-For.
# We take the last IP in the list as the client IP.
if xff := request.headers.get("X-Forwarded-For"):
client_ip = xff.split(",")[-1].strip()
# We use getlist() to combine all X-Forwarded-For headers, preventing spoofing via multiple headers.
xff_list = request.headers.getlist("X-Forwarded-For")
if xff_list:
client_ip = ",".join(xff_list).split(",")[-1].strip()

if not _rate_limiter.is_allowed(client_ip):
logger.warning(f"Rate limit exceeded for {client_ip}")
Expand Down
56 changes: 56 additions & 0 deletions tests/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
from fastapi.testclient import TestClient
from app.main import app, _rate_limiter

class TestRateLimitBypass(unittest.TestCase):
def setUp(self):
# Reset the rate limiter before each test
_rate_limiter._hits.clear()
self.client = TestClient(app)

def test_multiple_x_forwarded_for_headers(self):
"""
Test that an attacker cannot bypass the rate limiter by sending
multiple X-Forwarded-For headers. Starlette's `request.headers.get()`
only returns the first header value, so we must use `getlist()`.
"""
# Make 35 requests where the "attacker" changes the first X-Forwarded-For
# header, but the proxy (simulated) appends the real IP in a second header.

real_ip = b"1.2.3.4"

for i in range(35):
# TestClient accepts headers as a list of tuples to allow multiple headers with the same name
headers = [
(b"x-forwarded-for", f"spoofed_{i}".encode("utf-8")),
(b"x-forwarded-for", real_ip)
]

response = self.client.get("/health", headers=headers)

# The rate limiter is set to 30 requests per minute.
# So the 31st request and beyond should be blocked.
if i < 30:
self.assertEqual(response.status_code, 200, f"Request {i+1} failed early with status {response.status_code}")
else:
self.assertEqual(response.status_code, 429, f"Request {i+1} bypassed rate limiter! Expected 429, got {response.status_code}")

def test_comma_separated_x_forwarded_for(self):
"""
Test that comma-separated X-Forwarded-For values correctly extract the last IP.
"""
for i in range(35):
# Simulate a single header with comma-separated values
headers = [
(b"x-forwarded-for", f"spoofed_{i}, 1.2.3.4".encode("utf-8")),
]

response = self.client.get("/health", headers=headers)

if i < 30:
self.assertEqual(response.status_code, 200)
else:
self.assertEqual(response.status_code, 429)

if __name__ == '__main__':
unittest.main()