diff --git "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/1.LLM \351\203\250\347\275\262/projects/fastapi-llm-api/app/main.py" "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/1.LLM \351\203\250\347\275\262/projects/fastapi-llm-api/app/main.py" index d80f8f4..e48a48f 100644 --- "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/1.LLM \351\203\250\347\275\262/projects/fastapi-llm-api/app/main.py" +++ "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/1.LLM \351\203\250\347\275\262/projects/fastapi-llm-api/app/main.py" @@ -61,12 +61,16 @@ async def lifespan(app: FastAPI): ) # CORS 中間件 +# 從環境變量讀取允許的來源,生產環境應設置具體域名 +import os +ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8080").split(",") + app.add_middleware( CORSMiddleware, - allow_origins=["*"], # 生產環境應設置具體域名 + allow_origins=ALLOWED_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], ) diff --git "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/basic_apis/01_openai_basic.py" "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/basic_apis/01_openai_basic.py" index 7bd4c14..4331f55 100644 --- "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/basic_apis/01_openai_basic.py" +++ "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/basic_apis/01_openai_basic.py" @@ -130,12 +130,51 @@ def get_weather(location: str, unit: str = "celsius") -> Dict[str, Any]: return data def calculate(expression: str) -> float: - """安全地計算數學表達式""" + """安全地計算數學表達式 + + 使用 ast.literal_eval 和 operator 模組的安全方法, + 避免 eval() 的安全風險。 + """ + import ast + import operator + + # 支援的安全運算符 + operators = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + def safe_eval(node): + """遞迴安全計算 AST 節點""" + if isinstance(node, ast.Constant): # 數字 + return node.value + elif isinstance(node, ast.BinOp): # 二元運算 + left = safe_eval(node.left) + right = safe_eval(node.right) + op = operators.get(type(node.op)) + if op is None: + raise ValueError(f"不支援的運算符: {type(node.op).__name__}") + return op(left, right) + elif isinstance(node, ast.UnaryOp): # 一元運算 + operand = safe_eval(node.operand) + op = operators.get(type(node.op)) + if op is None: + raise ValueError(f"不支援的運算符: {type(node.op).__name__}") + return op(operand) + else: + raise ValueError(f"不支援的表達式類型: {type(node).__name__}") + try: - # 注意:在生產環境中應該使用更安全的方法 - return eval(expression, {"__builtins__": {}}) - except: - return "計算錯誤" + # 解析表達式為 AST + tree = ast.parse(expression, mode='eval') + return safe_eval(tree.body) + except Exception as e: + return f"計算錯誤: {e}" # 可用的函數映射 available_functions = { diff --git "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/frontend_integration/fastapi_backend/main.py" "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/frontend_integration/fastapi_backend/main.py" index ff886af..ecb926a 100644 --- "a/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/frontend_integration/fastapi_backend/main.py" +++ "b/3.LLM\346\207\211\347\224\250\345\267\245\347\250\213/2.LLM as API/examples/frontend_integration/fastapi_backend/main.py" @@ -177,11 +177,26 @@ def init_clients(): def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: - """驗證 API Key""" - expected_key = os.getenv("API_KEY", "your-secret-key") + """驗證 API Key - if credentials.credentials != expected_key: - logger.warning(f"無效的 API Key 嘗試") + 使用 secrets.compare_digest 進行常數時間比較, + 防止時序攻擊(timing attack)。 + """ + import secrets + + expected_key = os.getenv("API_KEY") + + # 確保 API_KEY 已設置 + if not expected_key: + logger.error("API_KEY 環境變量未設置") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Server configuration error" + ) + + # 使用常數時間比較防止時序攻擊 + if not secrets.compare_digest(credentials.credentials, expected_key): + logger.warning("無效的 API Key 嘗試") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key" diff --git "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/main.py" "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/main.py" index 267dcc3..88a48b5 100644 --- "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/main.py" +++ "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/main.py" @@ -9,26 +9,50 @@ from fastapi.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel from typing import List, Optional +from contextlib import asynccontextmanager import uvicorn import asyncio from rag_engine import RAGEngine +from middleware.rate_limiter import rate_limiter, rate_limit_middleware + + +# 生命週期管理 +@asynccontextmanager +async def lifespan(app: FastAPI): + """應用生命週期管理""" + # 啟動時 + await rate_limiter.start() + yield + # 關閉時 + await rate_limiter.stop() + # 創建 FastAPI 應用 app = FastAPI( title="RAG ChatBot API", description="檢索增強生成聊天機器人", - version="1.0.0" + version="1.0.0", + lifespan=lifespan ) -# CORS 配置 +# CORS 配置 - 從環境變量讀取允許的來源 +import os +ALLOWED_ORIGINS = os.getenv( + "ALLOWED_ORIGINS", + "http://localhost:3000,http://localhost:8080,http://127.0.0.1:3000" +).split(",") + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=ALLOWED_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], ) +# 速率限制中間件 +app.middleware("http")(rate_limit_middleware) + # 初始化 RAG 引擎 rag_engine = RAGEngine() diff --git "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/__init__.py" "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/__init__.py" new file mode 100644 index 0000000..5a210a5 --- /dev/null +++ "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/__init__.py" @@ -0,0 +1,19 @@ +""" +中間件模組 + +提供 FastAPI 應用的各種中間件功能。 +""" + +from .rate_limiter import ( + RateLimiter, + rate_limiter, + rate_limit, + rate_limit_middleware +) + +__all__ = [ + 'RateLimiter', + 'rate_limiter', + 'rate_limit', + 'rate_limit_middleware' +] diff --git "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/rate_limiter.py" "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/rate_limiter.py" new file mode 100644 index 0000000..a0819de --- /dev/null +++ "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/middleware/rate_limiter.py" @@ -0,0 +1,282 @@ +""" +速率限制中間件 + +提供基於 IP 和 API Key 的速率限制功能,防止 API 濫用。 +支持滑動窗口算法和令牌桶算法。 +""" + +import time +from collections import defaultdict +from typing import Callable, Optional +from functools import wraps +import asyncio +from fastapi import Request, HTTPException, status +from fastapi.responses import JSONResponse +import logging + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """ + 速率限制器 + + 使用滑動窗口算法實現請求限制。 + + Attributes: + requests_per_minute: 每分鐘允許的請求數 + requests_per_hour: 每小時允許的請求數 + """ + + def __init__( + self, + requests_per_minute: int = 60, + requests_per_hour: int = 1000, + burst_limit: int = 10 + ): + self.requests_per_minute = requests_per_minute + self.requests_per_hour = requests_per_hour + self.burst_limit = burst_limit + + # 存儲請求記錄 {client_id: [(timestamp, count), ...]} + self._minute_requests: dict[str, list[float]] = defaultdict(list) + self._hour_requests: dict[str, list[float]] = defaultdict(list) + self._burst_requests: dict[str, list[float]] = defaultdict(list) + + # 清理任務 + self._cleanup_task: Optional[asyncio.Task] = None + + async def start(self): + """啟動清理任務""" + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop(self): + """停止清理任務""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + async def _cleanup_loop(self): + """定期清理過期的請求記錄""" + while True: + await asyncio.sleep(60) # 每分鐘清理一次 + self._cleanup_old_requests() + + def _cleanup_old_requests(self): + """清理過期請求記錄""" + current_time = time.time() + + # 清理分鐘級記錄 + for client_id in list(self._minute_requests.keys()): + self._minute_requests[client_id] = [ + t for t in self._minute_requests[client_id] + if current_time - t < 60 + ] + if not self._minute_requests[client_id]: + del self._minute_requests[client_id] + + # 清理小時級記錄 + for client_id in list(self._hour_requests.keys()): + self._hour_requests[client_id] = [ + t for t in self._hour_requests[client_id] + if current_time - t < 3600 + ] + if not self._hour_requests[client_id]: + del self._hour_requests[client_id] + + def _get_client_id(self, request: Request) -> str: + """ + 獲取客戶端標識符 + + 優先使用 API Key,否則使用 IP 地址 + """ + # 嘗試從 Header 獲取 API Key + api_key = request.headers.get("X-API-Key") or request.headers.get("Authorization") + if api_key: + # 移除 "Bearer " 前綴(如果有) + if api_key.startswith("Bearer "): + api_key = api_key[7:] + return f"key:{api_key[:16]}" # 只使用前 16 字符 + + # 獲取客戶端 IP + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + ip = forwarded.split(",")[0].strip() + else: + ip = request.client.host if request.client else "unknown" + + return f"ip:{ip}" + + async def check_rate_limit(self, request: Request) -> bool: + """ + 檢查請求是否超過速率限制 + + Returns: + True 如果請求被允許,否則拋出 HTTPException + """ + client_id = self._get_client_id(request) + current_time = time.time() + + # 檢查突發限制(每秒) + burst_requests = [ + t for t in self._burst_requests[client_id] + if current_time - t < 1 + ] + if len(burst_requests) >= self.burst_limit: + logger.warning(f"Rate limit exceeded (burst) for client: {client_id}") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "error": "Rate limit exceeded", + "message": "Too many requests per second", + "retry_after": 1 + }, + headers={"Retry-After": "1"} + ) + + # 檢查分鐘限制 + minute_requests = [ + t for t in self._minute_requests[client_id] + if current_time - t < 60 + ] + if len(minute_requests) >= self.requests_per_minute: + retry_after = int(60 - (current_time - minute_requests[0])) + logger.warning(f"Rate limit exceeded (minute) for client: {client_id}") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "error": "Rate limit exceeded", + "message": f"Too many requests per minute. Limit: {self.requests_per_minute}", + "retry_after": retry_after + }, + headers={"Retry-After": str(retry_after)} + ) + + # 檢查小時限制 + hour_requests = [ + t for t in self._hour_requests[client_id] + if current_time - t < 3600 + ] + if len(hour_requests) >= self.requests_per_hour: + retry_after = int(3600 - (current_time - hour_requests[0])) + logger.warning(f"Rate limit exceeded (hour) for client: {client_id}") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail={ + "error": "Rate limit exceeded", + "message": f"Too many requests per hour. Limit: {self.requests_per_hour}", + "retry_after": retry_after + }, + headers={"Retry-After": str(retry_after)} + ) + + # 記錄請求 + self._burst_requests[client_id].append(current_time) + self._minute_requests[client_id].append(current_time) + self._hour_requests[client_id].append(current_time) + + return True + + def get_remaining_requests(self, request: Request) -> dict: + """獲取剩餘請求配額""" + client_id = self._get_client_id(request) + current_time = time.time() + + minute_requests = len([ + t for t in self._minute_requests[client_id] + if current_time - t < 60 + ]) + hour_requests = len([ + t for t in self._hour_requests[client_id] + if current_time - t < 3600 + ]) + + return { + "remaining_per_minute": max(0, self.requests_per_minute - minute_requests), + "remaining_per_hour": max(0, self.requests_per_hour - hour_requests), + "limit_per_minute": self.requests_per_minute, + "limit_per_hour": self.requests_per_hour + } + + +# 全局速率限制器實例 +rate_limiter = RateLimiter( + requests_per_minute=60, + requests_per_hour=1000, + burst_limit=10 +) + + +def rate_limit( + requests_per_minute: Optional[int] = None, + requests_per_hour: Optional[int] = None +): + """ + 速率限制裝飾器 + + 用於對特定端點應用自定義速率限制 + + Usage: + @app.get("/api/chat") + @rate_limit(requests_per_minute=10) + async def chat(): + ... + """ + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + # 獲取 request 對象 + request = kwargs.get('request') + if not request: + for arg in args: + if isinstance(arg, Request): + request = arg + break + + if request: + # 使用自定義限制或默認限制 + custom_limiter = RateLimiter( + requests_per_minute=requests_per_minute or rate_limiter.requests_per_minute, + requests_per_hour=requests_per_hour or rate_limiter.requests_per_hour + ) + await custom_limiter.check_rate_limit(request) + + return await func(*args, **kwargs) + return wrapper + return decorator + + +async def rate_limit_middleware(request: Request, call_next): + """ + FastAPI 速率限制中間件 + + Usage: + app.middleware("http")(rate_limit_middleware) + """ + # 跳過健康檢查和靜態文件 + skip_paths = ["/health", "/api/health", "/docs", "/redoc", "/openapi.json"] + if any(request.url.path.startswith(path) for path in skip_paths): + return await call_next(request) + + try: + await rate_limiter.check_rate_limit(request) + response = await call_next(request) + + # 添加速率限制相關的響應頭 + remaining = rate_limiter.get_remaining_requests(request) + response.headers["X-RateLimit-Limit-Minute"] = str(remaining["limit_per_minute"]) + response.headers["X-RateLimit-Remaining-Minute"] = str(remaining["remaining_per_minute"]) + response.headers["X-RateLimit-Limit-Hour"] = str(remaining["limit_per_hour"]) + response.headers["X-RateLimit-Remaining-Hour"] = str(remaining["remaining_per_hour"]) + + return response + + except HTTPException as e: + return JSONResponse( + status_code=e.status_code, + content=e.detail, + headers=dict(e.headers) if e.headers else None + ) diff --git "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/tests/test_rate_limiter.py" "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/tests/test_rate_limiter.py" new file mode 100644 index 0000000..940d148 --- /dev/null +++ "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/RAG-ChatBot/tests/test_rate_limiter.py" @@ -0,0 +1,281 @@ +""" +速率限制中間件測試 +測試 rate_limiter.py 的功能 +""" + +import pytest +import asyncio +import time +from unittest.mock import Mock, AsyncMock, patch +from fastapi import HTTPException +import sys +import os + +# 添加父目錄到路徑 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from middleware.rate_limiter import RateLimiter, rate_limiter, rate_limit_middleware + + +class TestRateLimiter: + """RateLimiter 類測試""" + + @pytest.fixture + def limiter(self): + """創建測試用的限制器""" + return RateLimiter( + requests_per_minute=5, + requests_per_hour=20, + burst_limit=2 + ) + + @pytest.fixture + def mock_request(self): + """創建 Mock 請求對象""" + request = Mock() + request.headers = {} + request.client = Mock() + request.client.host = "192.168.1.1" + request.url = Mock() + request.url.path = "/api/chat" + return request + + def test_initialization(self, limiter): + """測試初始化""" + assert limiter.requests_per_minute == 5 + assert limiter.requests_per_hour == 20 + assert limiter.burst_limit == 2 + + def test_get_client_id_from_ip(self, limiter, mock_request): + """測試從 IP 獲取客戶端 ID""" + client_id = limiter._get_client_id(mock_request) + assert client_id == "ip:192.168.1.1" + + def test_get_client_id_from_api_key(self, limiter, mock_request): + """測試從 API Key 獲取客戶端 ID""" + mock_request.headers = {"X-API-Key": "test_api_key_12345678"} + client_id = limiter._get_client_id(mock_request) + assert client_id.startswith("key:") + assert "test_api_key_123" in client_id + + def test_get_client_id_from_bearer_token(self, limiter, mock_request): + """測試從 Bearer Token 獲取客戶端 ID""" + mock_request.headers = {"Authorization": "Bearer test_token_12345678"} + client_id = limiter._get_client_id(mock_request) + assert client_id.startswith("key:") + + def test_get_client_id_from_forwarded_header(self, limiter, mock_request): + """測試從 X-Forwarded-For 獲取 IP""" + mock_request.headers = {"X-Forwarded-For": "10.0.0.1, 192.168.1.1"} + client_id = limiter._get_client_id(mock_request) + assert client_id == "ip:10.0.0.1" + + @pytest.mark.asyncio + async def test_check_rate_limit_allows_request(self, limiter, mock_request): + """測試允許正常請求""" + result = await limiter.check_rate_limit(mock_request) + assert result is True + + @pytest.mark.asyncio + async def test_check_rate_limit_burst_exceeded(self, limiter, mock_request): + """測試超過突發限制""" + # 快速發送超過突發限制的請求 + await limiter.check_rate_limit(mock_request) + await limiter.check_rate_limit(mock_request) + + # 第三個請求應該被拒絕 + with pytest.raises(HTTPException) as exc_info: + await limiter.check_rate_limit(mock_request) + + assert exc_info.value.status_code == 429 + assert "Too many requests per second" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_check_rate_limit_minute_exceeded(self, limiter, mock_request): + """測試超過分鐘限制""" + # 模擬突發請求分散在時間上 + for i in range(5): + # 清除突發記錄以避免觸發突發限制 + limiter._burst_requests.clear() + await limiter.check_rate_limit(mock_request) + + # 清除突發記錄 + limiter._burst_requests.clear() + + # 第六個請求應該被拒絕 + with pytest.raises(HTTPException) as exc_info: + await limiter.check_rate_limit(mock_request) + + assert exc_info.value.status_code == 429 + assert "minute" in str(exc_info.value.detail).lower() + + @pytest.mark.asyncio + async def test_get_remaining_requests(self, limiter, mock_request): + """測試獲取剩餘請求配額""" + # 發送一個請求 + limiter._burst_requests.clear() + await limiter.check_rate_limit(mock_request) + + remaining = limiter.get_remaining_requests(mock_request) + + assert remaining["remaining_per_minute"] == 4 # 5 - 1 + assert remaining["remaining_per_hour"] == 19 # 20 - 1 + assert remaining["limit_per_minute"] == 5 + assert remaining["limit_per_hour"] == 20 + + def test_cleanup_old_requests(self, limiter): + """測試清理過期請求記錄""" + old_time = time.time() - 120 # 2 分鐘前 + current_time = time.time() + + limiter._minute_requests["test_client"] = [old_time, current_time] + limiter._hour_requests["test_client"] = [old_time, current_time] + + limiter._cleanup_old_requests() + + # 舊的分鐘記錄應該被清理 + assert len(limiter._minute_requests["test_client"]) == 1 + # 舊的小時記錄(2分鐘前)應該仍然保留 + assert len(limiter._hour_requests["test_client"]) == 2 + + @pytest.mark.asyncio + async def test_start_and_stop(self, limiter): + """測試啟動和停止清理任務""" + await limiter.start() + assert limiter._cleanup_task is not None + + await limiter.stop() + assert limiter._cleanup_task.cancelled() or limiter._cleanup_task.done() + + +class TestRateLimitMiddleware: + """rate_limit_middleware 測試""" + + @pytest.fixture + def mock_request(self): + """創建 Mock 請求""" + request = Mock() + request.headers = {} + request.client = Mock() + request.client.host = "192.168.1.1" + request.url = Mock() + request.url.path = "/api/chat" + return request + + @pytest.fixture + def mock_call_next(self): + """創建 Mock call_next""" + async def call_next(request): + response = Mock() + response.headers = {} + return response + return call_next + + @pytest.mark.asyncio + async def test_middleware_allows_request(self, mock_request, mock_call_next): + """測試中間件允許正常請求""" + # 重置全局限制器 + rate_limiter._minute_requests.clear() + rate_limiter._hour_requests.clear() + rate_limiter._burst_requests.clear() + + response = await rate_limit_middleware(mock_request, mock_call_next) + + assert response is not None + assert "X-RateLimit-Limit-Minute" in response.headers + + @pytest.mark.asyncio + async def test_middleware_skips_health_check(self, mock_request, mock_call_next): + """測試中間件跳過健康檢查路徑""" + mock_request.url.path = "/api/health" + + response = await rate_limit_middleware(mock_request, mock_call_next) + + # 健康檢查不應該有速率限制頭 + assert response is not None + + @pytest.mark.asyncio + async def test_middleware_skips_docs(self, mock_request, mock_call_next): + """測試中間件跳過文檔路徑""" + mock_request.url.path = "/docs" + + response = await rate_limit_middleware(mock_request, mock_call_next) + assert response is not None + + +class TestRateLimitDecorator: + """rate_limit 裝飾器測試""" + + @pytest.mark.asyncio + async def test_decorator_creates_custom_limiter(self): + """測試裝飾器創建自定義限制器""" + from middleware.rate_limiter import rate_limit + + @rate_limit(requests_per_minute=2) + async def test_endpoint(request=None): + return "success" + + mock_request = Mock() + mock_request.headers = {} + mock_request.client = Mock() + mock_request.client.host = "192.168.1.100" + + # 前兩個請求應該成功 + result = await test_endpoint(request=mock_request) + assert result == "success" + + +class TestDifferentClients: + """不同客戶端的測試""" + + @pytest.fixture + def limiter(self): + """創建限制器""" + return RateLimiter(requests_per_minute=2, requests_per_hour=10, burst_limit=2) + + @pytest.mark.asyncio + async def test_different_ips_have_separate_limits(self, limiter): + """測試不同 IP 有獨立的限制""" + request1 = Mock() + request1.headers = {} + request1.client = Mock() + request1.client.host = "192.168.1.1" + + request2 = Mock() + request2.headers = {} + request2.client = Mock() + request2.client.host = "192.168.1.2" + + # 客戶端1發送請求 + await limiter.check_rate_limit(request1) + await limiter.check_rate_limit(request1) + + # 客戶端2應該仍然可以發送請求 + result = await limiter.check_rate_limit(request2) + assert result is True + + @pytest.mark.asyncio + async def test_api_key_overrides_ip(self, limiter): + """測試 API Key 優先於 IP""" + # 相同 IP,不同 API Key + request1 = Mock() + request1.headers = {"X-API-Key": "key_user_1"} + request1.client = Mock() + request1.client.host = "192.168.1.1" + + request2 = Mock() + request2.headers = {"X-API-Key": "key_user_2"} + request2.client = Mock() + request2.client.host = "192.168.1.1" + + # 用戶1發送請求直到限制 + await limiter.check_rate_limit(request1) + await limiter.check_rate_limit(request1) + + # 用戶2(不同 API Key)應該仍然可以發送 + result = await limiter.check_rate_limit(request2) + assert result is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/web-ui/app/layout.tsx" "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/web-ui/app/layout.tsx" index 5816d78..e63b1bd 100644 --- "a/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/web-ui/app/layout.tsx" +++ "b/5.AI\347\240\224\347\251\266\345\211\215\346\262\277_2024-2025/\345\257\246\346\210\260\351\240\205\347\233\256/web-ui/app/layout.tsx" @@ -1,6 +1,7 @@ import type { Metadata } from 'next' import { Inter } from 'next/font/google' import './globals.css' +import { ErrorBoundary } from '@/components/ErrorBoundary' const inter = Inter({ subsets: ['latin'] }) @@ -16,7 +17,11 @@ export default function RootLayout({ }) { return ( -
{children} + ++ 應用程序遇到了意外錯誤。請稍後重試。 +
+ + {process.env.NODE_ENV === 'development' && this.state.error && ( ++ {this.state.error.message} +
+ {this.state.errorInfo && ( +
+ {this.state.errorInfo.componentStack}
+
+ ( + WrappedComponent: React.ComponentType
,
+ fallback?: ReactNode
+) {
+ return function WithErrorBoundaryWrapper(props: P) {
+ return (
+ {subtitle}{title}
+ {subtitle && (
+