From 91cf264b512ead31b2e34380940e5fbac59497e0 Mon Sep 17 00:00:00 2001 From: iklobato Date: Mon, 9 Mar 2026 22:25:47 -0300 Subject: [PATCH 1/3] feat(auth): implement comprehensive JWT and Basic authentication - Add JWT algorithm configuration with validation (HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512) - Implement BasicAuthentication class with proper 401 error responses - Add rate limiting for auth endpoints (10/min, 100/hour, 1000/day default) - Fix JWT expiration configuration not being passed to authentication instances - Add JWT claim validation to prevent overwriting reserved 'exp' claim - Consolidate Basic auth parsing in _login.py module - Add comprehensive test suite covering all auth scenarios - Update YAML configuration support for JWT algorithm and expiration - Fix typing issues in auth modules - Add proper error response differentiation (401 for auth failures, 403 for permissions) BREAKING CHANGE: JWTAuthentication constructor now accepts optional parameters: - expiration: Override default 3600s token lifetime - algorithm: Override default HS256 algorithm - secret_key: Custom secret (defaults to LIGHTAPI_JWT_SECRET) Security enhancements: - Rate limiting prevents brute force attacks on login endpoints - Basic auth returns proper WWW-Authenticate header semantics - JWT algorithm validation prevents insecure configurations --- README.md | 16 + docs/examples/yaml-configuration.md | 12 +- lightapi/__init__.py | 9 +- lightapi/_login.py | 131 ++++ lightapi/_registry.py | 46 +- lightapi/auth.py | 135 +++- lightapi/cache.py | 11 +- lightapi/config.py | 28 + lightapi/lightapi.py | 95 ++- lightapi/rate_limiter.py | 171 +++++ lightapi/yaml_loader.py | 88 ++- tests/test_auth.py | 7 +- tests/test_login_auth.py | 984 ++++++++++++++++++++++++++++ tests/test_yaml_config.py | 162 ++++- uv.lock | 2 +- 15 files changed, 1857 insertions(+), 40 deletions(-) create mode 100644 lightapi/_login.py create mode 100644 lightapi/rate_limiter.py create mode 100644 tests/test_login_auth.py diff --git a/README.md b/README.md index f2995a4..c19db4b 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,22 @@ class AdminOnlyEndpoint(RestEndpoint): 2. Permission class `.has_permission(request)` — checks `request.state.user` 3. Returns `401` if authentication fails, `403` if permission denied +**Login and token endpoints:** When using `JWTAuthentication` or `BasicAuthentication`, pass `login_validator` to obtain automatic `/auth/login` and `/auth/token` endpoints: + +```python +def my_validator(username: str, password: str): + # Return user payload dict or None + user = db.query(User).filter_by(username=username).first() + if user and user.check_password(password): + return {"sub": str(user.id), "is_admin": user.is_admin} + return None + +app = LightApi(engine=engine, login_validator=my_validator) +app.register({"/secrets": ProtectedEndpoint}) +# POST /auth/login and POST /auth/token now accept {"username":"...","password":"..."} +# JWT mode: 200 {"token":"...","user":{...}}; Basic-only: 200 {"user":{...}} +``` + **Built-in permission classes:** | Class | Condition | diff --git a/docs/examples/yaml-configuration.md b/docs/examples/yaml-configuration.md index 26b7b38..a9e2fc3 100644 --- a/docs/examples/yaml-configuration.md +++ b/docs/examples/yaml-configuration.md @@ -22,6 +22,8 @@ defaults: authentication: backend: JWTAuthentication permission: IsAuthenticated + jwt_expiration: 3600 + jwt_extra_claims: [sub, email] pagination: style: page_number page_size: 20 @@ -29,6 +31,10 @@ defaults: middleware: - CORSMiddleware +auth: + auth_path: /auth + login_validator: myapp.validators.validate_login + endpoints: - route: /products fields: @@ -130,8 +136,12 @@ python -c "from lightapi import LightApi; LightApi.from_config('lightapi.yaml'). |-------|------|-------------| | `database.url` | string | SQLAlchemy URL. Supports `${VAR}` substitution. | | `cors_origins` | list | CORS allowed origins. | -| `defaults.authentication.backend` | string | Auth backend class name. | +| `defaults.authentication.backend` | string | Auth backend class name (`JWTAuthentication`, `BasicAuthentication`). | | `defaults.authentication.permission` | string | Permission class name. | +| `defaults.authentication.jwt_expiration` | int | JWT token expiration in seconds (JWT only). | +| `defaults.authentication.jwt_extra_claims` | list | Claims to include in token payload (JWT only). | +| `auth.auth_path` | string | Path prefix for `/login` and `/token` (default `/auth`). | +| `auth.login_validator` | string | Dotted path to credential validator callable (e.g. `myapp.validators.check_user`). | | `defaults.pagination.style` | string | `page_number` or `cursor`. | | `defaults.pagination.page_size` | int | Rows per page. | | `middleware` | list | Class names resolved at startup. | diff --git a/lightapi/__init__.py b/lightapi/__init__.py index da005e4..54a9846 100644 --- a/lightapi/__init__.py +++ b/lightapi/__init__.py @@ -1,6 +1,12 @@ """LightAPI v2 public API.""" -from lightapi.auth import AllowAny, IsAdminUser, IsAuthenticated, JWTAuthentication +from lightapi.auth import ( + AllowAny, + BasicAuthentication, + IsAdminUser, + IsAuthenticated, + JWTAuthentication, +) from lightapi.cache import RedisCache from lightapi.config import Authentication, Cache, Filtering, Pagination, Serializer @@ -33,6 +39,7 @@ "Pagination", "Serializer", # Auth + "BasicAuthentication", "JWTAuthentication", "AllowAny", "IsAuthenticated", diff --git a/lightapi/_login.py b/lightapi/_login.py new file mode 100644 index 0000000..db50f0a --- /dev/null +++ b/lightapi/_login.py @@ -0,0 +1,131 @@ +"""Login and token endpoint handlers.""" + +from __future__ import annotations + +import base64 +import logging +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field +from starlette.requests import Request +from starlette.responses import JSONResponse + +from lightapi.auth import JWTAuthentication +from lightapi.rate_limiter import rate_limit_auth_endpoint + +logger = logging.getLogger(__name__) + + +class LoginRequest(BaseModel): + """Request body for POST /auth/login and /auth/token.""" + + model_config = ConfigDict(frozen=True) + + username: str = Field(min_length=1) + password: str = Field(min_length=1) + + +def _parse_basic_header(auth_header: str) -> tuple[str, str] | None: + """ + Decode Authorization: Basic header. + + Returns (username, password) or None if malformed. + """ + if not auth_header.startswith("Basic "): + return None + try: + token = auth_header.split(" ", 1)[1] + decoded = base64.b64decode(token).decode("utf-8") + except (ValueError, IndexError, UnicodeDecodeError): + return None + parts = decoded.split(":", 1) + if len(parts) != 2: + return None + return parts[0], parts[1] + + +async def _parse_credentials(request: Request) -> tuple[str, str] | None: + """ + Extract (username, password) from request. + + - If Authorization: Basic present: returns (u, p) or None if malformed. + - If no Basic header: reads body, validates with LoginRequest. + Returns (u, p) if valid. Raises ValidationError for body (caller returns 422). + - None means malformed Basic (caller returns 401). + """ + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Basic "): + return _parse_basic_header(auth_header) + + body = await _read_body(request) + parsed = LoginRequest.model_validate(body if body else {}) + return parsed.username, parsed.password + + +async def _read_body(request: Request) -> dict[str, Any]: + """Read JSON body; return {} on empty or invalid.""" + import json + + try: + body = await request.body() + return json.loads(body) if body else {} + except (json.JSONDecodeError, TypeError): + return {} + + +@rate_limit_auth_endpoint +async def login_handler( + request: Request, + *, + login_validator: Callable[[str, str], dict[str, Any] | None], + has_jwt: bool, + jwt_expiration: int | None = None, + jwt_extra_claims: list[str] | None = None, + jwt_algorithm: str | None = None, +) -> JSONResponse: + """ + Handle POST /auth/login and POST /auth/token. + + Returns 422 for body validation, 401 for malformed Basic or invalid credentials, + 500 for validator exception, 200 with token+user (JWT) or user only (Basic). + """ + from pydantic import ValidationError + + if request.method != "POST": + return JSONResponse( + {"detail": "method not allowed"}, + status_code=405, + headers={"Allow": "POST"}, + ) + + try: + creds = await _parse_credentials(request) + except ValidationError as exc: + return JSONResponse({"detail": exc.errors()}, status_code=422) + + if creds is None: + return JSONResponse({"detail": "Invalid credentials"}, status_code=401) + + username, password = creds + try: + payload = login_validator(username, password) + except Exception as e: + logger.exception("login_validator raised: %s", e) + raise + + if payload is None: + return JSONResponse({"detail": "Invalid credentials"}, status_code=401) + + if has_jwt: + jwt_auth = JWTAuthentication(algorithm=jwt_algorithm) + if jwt_extra_claims and isinstance(payload, dict): + token_payload = {k: payload[k] for k in jwt_extra_claims if k in payload} + if not token_payload: + token_payload = payload + else: + token_payload = payload + token = jwt_auth.generate_token(token_payload, expiration=jwt_expiration) + return JSONResponse({"token": token, "user": payload}) + + return JSONResponse({"user": payload}) diff --git a/lightapi/_registry.py b/lightapi/_registry.py index c4498d8..e6851f6 100644 --- a/lightapi/_registry.py +++ b/lightapi/_registry.py @@ -7,31 +7,53 @@ from __future__ import annotations +from collections.abc import Callable +from typing import Any, cast + from sqlalchemy import MetaData from sqlalchemy.orm import registry -_registry: registry | None = None -_metadata: MetaData | None = None -_engine: object | None = None +LoginValidator = Callable[[str, str], dict[str, Any] | None] + +_state: dict[str, object | None] = { + "registry": None, + "metadata": None, + "engine": None, + "login_validator": None, +} def get_registry_and_metadata() -> tuple[registry, MetaData]: - global _registry, _metadata - if _registry is None or _metadata is None: - _metadata = MetaData() - _registry = registry(metadata=_metadata) - return _registry, _metadata + reg = cast(registry | None, _state["registry"]) + meta = cast(MetaData | None, _state["metadata"]) + if reg is None or meta is None: + meta = MetaData() + reg = registry(metadata=meta) + _state["metadata"] = meta + _state["registry"] = reg + return reg, meta def set_engine(engine: object) -> None: - global _engine - _engine = engine + _state["engine"] = engine def get_engine() -> object: - if _engine is None: + engine = _state["engine"] + if engine is None: raise RuntimeError( "No engine configured. Call LightApi(engine=...) or ensure " "database connection is set before the first request." ) - return _engine + return engine + + +def set_login_validator(validator: LoginValidator) -> None: + _state["login_validator"] = validator + + +def get_login_validator() -> LoginValidator | None: + validator = _state["login_validator"] + if validator is None: + return None + return cast(LoginValidator, validator) diff --git a/lightapi/auth.py b/lightapi/auth.py index 5a7980e..ebbb942 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,10 +1,11 @@ from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Any, Dict, Optional import jwt from starlette.requests import Request from starlette.responses import JSONResponse +from ._registry import LoginValidator from .config import config @@ -16,7 +17,7 @@ class BaseAuthentication: By default, allows all requests. """ - def authenticate(self, request): + def authenticate(self, request: Request) -> bool: """ Authenticate a request. @@ -28,7 +29,7 @@ def authenticate(self, request): """ return True - def get_auth_error_response(self, request): + def get_auth_error_response(self, request: Request) -> JSONResponse: """ Get the response to return when authentication fails. @@ -55,16 +56,22 @@ class JWTAuthentication(BaseAuthentication): expiration: Token expiration time in seconds. """ - def __init__(self): - if not config.jwt_secret: + def __init__( + self, + secret_key: str | None = None, + algorithm: str | None = None, + expiration: int | None = None, + ): + self.secret_key = secret_key or config.jwt_secret + if not self.secret_key: raise ValueError( "JWT secret key not configured. Set LIGHTAPI_JWT_SECRET environment variable." ) - self.secret_key = config.jwt_secret - self.algorithm = "HS256" - self.expiration = 3600 # 1 hour default - def authenticate(self, request): + self.algorithm = algorithm or config.jwt_algorithm + self.expiration = expiration or 3600 # 1 hour default + + def authenticate(self, request: Request) -> bool: """ Authenticate a request using JWT token. Automatically allows OPTIONS requests for CORS preflight. @@ -101,7 +108,17 @@ def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str Returns: str: The encoded JWT token. + + Raises: + ValueError: If payload contains 'exp' claim which will be overwritten. """ + # Check for 'exp' in payload since we overwrite it + if "exp" in payload: + raise ValueError( + "Payload contains 'exp' claim which will be overwritten. " + "Use the 'expiration' parameter instead." + ) + exp_seconds = expiration or self.expiration token_data = { **payload, @@ -124,6 +141,106 @@ def decode_token(self, token: str) -> Dict: """ return jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. + + Args: + request: The HTTP request object. + + Returns: + Response object for authentication error. + """ + return JSONResponse({"error": "authentication failed"}, status_code=401) + + +class BasicAuthentication(BaseAuthentication): + """ + Basic (Base64) authentication. + + Authenticates requests using Authorization: Basic . + Delegates credential validation to the app-level login_validator from the registry. + """ + + def authenticate(self, request: Request) -> bool: + if request.method == "OPTIONS": + return True + + auth_header = request.headers.get("Authorization") + if not auth_header: + return False + + # Use the shared Basic auth parsing function + from lightapi._login import _parse_basic_header + + credentials = _parse_basic_header(auth_header) + if credentials is None: + return False + + username, password = credentials + from lightapi._registry import get_login_validator + + validator = get_login_validator() + if validator is None: + return False + + try: + payload = validator(username, password) + except Exception: + return False + + if payload is None: + return False + + request.state.user = payload + return True + + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. + + Args: + request: The HTTP request object. + + Returns: + Response object for authentication error. + """ + return JSONResponse({"error": "authentication failed"}, status_code=401) + + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Basic "): + return False + + try: + import base64 + + token = auth_header.split(" ", 1)[1] + decoded = base64.b64decode(token).decode("utf-8") + except (ValueError, IndexError, UnicodeDecodeError): + return False + + parts = decoded.split(":", 1) + if len(parts) != 2: + return False + + username, password = parts[0], parts[1] + from lightapi._registry import get_login_validator + + validator = get_login_validator() + if validator is None: + return False + + try: + payload = validator(username, password) + except Exception: + return False + + if payload is None: + return False + + request.state.user = payload + return True + class AllowAny: """Permits all requests regardless of authentication state.""" diff --git a/lightapi/cache.py b/lightapi/cache.py index b56fb0d..7d5c947 100644 --- a/lightapi/cache.py +++ b/lightapi/cache.py @@ -10,17 +10,18 @@ logger = logging.getLogger(__name__) _REDIS_URL = os.environ.get("LIGHTAPI_REDIS_URL", "redis://localhost:6379/0") -_redis_client: "redis.Redis | None" = None +_redis_state: dict[str, "redis.Redis | None"] = {"client": None} def _get_redis() -> "redis.Redis | None": - global _redis_client - if _redis_client is None: + if _redis_state["client"] is None: try: - _redis_client = redis.from_url(_REDIS_URL, socket_connect_timeout=1) + _redis_state["client"] = redis.from_url( + _REDIS_URL, socket_connect_timeout=1 + ) except Exception: return None - return _redis_client + return _redis_state["client"] def _ping_redis() -> bool: diff --git a/lightapi/config.py b/lightapi/config.py index 02646ae..e1ecc11 100644 --- a/lightapi/config.py +++ b/lightapi/config.py @@ -9,6 +9,18 @@ class _Config: """Configuration used by JWTAuthentication and other components.""" + VALID_JWT_ALGORITHMS = { + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + } + def __init__(self) -> None: self._overrides: dict[str, Any] = {} @@ -25,6 +37,16 @@ def _get(self, key: str, env_key: str, default: Any = None) -> Any: def jwt_secret(self) -> str | None: return self._get("jwt_secret", "LIGHTAPI_JWT_SECRET") + @property + def jwt_algorithm(self) -> str: + algorithm = self._get("jwt_algorithm", "LIGHTAPI_JWT_ALGORITHM", "HS256") + if algorithm not in self.VALID_JWT_ALGORITHMS: + raise ConfigurationError( + f"Invalid JWT algorithm '{algorithm}'. " + f"Valid algorithms are: {sorted(self.VALID_JWT_ALGORITHMS)}" + ) + return algorithm + config = _Config() @@ -36,6 +58,9 @@ def __init__( self, backend: type | None = None, permission: type | dict[str, type] | None = None, + jwt_expiration: int | None = None, + jwt_extra_claims: list[str] | None = None, + jwt_algorithm: str | None = None, ) -> None: from lightapi.auth import AllowAny @@ -43,6 +68,9 @@ def __init__( self.permission: type | dict[str, type] = ( permission if permission is not None else AllowAny ) + self.jwt_expiration = jwt_expiration + self.jwt_extra_claims = jwt_extra_claims + self.jwt_algorithm = jwt_algorithm class Filtering: diff --git a/lightapi/lightapi.py b/lightapi/lightapi.py index 8e087e7..567a8c1 100644 --- a/lightapi/lightapi.py +++ b/lightapi/lightapi.py @@ -44,6 +44,8 @@ def __init__( database_url: str | None = None, cors_origins: list[str] | None = None, middlewares: list[type] | None = None, + login_validator: Any = None, + auth_path: str = "/auth", ) -> None: if engine is None and database_url: engine = create_engine(database_url) @@ -71,6 +73,12 @@ def __init__( self._endpoint_map: dict[str, type] = {} self._middlewares: list[type] = middlewares or [] self._cors_origins: list[str] = cors_origins or [] + self._login_validator = login_validator + self._auth_path = auth_path + if login_validator is not None: + from lightapi._registry import set_login_validator + + set_login_validator(login_validator) # ───────────────────────────────────────────────────────────────────────── # Registration @@ -144,6 +152,83 @@ def register(self, mapping: dict[str, type]) -> None: self._routes.append(detail_route) self._endpoint_map[path] = cls + # Auto-register /auth/login and /auth/token when JWT or Basic auth is used + from lightapi.auth import BasicAuthentication, JWTAuthentication + + auth_backends: set[type] = set() + jwt_config_expiration: int | None = None + jwt_config_extra_claims: list[str] | None = None + jwt_config_algorithm: str | None = None + for cls in self._endpoint_map.values(): + auth_cfg = getattr(cls, "_meta", {}).get("authentication") + if auth_cfg and auth_cfg.backend: + auth_backends.add(auth_cfg.backend) + if ( + auth_cfg.backend is JWTAuthentication + and jwt_config_expiration is None + ): + jwt_config_expiration = getattr(auth_cfg, "jwt_expiration", None) + jwt_config_extra_claims = getattr( + auth_cfg, "jwt_extra_claims", None + ) + jwt_config_algorithm = getattr(auth_cfg, "jwt_algorithm", None) + + if auth_backends & {JWTAuthentication, BasicAuthentication}: + if self._login_validator is None: + raise ConfigurationError( + "login_validator is required when using JWTAuthentication " + "or BasicAuthentication. Pass it to LightApi(login_validator=...)." + ) + has_jwt = JWTAuthentication in auth_backends + auth_path = self._auth_path.rstrip("/") + login_endpoint = self._make_login_endpoint( + has_jwt=has_jwt, + jwt_expiration=jwt_config_expiration, + jwt_extra_claims=jwt_config_extra_claims, + jwt_algorithm=jwt_config_algorithm, + ) + self._routes.insert( + 0, + Route( + f"{auth_path}/login", + login_endpoint, + methods=["POST"], + ), + ) + self._routes.insert( + 1, + Route( + f"{auth_path}/token", + login_endpoint, + methods=["POST"], + ), + ) + + def _make_login_endpoint( + self, + *, + has_jwt: bool, + jwt_expiration: int | None, + jwt_extra_claims: list[str] | None, + jwt_algorithm: str | None, + ) -> Any: + """Create the login/token handler with captured config.""" + from lightapi._login import login_handler + + login_validator = self._login_validator + + async def handler(request: Request) -> Response: + return await login_handler( + request, + login_validator=login_validator, + has_jwt=has_jwt, + jwt_expiration=jwt_expiration, + jwt_extra_claims=jwt_extra_claims, + jwt_algorithm=jwt_algorithm, + ) + + return handler + def _make_collection_handler(self, cls: type) -> Any: app_middlewares = self._middlewares is_async = self._async @@ -461,7 +546,15 @@ def _check_auth(cls: type, request: Request) -> Response | None: perm_cls = AllowAny if backend is not None: - authenticator = backend() + # Pass JWT configuration if backend is JWTAuthentication + if backend.__name__ == "JWTAuthentication": + authenticator = backend( + expiration=getattr(auth_cfg, "jwt_expiration", None), + algorithm=getattr(auth_cfg, "jwt_algorithm", None), + ) + else: + authenticator = backend() + if not authenticator.authenticate(request): return JSONResponse( {"detail": "Authentication credentials invalid."}, status_code=401 diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py new file mode 100644 index 0000000..5676cb4 --- /dev/null +++ b/lightapi/rate_limiter.py @@ -0,0 +1,171 @@ +"""Simple rate limiting for authentication endpoints.""" + +from __future__ import annotations + +import time +from collections import defaultdict +from typing import Any, Callable + +from starlette.requests import Request +from starlette.responses import JSONResponse + + +class RateLimiter: + """ + Simple in-memory rate limiter. + + Tracks requests by IP address and endpoint. + """ + + def __init__( + self, + requests_per_minute: int = 10, + requests_per_hour: int = 100, + requests_per_day: int = 1000, + ) -> None: + self.requests_per_minute = requests_per_minute + self.requests_per_hour = requests_per_hour + self.requests_per_day = requests_per_day + + # Storage: {ip: {window: {timestamp: count}}} + self._store: dict[str, dict[str, dict[float, int]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(int)) + ) + self._cleanup_interval = 300 # Cleanup every 5 minutes + self._last_cleanup = time.time() + + def _cleanup_old_entries(self) -> None: + """Remove old entries to prevent memory leak.""" + current_time = time.time() + if current_time - self._last_cleanup < self._cleanup_interval: + return + + for ip in list(self._store.keys()): + for window in list(self._store[ip].keys()): + # Remove entries older than window size + window_seconds = self._get_window_seconds(window) + cutoff = current_time - window_seconds + + # Remove old timestamps + for timestamp in list(self._store[ip][window].keys()): + if timestamp < cutoff: + del self._store[ip][window][timestamp] + + # Remove empty windows + if not self._store[ip][window]: + del self._store[ip][window] + + # Remove IPs with no windows + if not self._store[ip]: + del self._store[ip] + + self._last_cleanup = current_time + + def _get_window_seconds(self, window: str) -> int: + """Convert window name to seconds.""" + if window == "minute": + return 60 + elif window == "hour": + return 3600 + elif window == "day": + return 86400 + else: + return 60 + + def _get_client_ip(self, request: Request) -> str: + """Extract client IP from request.""" + # Try common headers for proxy setups + for header in ("X-Forwarded-For", "X-Real-IP", "X-Client-IP"): + if header in request.headers: + ip = request.headers[header].split(",")[0].strip() + if ip: + return ip + + # Fall back to client host + return request.client.host if request.client else "0.0.0.0" + + def is_rate_limited(self, request: Request, endpoint: str = "") -> bool: + """ + Check if request should be rate limited. + + Args: + request: The HTTP request. + endpoint: Optional endpoint identifier for per-endpoint limiting. + + Returns: + bool: True if rate limited, False otherwise. + """ + self._cleanup_old_entries() + + client_ip = self._get_client_ip(request) + current_time = time.time() + + # Check each window + windows = [ + ("minute", self.requests_per_minute), + ("hour", self.requests_per_hour), + ("day", self.requests_per_day), + ] + + for window_name, limit in windows: + window_seconds = self._get_window_seconds(window_name) + window_key = f"{endpoint}:{window_name}" if endpoint else window_name + + # Count requests in this window + count = 0 + for timestamp, request_count in self._store[client_ip][window_key].items(): + if current_time - timestamp < window_seconds: + count += request_count + + if count >= limit: + return True + + # Add current request + self._store[client_ip][window_key][current_time] = ( + self._store[client_ip][window_key].get(current_time, 0) + 1 + ) + + return False + + def get_rate_limit_response(self, request: Request) -> JSONResponse: + """Get standard rate limit exceeded response.""" + return JSONResponse( + { + "error": "rate_limit_exceeded", + "detail": "Too many requests. Please try again later.", + }, + status_code=429, + headers={ + "Retry-After": "60", # Retry after 60 seconds + "X-RateLimit-Limit": str(self.requests_per_minute), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time() + 60)), + }, + ) + + +# Global rate limiter instance for auth endpoints +_auth_rate_limiter = RateLimiter( + requests_per_minute=10, # 10 requests per minute + requests_per_hour=100, # 100 requests per hour + requests_per_day=1000, # 1000 requests per day +) + + +def rate_limit_auth_endpoint(func: Callable) -> Callable: + """ + Decorator to rate limit authentication endpoints. + + Args: + func: The endpoint function to decorate. + + Returns: + Decorated function with rate limiting. + """ + + async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any: + if _auth_rate_limiter.is_rate_limited(request, endpoint="auth"): + return _auth_rate_limiter.get_rate_limit_response(request) + return await func(request, *args, **kwargs) + + return wrapper diff --git a/lightapi/yaml_loader.py b/lightapi/yaml_loader.py index ebed063..d263a14 100644 --- a/lightapi/yaml_loader.py +++ b/lightapi/yaml_loader.py @@ -43,7 +43,13 @@ def _build_name_registry() -> dict[str, type]: - from lightapi.auth import AllowAny, IsAdminUser, IsAuthenticated, JWTAuthentication + from lightapi.auth import ( + AllowAny, + BasicAuthentication, + IsAdminUser, + IsAuthenticated, + JWTAuthentication, + ) from lightapi.core import AuthenticationMiddleware, CORSMiddleware, Middleware from lightapi.filters import ( FieldFilter, @@ -55,6 +61,7 @@ def _build_name_registry() -> dict[str, type]: return { # Auth backends "JWTAuthentication": JWTAuthentication, + "BasicAuthentication": BasicAuthentication, # Permissions "AllowAny": AllowAny, "IsAuthenticated": IsAuthenticated, @@ -76,6 +83,26 @@ def _build_name_registry() -> dict[str, type]: } +def _resolve_callable(dotted_path: str) -> Any: + """Resolve a dotted path like 'myapp.validators.validate_login' to a callable.""" + if "." not in dotted_path: + raise ConfigurationError( + f"login_validator must be a dotted path (e.g. myapp.validators.check), " + f"got '{dotted_path}'" + ) + module_path, attr_name = dotted_path.rsplit(".", 1) + try: + mod = importlib.import_module(module_path) + fn = getattr(mod, attr_name) + except (ImportError, AttributeError) as exc: + raise ConfigurationError( + f"Cannot resolve login_validator '{dotted_path}': {exc}" + ) from exc + if not callable(fn): + raise ConfigurationError(f"login_validator '{dotted_path}' is not callable.") + return fn + + def _resolve_name(name: str) -> type: """Resolve a class name string to a class. @@ -129,11 +156,26 @@ def substitute_env(cls, v: str) -> str: return _substitute_env(v) +class AuthLoginConfig(BaseModel): + """Login/auth block: auth: { auth_path: ..., login_validator: ... }. + + When using JWTAuthentication or BasicAuthentication, login_validator is required. + It can be specified as a dotted path (e.g. myapp.validators.validate_login) + or passed as an override to from_config(login_validator=...). + """ + + auth_path: str = "/auth" + login_validator: str | None = None + + class AuthConfig(BaseModel): """Authentication block used in defaults and per-endpoint meta.""" backend: str | None = None permission: Union[str, dict[str, str], None] = None + jwt_expiration: int | None = None + jwt_extra_claims: list[str] | None = None + jwt_algorithm: str | None = None class FilteringConfig(BaseModel): @@ -220,6 +262,7 @@ class LightAPIConfig(BaseModel): defaults: DefaultsConfig = DefaultsConfig() endpoints: list[EndpointConfig] = [] middleware: list[str] = [] + auth: AuthLoginConfig | None = None @property def effective_database_url(self) -> str | None: @@ -245,14 +288,15 @@ def _substitute_env(value: str) -> str: def _make_authentication( - auth_cfg: AuthConfig | None, defaults_auth: AuthConfig | None + auth_cfg: AuthConfig | None, + defaults_auth: AuthConfig | None, ) -> Any: - """Build an Authentication instance from an AuthConfig, merged with defaults.""" - from lightapi.config import Authentication - # Merge: explicit cfg wins over defaults, defaults fill gaps merged_backend = None merged_permission = None + merged_jwt_expiration = None + merged_jwt_extra_claims = None + merged_jwt_algorithm = None for source in (defaults_auth, auth_cfg): if source is None: @@ -261,6 +305,12 @@ def _make_authentication( merged_backend = source.backend if source.permission is not None: merged_permission = source.permission + if source.jwt_expiration is not None: + merged_jwt_expiration = source.jwt_expiration + if source.jwt_extra_claims is not None: + merged_jwt_extra_claims = source.jwt_extra_claims + if source.jwt_algorithm is not None: + merged_jwt_algorithm = source.jwt_algorithm if merged_backend is None and merged_permission is None: return None @@ -277,7 +327,15 @@ def _make_authentication( else: permission = None - return Authentication(backend=backend_cls, permission=permission) + from lightapi.config import Authentication + + return Authentication( + backend=backend_cls, + permission=permission, + jwt_expiration=merged_jwt_expiration, + jwt_extra_claims=merged_jwt_extra_claims, + jwt_algorithm=merged_jwt_algorithm, + ) def _make_filtering(filtering_cfg: FilteringConfig | None) -> Any: @@ -343,15 +401,22 @@ def _build_meta_class( if isinstance(perm, str): permission_map[method] = _resolve_name(perm) if permission_map: - # Determine shared backend (from endpoint auth or defaults) + # Determine shared backend and JWT opts (from endpoint auth or defaults) + src = meta.authentication or defaults.authentication backend_name = (meta.authentication and meta.authentication.backend) or ( defaults.authentication and defaults.authentication.backend ) + jwt_exp = src.jwt_expiration if src else None + jwt_claims = src.jwt_extra_claims if src else None + jwt_algo = src.jwt_algorithm if src else None from lightapi.config import Authentication attrs["authentication"] = Authentication( backend=_resolve_name(backend_name) if backend_name else None, permission=permission_map, + jwt_expiration=jwt_exp, + jwt_extra_claims=jwt_claims, + jwt_algorithm=jwt_algo, ) else: # Simple auth: endpoint overrides defaults @@ -480,6 +545,15 @@ def load_config(app_cls: type, config_path: str, **overrides: Any) -> Any: "cors_origins": cfg.cors_origins or None, "middlewares": middlewares or None, } + + # Auth/login config from YAML auth: block + if cfg.auth: + constructor_kwargs["auth_path"] = cfg.auth.auth_path + if cfg.auth.login_validator: + constructor_kwargs["login_validator"] = _resolve_callable( + cfg.auth.login_validator + ) + constructor_kwargs.update(overrides) instance = app_cls(**constructor_kwargs) diff --git a/tests/test_auth.py b/tests/test_auth.py index 11b211a..d49519b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -59,6 +59,11 @@ def jwt_secret(monkeypatch_session=None): return secret +def _login_validator(username: str, password: str): + """Test validator; always returns None (tests use _make_token for tokens).""" + return None + + @pytest.fixture(scope="module") def client(jwt_secret): os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -67,7 +72,7 @@ def client(jwt_secret): connect_args={"check_same_thread": False}, poolclass=StaticPool, ) - app_instance = LightApi(engine=engine) + app_instance = LightApi(engine=engine, login_validator=_login_validator) app_instance.register( { "/secrets": SecretEndpoint, diff --git a/tests/test_login_auth.py b/tests/test_login_auth.py new file mode 100644 index 0000000..4903177 --- /dev/null +++ b/tests/test_login_auth.py @@ -0,0 +1,984 @@ +"""Tests for login and token endpoints (US-L1, US-L2, US-L3, US-L4).""" + +import base64 +import os + +import jwt +import pytest +from pydantic import ValidationError +from sqlalchemy import create_engine +from sqlalchemy.pool import StaticPool +from starlette.testclient import TestClient + +from lightapi import ( + Authentication, + BasicAuthentication, + JWTAuthentication, + LightApi, + RestEndpoint, +) +from lightapi._login import LoginRequest +from lightapi.exceptions import ConfigurationError +from lightapi.fields import Field as LField + + +def _valid_validator(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "alice@example.com", "is_admin": False} + return None + + +def _valid_validator_admin(username: str, password: str): + if username == "admin" and password == "admin": + return {"sub": "2", "is_admin": True} + return None + + +class JWTProtectedEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=JWTAuthentication) + + +class BasicProtectedEndpoint(RestEndpoint): + name: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=BasicAuthentication) + + +@pytest.fixture(scope="module") +def jwt_client(): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/secrets": JWTProtectedEndpoint}) + return TestClient(app.build_app()) + + +@pytest.fixture(scope="module") +def basic_client(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/items": BasicProtectedEndpoint}) + return TestClient(app.build_app()) + + +class TestJWTTokenEndpoint: + def test_login_valid_credentials_returns_token(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert "user" in data + assert data["user"]["sub"] == "1" + assert data["user"]["email"] == "alice@example.com" + + def test_login_invalid_credentials_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "wrong"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_missing_body_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={}) + assert resp.status_code == 422 + + def test_token_endpoint_same_as_login(self, jwt_client): + resp = jwt_client.post( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + assert "user" in resp.json() + + +class TestBodyValidation: + def test_login_missing_username_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={"password": "x"}) + assert resp.status_code == 422 + detail = resp.json()["detail"] + assert any("username" in str(e.get("loc", [])) for e in detail) + + def test_login_missing_password_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={"username": "x"}) + assert resp.status_code == 422 + detail = resp.json()["detail"] + assert any("password" in str(e.get("loc", [])) for e in detail) + + def test_login_empty_username_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "", "password": "x"}, + ) + assert resp.status_code == 422 + + def test_login_empty_password_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "x", "password": ""}, + ) + assert resp.status_code == 422 + + def test_login_invalid_json_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content="not json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + +class TestBasicHeaderInput: + def test_login_basic_header_valid_returns_token(self, jwt_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert data["user"]["sub"] == "1" + + def test_login_basic_header_takes_precedence_over_body(self, jwt_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + json={"username": "alice", "password": "wrong"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_login_basic_header_malformed_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic not-valid-base64!!"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_basic_header_no_colon_returns_401(self, jwt_client): + creds = base64.b64encode(b"nocolon").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_bearer_header_falls_through_to_body_validation(self, jwt_client): + """Authorization: Bearer x (no Basic) + empty body → 422.""" + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Bearer x"}, + json={}, + ) + assert resp.status_code == 422 + + +class TestBasicAuthEndpoint: + def test_basic_only_returns_user_without_token(self, basic_client): + resp = basic_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "user" in data + assert "token" not in data + assert data["user"]["sub"] == "1" + + +class TestHttpMethodAndErrorFormat: + def test_login_get_returns_405(self, jwt_client): + resp = jwt_client.get("/auth/login") + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + +class TestConfigurationAndRegistration: + def test_register_jwt_without_validator_raises_configuration_error(self): + os.environ.setdefault("LIGHTAPI_JWT_SECRET", "test-secret-key") + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine) + with pytest.raises(ConfigurationError, match="login_validator"): + app.register({"/secrets": JWTProtectedEndpoint}) + + def test_register_basic_without_validator_raises_configuration_error(self): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine) + with pytest.raises(ConfigurationError, match="login_validator"): + app.register({"/items": BasicProtectedEndpoint}) + + def test_auth_path_customization(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/api/auth", + ) + app.register({"/secrets": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/api/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + +class TestJWTConfigOverrides: + def test_jwt_expiration_override(self): + class JWTWithExpEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_expiration=10, + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTWithExpEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + exp = payload.get("exp") + assert exp is not None + import time + + assert 0 < exp - int(time.time()) <= 15 + + def test_jwt_extra_claims_filters_payload(self): + def validator_with_extra(username: str, password: str): + if username == "alice" and password == "secret": + return { + "sub": "1", + "email": "a@b.com", + "secret": "must-not-appear", + } + return None + + class JWTWithExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["sub", "email"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_with_extra) + app.register({"/x": JWTWithExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert "email" in payload + assert "secret" not in payload + + +class TestBasicAuthProtectedEndpoints: + def test_basic_protected_get_with_valid_header_returns_200(self, basic_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = basic_client.get( + "/items", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + + def test_basic_protected_get_without_auth_returns_401(self, basic_client): + resp = basic_client.get("/items") + assert resp.status_code == 401 + + def test_basic_protected_get_invalid_credentials_returns_401(self, basic_client): + creds = base64.b64encode(b"alice:wrong").decode() + resp = basic_client.get( + "/items", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + +class TestTokenUsability: + def test_jwt_token_usable_on_protected_endpoint(self, jwt_client): + login_resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert login_resp.status_code == 200 + token = login_resp.json()["token"] + resp = jwt_client.get( + "/secrets", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + + +class TestValidatorException: + def test_login_validator_exception_returns_500(self): + """Validator raising propagates to 500 (Starlette default handler).""" + + def failing_validator(username: str, password: str): + raise RuntimeError("DB unavailable") + + os.environ.setdefault("LIGHTAPI_JWT_SECRET", "test-secret-key") + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=failing_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app(), raise_server_exceptions=False) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 500 + + +class TestLoginRequestModel: + def test_login_request_validates_min_length(self): + with pytest.raises(ValidationError): + LoginRequest.model_validate({"username": "x", "password": ""}) + + +class TestBasicHeaderEdgeCases: + def test_login_basic_header_password_with_colon_parsed_correctly(self): + def validator_colon(username: str, password: str): + if username == "alice" and password == "pass:word": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_colon) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:pass:word").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + assert resp.json()["user"]["sub"] == "1" + + def test_login_basic_header_empty_username_returns_401(self): + def validator_reject_empty(username: str, password: str): + if not username or not password: + return None + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_reject_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b":secret").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_empty_password_returns_401(self): + def validator_reject_empty(username: str, password: str): + if not username or not password: + return None + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_reject_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_unicode_credentials_success(self): + def validator_unicode(username: str, password: str): + if username == "josé" and password == "contraseña": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_unicode) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode("josé:contraseña".encode("utf-8")).decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + assert resp.json()["user"]["sub"] == "1" + + def test_login_basic_header_lowercase_falls_through_to_body(self): + """Authorization: basic x (lowercase) does not match Basic; falls to body.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:secret").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"basic {creds}"}, + json={}, + ) + assert resp.status_code == 422 + + def test_login_basic_header_no_value_returns_401(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + headers={"Authorization": "Basic "}, + json={}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_multiple_spaces_after_basic(self): + """Authorization: Basic - leading space may affect decode.""" + creds = base64.b64encode(b"alice:secret").decode() + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code in (200, 401) + if resp.status_code == 200: + assert resp.json()["user"]["sub"] == "1" + + +class TestBodyEdgeCases: + def test_login_body_with_extra_keys_ignored_success(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret", "extra": "x"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_login_body_wrong_type_username_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": 123, "password": "x"}, + ) + assert resp.status_code == 422 + + def test_login_truncated_json_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content='{"username": "a",', + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + def test_login_empty_body_with_content_type_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content="{}", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + +class TestAuthPathEdgeCases: + def test_auth_path_trailing_slash_normalized(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/auth/", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_auth_path_root_routes_at_login_and_token(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_auth_path_nested(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/api/v1/auth", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/api/v1/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + +class TestJWTExtraClaimsEdgeCases: + def test_jwt_extra_claims_empty_list_uses_full_payload(self): + def validator_full(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "a@b.com"} + return None + + class JWTEmptyExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=[], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_full) + app.register({"/x": JWTEmptyExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert "email" in payload + + def test_jwt_extra_claims_all_keys_missing_fallback_to_full_payload(self): + def validator_minimal(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + class JWTMissingExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["x", "y"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_minimal) + app.register({"/x": JWTMissingExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + + def test_jwt_extra_claims_partial_overlap_only_included_in_token(self): + def validator_partial(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "a@b.com"} + return None + + class JWTPartialExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["sub", "missing"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_partial) + app.register({"/x": JWTPartialExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert payload["sub"] == "1" + assert "missing" not in payload + + +class TestValidatorReturnShapes: + def test_validator_empty_dict_returns_200(self): + def validator_empty(username: str, password: str): + if username == "alice" and password == "secret": + return {} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert data["user"] == {} + token = data["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "exp" in payload + + def test_validator_nested_payload_success(self): + def validator_nested(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "meta": {"role": "admin"}} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_nested) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["user"]["sub"] == "1" + assert data["user"]["meta"]["role"] == "admin" + + +class TestMixedAppScenarios: + def test_both_jwt_and_basic_endpoints_returns_token(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register( + { + "/secrets": JWTProtectedEndpoint, + "/items": BasicProtectedEndpoint, + } + ) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_basic_only_auth_token_same_as_login(self, basic_client): + resp = basic_client.post( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "user" in data + assert "token" not in data + + def test_allowany_only_no_auth_routes_404(self): + from lightapi import AllowAny + + class AllowAnyEndpoint(RestEndpoint): + name: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=AllowAny) + + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/public": AllowAnyEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 404 + + +class TestMethodEdgeCases: + def test_put_auth_login_returns_405(self, jwt_client): + resp = jwt_client.put( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + def test_patch_auth_token_returns_405(self, jwt_client): + resp = jwt_client.patch( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + def test_delete_auth_login_returns_405(self, jwt_client): + resp = jwt_client.delete("/auth/login") + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + +class TestResponseStructure: + def test_jwt_mode_response_has_exactly_token_and_user(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert set(data.keys()) == {"token", "user"} + + def test_basic_mode_response_has_no_token_key(self, basic_client): + resp = basic_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" not in resp.json() + assert "user" in resp.json() + + +class TestTokenVerification: + def test_expired_token_rejected_on_protected_endpoint(self): + import time + + class JWTExpiringEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_expiration=1, + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTExpiringEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + time.sleep(2) + get_resp = client.get("/x", headers={"Authorization": f"Bearer {token}"}) + assert get_resp.status_code == 401 + + def test_wrong_secret_token_rejected(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + wrong_token = jwt.encode( + {"sub": "1", "exp": 9999999999}, + "wrong-secret", + algorithm="HS256", + ) + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.get( + "/x", + headers={"Authorization": f"Bearer {wrong_token}"}, + ) + assert resp.status_code == 401 + + def test_token_structure_has_three_parts(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + assert len(token.split(".")) == 3 + + +class TestBasicVsBodyPrecedence: + def test_malformed_basic_overrides_valid_body_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic not-valid-base64!!"}, + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 401 + + def test_invalid_basic_and_empty_body_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic !!!"}, + json={}, + ) + assert resp.status_code == 401 diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index 6baab36..540cd4a 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -26,10 +26,20 @@ def _write_yaml(content: str | dict) -> str: return f.name -def _from_str(content: str) -> LightApi: +def _dummy_login_validator(username: str, password: str): + return None + + +def _from_str(content: str, login_validator=None) -> LightApi: path = _write_yaml(content) try: - return LightApi.from_config(path) + # Pass login_validator when YAML uses JWT or Basic auth (required by LightApi) + needs_validator = ( + "JWTAuthentication" in content or "BasicAuthentication" in content + ) + if needs_validator and login_validator is None: + login_validator = _dummy_login_validator + return LightApi.from_config(path, login_validator=login_validator) finally: os.unlink(path) @@ -299,6 +309,154 @@ def test_from_config_kwargs_override_yaml(self): os.unlink(path) +# ───────────────────────────────────────────────────────────────────────────── +# Auth/login YAML configuration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestYamlAuthConfig: + def test_auth_path_from_yaml(self): + """auth.auth_path in YAML configures login/token route prefix.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + content = """\ + database: + url: "sqlite:///:memory:" + auth: + auth_path: /api/auth + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + from starlette.testclient import TestClient + + client = TestClient(app.build_app()) + resp = client.post( + "/api/auth/login", + json={"username": "a", "password": "b"}, + ) + assert resp.status_code in (200, 401) + + def test_login_validator_dotted_path_from_yaml(self): + """auth.login_validator as dotted path resolves to callable.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + content = """\ + database: + url: "sqlite:///:memory:" + auth: + login_validator: tests.test_yaml_config._dummy_login_validator + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + assert app._login_validator is _dummy_login_validator + + def test_jwt_expiration_from_defaults(self): + """defaults.authentication.jwt_expiration flows to Meta.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + jwt_expiration: 300 + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + cls = app._endpoint_map["/x"] + assert cls.Meta.authentication.jwt_expiration == 300 + + def test_jwt_extra_claims_from_defaults(self): + """defaults.authentication.jwt_extra_claims flows to Meta.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + jwt_extra_claims: [sub, email] + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + cls = app._endpoint_map["/x"] + assert cls.Meta.authentication.jwt_extra_claims == ["sub", "email"] + + def test_basic_authentication_from_yaml(self): + """BasicAuthentication can be specified as backend in YAML.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: BasicAuthentication + permission: IsAuthenticated + endpoints: + - route: /items + fields: + name: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + from lightapi.auth import BasicAuthentication + + cls = app._endpoint_map["/items"] + assert cls.Meta.authentication.backend is BasicAuthentication + + def test_login_validator_invalid_dotted_path_raises(self): + """Invalid login_validator dotted path raises ConfigurationError.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + content = """\ + database: + url: "sqlite:///:memory:" + auth: + login_validator: non.existent.module.foo + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + path = _write_yaml(content) + try: + with pytest.raises(ConfigurationError, match="login_validator"): + LightApi.from_config(path) + finally: + os.unlink(path) + + # ───────────────────────────────────────────────────────────────────────────── # _resolve_name utility # ───────────────────────────────────────────────────────────────────────────── diff --git a/uv.lock b/uv.lock index cfc1d4a..8a87de6 100644 --- a/uv.lock +++ b/uv.lock @@ -526,7 +526,7 @@ wheels = [ [[package]] name = "lightapi" -version = "0.1.16" +version = "0.1.20" source = { editable = "." } dependencies = [ { name = "pydantic" }, From a92eb5df41d902d6f334d008b0505828c643d909 Mon Sep 17 00:00:00 2001 From: iklobato Date: Mon, 9 Mar 2026 23:25:34 -0300 Subject: [PATCH 2/3] fix(auth): address PR review comments for JWT and Basic authentication - Fix case-insensitive Basic auth header check - Make rate limiter app-scoped and configurable - Fix global login_validator breaking multi-app isolation - Make VALID_JWT_ALGORITHMS immutable - Validate per-endpoint jwt_algorithm before storing - Fix duplicate auth routes in register() - Fix rate limiter cleanup for hourly/daily buckets - Improve IP header trust for rate limiting - Fix rate limit response headers for correct window - Validate login validator signature at load time - Fix dict-form method auth dropping default JWT settings - Fix test helper signatures and type annotations - Update YAML documentation with jwt_algorithm - Fix circular imports and linting issues --- docs/examples/yaml-configuration.md | 2 + lightapi/_login.py | 22 ++++++---- lightapi/auth.py | 26 ++++++------ lightapi/config.py | 34 ++++++++++----- lightapi/lightapi.py | 65 +++++++++++++++++++++------- lightapi/rate_limiter.py | 54 +++++++++++++++++------ lightapi/yaml_loader.py | 66 +++++++++++++++++++++++------ tests/test_auth.py | 3 +- tests/test_yaml_config.py | 63 ++++++++++++++++----------- 9 files changed, 234 insertions(+), 101 deletions(-) diff --git a/docs/examples/yaml-configuration.md b/docs/examples/yaml-configuration.md index a9e2fc3..e399cf0 100644 --- a/docs/examples/yaml-configuration.md +++ b/docs/examples/yaml-configuration.md @@ -24,6 +24,7 @@ defaults: permission: IsAuthenticated jwt_expiration: 3600 jwt_extra_claims: [sub, email] + jwt_algorithm: HS256 pagination: style: page_number page_size: 20 @@ -140,6 +141,7 @@ python -c "from lightapi import LightApi; LightApi.from_config('lightapi.yaml'). | `defaults.authentication.permission` | string | Permission class name. | | `defaults.authentication.jwt_expiration` | int | JWT token expiration in seconds (JWT only). | | `defaults.authentication.jwt_extra_claims` | list | Claims to include in token payload (JWT only). | +| `defaults.authentication.jwt_algorithm` | string | JWT algorithm (HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512). | | `auth.auth_path` | string | Path prefix for `/login` and `/token` (default `/auth`). | | `auth.login_validator` | string | Dotted path to credential validator callable (e.g. `myapp.validators.check_user`). | | `defaults.pagination.style` | string | `page_number` or `cursor`. | diff --git a/lightapi/_login.py b/lightapi/_login.py index db50f0a..0e205d6 100644 --- a/lightapi/_login.py +++ b/lightapi/_login.py @@ -3,16 +3,16 @@ from __future__ import annotations import base64 +import json import logging from collections.abc import Callable -from typing import Any +from typing import Any, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError from starlette.requests import Request from starlette.responses import JSONResponse -from lightapi.auth import JWTAuthentication -from lightapi.rate_limiter import rate_limit_auth_endpoint +# JWTAuthentication imported locally where needed to avoid circular import logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def _parse_basic_header(auth_header: str) -> tuple[str, str] | None: Returns (username, password) or None if malformed. """ - if not auth_header.startswith("Basic "): + if not auth_header.lower().startswith("basic "): return None try: token = auth_header.split(" ", 1)[1] @@ -65,8 +65,6 @@ async def _parse_credentials(request: Request) -> tuple[str, str] | None: async def _read_body(request: Request) -> dict[str, Any]: """Read JSON body; return {} on empty or invalid.""" - import json - try: body = await request.body() return json.loads(body) if body else {} @@ -74,7 +72,6 @@ async def _read_body(request: Request) -> dict[str, Any]: return {} -@rate_limit_auth_endpoint async def login_handler( request: Request, *, @@ -83,6 +80,7 @@ async def login_handler( jwt_expiration: int | None = None, jwt_extra_claims: list[str] | None = None, jwt_algorithm: str | None = None, + rate_limiter: Optional[Any] = None, ) -> JSONResponse: """ Handle POST /auth/login and POST /auth/token. @@ -90,7 +88,11 @@ async def login_handler( Returns 422 for body validation, 401 for malformed Basic or invalid credentials, 500 for validator exception, 200 with token+user (JWT) or user only (Basic). """ - from pydantic import ValidationError + # Apply rate limiting if a rate limiter is provided + if rate_limiter is not None: + is_limited, window = rate_limiter.is_rate_limited(request, endpoint="auth") + if is_limited: + return rate_limiter.get_rate_limit_response(request, window) if request.method != "POST": return JSONResponse( @@ -118,6 +120,8 @@ async def login_handler( return JSONResponse({"detail": "Invalid credentials"}, status_code=401) if has_jwt: + from lightapi.auth import JWTAuthentication + jwt_auth = JWTAuthentication(algorithm=jwt_algorithm) if jwt_extra_claims and isinstance(payload, dict): token_payload = {k: payload[k] for k in jwt_extra_claims if k in payload} diff --git a/lightapi/auth.py b/lightapi/auth.py index ebbb942..4140e13 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,11 +1,13 @@ +import base64 from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from typing import Dict, Optional import jwt from starlette.requests import Request from starlette.responses import JSONResponse -from ._registry import LoginValidator +from ._login import _parse_basic_header +from ._registry import LoginValidator, get_login_validator from .config import config @@ -159,9 +161,15 @@ class BasicAuthentication(BaseAuthentication): Basic (Base64) authentication. Authenticates requests using Authorization: Basic . - Delegates credential validation to the app-level login_validator from the registry. + Delegates credential validation to the provided login_validator. """ + def __init__( + self, + login_validator: Optional[LoginValidator] = None, + ) -> None: + self.login_validator = login_validator + def authenticate(self, request: Request) -> bool: if request.method == "OPTIONS": return True @@ -171,16 +179,12 @@ def authenticate(self, request: Request) -> bool: return False # Use the shared Basic auth parsing function - from lightapi._login import _parse_basic_header - credentials = _parse_basic_header(auth_header) if credentials is None: return False username, password = credentials - from lightapi._registry import get_login_validator - - validator = get_login_validator() + validator = self.login_validator or get_login_validator() if validator is None: return False @@ -208,12 +212,10 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: return JSONResponse({"error": "authentication failed"}, status_code=401) auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Basic "): + if not auth_header or not auth_header.lower().startswith("basic "): return False try: - import base64 - token = auth_header.split(" ", 1)[1] decoded = base64.b64decode(token).decode("utf-8") except (ValueError, IndexError, UnicodeDecodeError): @@ -224,8 +226,6 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: return False username, password = parts[0], parts[1] - from lightapi._registry import get_login_validator - validator = get_login_validator() if validator is None: return False diff --git a/lightapi/config.py b/lightapi/config.py index e1ecc11..daba180 100644 --- a/lightapi/config.py +++ b/lightapi/config.py @@ -9,17 +9,19 @@ class _Config: """Configuration used by JWTAuthentication and other components.""" - VALID_JWT_ALGORITHMS = { - "HS256", - "HS384", - "HS512", - "RS256", - "RS384", - "RS512", - "ES256", - "ES384", - "ES512", - } + VALID_JWT_ALGORITHMS = frozenset( + { + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + } + ) def __init__(self) -> None: self._overrides: dict[str, Any] = {} @@ -70,6 +72,16 @@ def __init__( ) self.jwt_expiration = jwt_expiration self.jwt_extra_claims = jwt_extra_claims + + # Validate jwt_algorithm if provided + if ( + jwt_algorithm is not None + and jwt_algorithm not in config.VALID_JWT_ALGORITHMS + ): + raise ConfigurationError( + f"Invalid JWT algorithm: '{jwt_algorithm}'. " + f"Must be one of: {sorted(config.VALID_JWT_ALGORITHMS)}" + ) self.jwt_algorithm = jwt_algorithm diff --git a/lightapi/lightapi.py b/lightapi/lightapi.py index 567a8c1..10b8324 100644 --- a/lightapi/lightapi.py +++ b/lightapi/lightapi.py @@ -4,10 +4,11 @@ import asyncio import importlib +import json import logging import os import warnings -from typing import Any +from typing import Any, Callable, Dict, Optional import uvicorn from sqlalchemy import create_engine @@ -18,8 +19,16 @@ from starlette.responses import JSONResponse, Response from starlette.routing import Route -from lightapi._registry import get_registry_and_metadata, set_engine +from lightapi._registry import ( + get_registry_and_metadata, + set_engine, + set_login_validator, +) +from lightapi.auth import AllowAny, BasicAuthentication, JWTAuthentication +from lightapi.cache import get_cached, invalidate_cache_prefix, set_cached from lightapi.exceptions import ConfigurationError +from lightapi.rest import RestEndpoint +from lightapi.yaml_loader import load_config logger = logging.getLogger(__name__) @@ -63,6 +72,7 @@ def __init__( # Detect async engine — drives session strategy and startup validation try: + importlib.import_module("sqlalchemy.ext.asyncio") from sqlalchemy.ext.asyncio import AsyncEngine self._async: bool = isinstance(engine, AsyncEngine) @@ -75,9 +85,8 @@ def __init__( self._cors_origins: list[str] = cors_origins or [] self._login_validator = login_validator self._auth_path = auth_path + self._auth_rate_limiter = None if login_validator is not None: - from lightapi._registry import set_login_validator - set_login_validator(login_validator) # ───────────────────────────────────────────────────────────────────────── @@ -91,7 +100,6 @@ def register(self, mapping: dict[str, type]) -> None: mapping: ``{"/path": EndpointClass}`` dictionary. Each class must be a ``RestEndpoint`` subclass. """ - from lightapi.rest import RestEndpoint for path, cls in mapping.items(): if not (isinstance(cls, type) and issubclass(cls, RestEndpoint)): @@ -153,7 +161,6 @@ def register(self, mapping: dict[str, type]) -> None: self._endpoint_map[path] = cls # Auto-register /auth/login and /auth/token when JWT or Basic auth is used - from lightapi.auth import BasicAuthentication, JWTAuthentication auth_backends: set[type] = set() jwt_config_expiration: int | None = None @@ -174,6 +181,12 @@ def register(self, mapping: dict[str, type]) -> None: jwt_config_algorithm = getattr(auth_cfg, "jwt_algorithm", None) if auth_backends & {JWTAuthentication, BasicAuthentication}: + # Initialize rate limiter if not already created + if self._auth_rate_limiter is None: + from lightapi.rate_limiter import RateLimiter + + self._auth_rate_limiter = RateLimiter() + if self._login_validator is None: raise ConfigurationError( "login_validator is required when using JWTAuthentication " @@ -181,6 +194,20 @@ def register(self, mapping: dict[str, type]) -> None: ) has_jwt = JWTAuthentication in auth_backends auth_path = self._auth_path.rstrip("/") + + # Check if auth routes already exist + login_path = f"{auth_path}/login" + token_path = f"{auth_path}/token" + + # Remove existing auth routes if any + self._routes = [ + route + for route in self._routes + if not ( + isinstance(route, Route) and route.path in {login_path, token_path} + ) + ] + login_endpoint = self._make_login_endpoint( has_jwt=has_jwt, jwt_expiration=jwt_config_expiration, @@ -190,7 +217,7 @@ def register(self, mapping: dict[str, type]) -> None: self._routes.insert( 0, Route( - f"{auth_path}/login", + login_path, login_endpoint, methods=["POST"], ), @@ -198,7 +225,7 @@ def register(self, mapping: dict[str, type]) -> None: self._routes.insert( 1, Route( - f"{auth_path}/token", + token_path, login_endpoint, methods=["POST"], ), @@ -225,6 +252,7 @@ async def handler(request: Request) -> Response: jwt_expiration=jwt_expiration, jwt_extra_claims=jwt_extra_claims, jwt_algorithm=jwt_algorithm, + rate_limiter=self._auth_rate_limiter, ) return handler @@ -242,7 +270,9 @@ async def handler(request: Request) -> Response: if pre_result is not None: return pre_result - auth_result = _check_auth(cls, request) + auth_result = _check_auth( + cls, request, login_validator=self._login_validator + ) if auth_result is not None: return auth_result @@ -299,7 +329,9 @@ async def handler(request: Request) -> Response: if pre_result is not None: return pre_result - auth_result = _check_auth(cls, request) + auth_result = _check_auth( + cls, request, login_validator=self._login_validator + ) if auth_result is not None: return auth_result @@ -416,7 +448,6 @@ def from_config(cls, config_path: str, **kwargs: Any) -> "LightApi": Kwargs override YAML values (e.g. engine=..., database_url=...). """ - from lightapi.yaml_loader import load_config return load_config(cls, config_path, **kwargs) @@ -511,7 +542,6 @@ def _validate_async_dependencies(engine: Any) -> None: async def _read_body(request: Request) -> dict[str, Any]: """Read and parse JSON body; return {} on failure.""" - import json try: body = await request.body() @@ -520,9 +550,12 @@ async def _read_body(request: Request) -> dict[str, Any]: return {} -def _check_auth(cls: type, request: Request) -> Response | None: +def _check_auth( + cls: type, + request: Request, + login_validator: Optional[Callable[[str, str], Dict[str, Any] | None]] = None, +) -> Response | None: """Run authentication + permission checks; return 401/403 response or None.""" - from lightapi.auth import AllowAny auth_cfg = cls._meta.get("authentication") if auth_cfg is None: @@ -552,6 +585,8 @@ def _check_auth(cls: type, request: Request) -> Response | None: expiration=getattr(auth_cfg, "jwt_expiration", None), algorithm=getattr(auth_cfg, "jwt_algorithm", None), ) + elif backend.__name__ == "BasicAuthentication": + authenticator = backend(login_validator=login_validator) else: authenticator = backend() @@ -603,7 +638,6 @@ async def _run_post_middlewares( def _maybe_cached(cls: type, request: Request, fn: Any) -> Response: """Serve from Redis cache (GET only) or call fn() and populate cache.""" - from lightapi.cache import get_cached, set_cached cache_cfg = cls._meta.get("cache") if cache_cfg is None: @@ -634,7 +668,6 @@ def _maybe_invalidate_cache(cls: type, request: Request) -> None: cache_cfg = cls._meta.get("cache") if cache_cfg is None: return - from lightapi.cache import invalidate_cache_prefix invalidate_cache_prefix(_cache_key_prefix(cls)) diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py index 5676cb4..c6ff06d 100644 --- a/lightapi/rate_limiter.py +++ b/lightapi/rate_limiter.py @@ -63,28 +63,36 @@ def _cleanup_old_entries(self) -> None: def _get_window_seconds(self, window: str) -> int: """Convert window name to seconds.""" - if window == "minute": + # Extract window name from key (e.g., "auth:minute" -> "minute") + window_name = window.split(":")[-1] if ":" in window else window + + if window_name == "minute": return 60 - elif window == "hour": + elif window_name == "hour": return 3600 - elif window == "day": + elif window_name == "day": return 86400 else: return 60 def _get_client_ip(self, request: Request) -> str: """Extract client IP from request.""" - # Try common headers for proxy setups + # Always prefer the actual client host for security + if request.client and request.client.host: + return request.client.host + + # Fallback to headers only if client host is not available for header in ("X-Forwarded-For", "X-Real-IP", "X-Client-IP"): if header in request.headers: ip = request.headers[header].split(",")[0].strip() if ip: return ip - # Fall back to client host - return request.client.host if request.client else "0.0.0.0" + return "0.0.0.0" - def is_rate_limited(self, request: Request, endpoint: str = "") -> bool: + def is_rate_limited( + self, request: Request, endpoint: str = "" + ) -> tuple[bool, str | None]: """ Check if request should be rate limited. @@ -93,7 +101,8 @@ def is_rate_limited(self, request: Request, endpoint: str = "") -> bool: endpoint: Optional endpoint identifier for per-endpoint limiting. Returns: - bool: True if rate limited, False otherwise. + tuple[bool, str | None]: (True if rate limited, window name) or + (False, None) """ self._cleanup_old_entries() @@ -118,17 +127,34 @@ def is_rate_limited(self, request: Request, endpoint: str = "") -> bool: count += request_count if count >= limit: - return True + # Don't count this request since it's being blocked + return (True, window_name) # Add current request self._store[client_ip][window_key][current_time] = ( self._store[client_ip][window_key].get(current_time, 0) + 1 ) - return False + return (False, None) - def get_rate_limit_response(self, request: Request) -> JSONResponse: + def get_rate_limit_response( + self, request: Request, window: str = "minute" + ) -> JSONResponse: """Get standard rate limit exceeded response.""" + # Determine window-specific values + if window == "hour": + limit = self.requests_per_hour + retry_after = 3600 + reset_seconds = 3600 + elif window == "day": + limit = self.requests_per_day + retry_after = 86400 + reset_seconds = 86400 + else: # minute + limit = self.requests_per_minute + retry_after = 60 + reset_seconds = 60 + return JSONResponse( { "error": "rate_limit_exceeded", @@ -136,10 +162,10 @@ def get_rate_limit_response(self, request: Request) -> JSONResponse: }, status_code=429, headers={ - "Retry-After": "60", # Retry after 60 seconds - "X-RateLimit-Limit": str(self.requests_per_minute), + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limit), "X-RateLimit-Remaining": "0", - "X-RateLimit-Reset": str(int(time.time() + 60)), + "X-RateLimit-Reset": str(int(time.time() + reset_seconds)), }, ) diff --git a/lightapi/yaml_loader.py b/lightapi/yaml_loader.py index d263a14..ff0ea7f 100644 --- a/lightapi/yaml_loader.py +++ b/lightapi/yaml_loader.py @@ -100,6 +100,38 @@ def _resolve_callable(dotted_path: str) -> Any: ) from exc if not callable(fn): raise ConfigurationError(f"login_validator '{dotted_path}' is not callable.") + + # Validate signature: must be sync function with exactly 2 positional args + import inspect + + if inspect.iscoroutinefunction(fn): + raise ConfigurationError( + f"login_validator '{dotted_path}' is async, but sync function required. " + f"Login validation must be a synchronous function." + ) + + try: + sig = inspect.signature(fn) + except ValueError: + # Some callables (e.g., builtins) don't have inspectable signatures + return fn + + # Count required positional parameters + required_params = 0 + for param in sig.parameters.values(): + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + if param.default is inspect.Parameter.empty: + required_params += 1 + + if required_params != 2: + raise ConfigurationError( + f"login_validator '{dotted_path}' must accept exactly 2 required " + f"positional parameters (username, password), got {required_params}" + ) + return fn @@ -401,22 +433,32 @@ def _build_meta_class( if isinstance(perm, str): permission_map[method] = _resolve_name(perm) if permission_map: - # Determine shared backend and JWT opts (from endpoint auth or defaults) - src = meta.authentication or defaults.authentication - backend_name = (meta.authentication and meta.authentication.backend) or ( - defaults.authentication and defaults.authentication.backend - ) - jwt_exp = src.jwt_expiration if src else None - jwt_claims = src.jwt_extra_claims if src else None - jwt_algo = src.jwt_algorithm if src else None + # Merge auth settings similar to _make_authentication + merged_backend = None + merged_jwt_expiration = None + merged_jwt_extra_claims = None + merged_jwt_algorithm = None + + for source in (defaults.authentication, meta.authentication): + if source is None: + continue + if source.backend is not None: + merged_backend = source.backend + if source.jwt_expiration is not None: + merged_jwt_expiration = source.jwt_expiration + if source.jwt_extra_claims is not None: + merged_jwt_extra_claims = source.jwt_extra_claims + if source.jwt_algorithm is not None: + merged_jwt_algorithm = source.jwt_algorithm + from lightapi.config import Authentication attrs["authentication"] = Authentication( - backend=_resolve_name(backend_name) if backend_name else None, + backend=_resolve_name(merged_backend) if merged_backend else None, permission=permission_map, - jwt_expiration=jwt_exp, - jwt_extra_claims=jwt_claims, - jwt_algorithm=jwt_algo, + jwt_expiration=merged_jwt_expiration, + jwt_extra_claims=merged_jwt_extra_claims, + jwt_algorithm=merged_jwt_algorithm, ) else: # Simple auth: endpoint overrides defaults diff --git a/tests/test_auth.py b/tests/test_auth.py index d49519b..33c7db7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,7 @@ """Tests for US3: Authentication and Permission classes.""" import os +from typing import Any import pytest from sqlalchemy import create_engine @@ -59,7 +60,7 @@ def jwt_secret(monkeypatch_session=None): return secret -def _login_validator(username: str, password: str): +def _login_validator(_username: str, _password: str) -> dict[str, Any] | None: """Test validator; always returns None (tests use _make_token for tokens).""" return None diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index 540cd4a..fc76a8d 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -33,13 +33,11 @@ def _dummy_login_validator(username: str, password: str): def _from_str(content: str, login_validator=None) -> LightApi: path = _write_yaml(content) try: - # Pass login_validator when YAML uses JWT or Basic auth (required by LightApi) - needs_validator = ( - "JWTAuthentication" in content or "BasicAuthentication" in content - ) - if needs_validator and login_validator is None: - login_validator = _dummy_login_validator - return LightApi.from_config(path, login_validator=login_validator) + # Only pass login_validator if explicitly provided + kwargs = {} + if login_validator is not None: + kwargs["login_validator"] = login_validator + return LightApi.from_config(path, **kwargs) finally: os.unlink(path) @@ -347,24 +345,39 @@ def test_auth_path_from_yaml(self): def test_login_validator_dotted_path_from_yaml(self): """auth.login_validator as dotted path resolves to callable.""" os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" - content = """\ - database: - url: "sqlite:///:memory:" - auth: - login_validator: tests.test_yaml_config._dummy_login_validator - defaults: - authentication: - backend: JWTAuthentication - permission: IsAuthenticated - endpoints: - - route: /x - fields: - data: { type: str } - meta: - methods: [GET] - """ - app = _from_str(content) - assert app._login_validator is _dummy_login_validator + + # Create a simple validator function + def test_validator(username: str, password: str): + return {"sub": "test"} + + # Monkey-patch it into the module + import tests.test_yaml_config + + tests.test_yaml_config.test_validator_func = test_validator + + try: + content = """\ + database: + url: "sqlite:///:memory:" + auth: + login_validator: tests.test_yaml_config.test_validator_func + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + # The validator from YAML should be resolved and used + assert app._login_validator is test_validator + finally: + # Clean up + delattr(tests.test_yaml_config, "test_validator_func") def test_jwt_expiration_from_defaults(self): """defaults.authentication.jwt_expiration flows to Meta.""" From 563976de9b79cd53459891a7fe9ef50d09583b44 Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 00:13:45 -0300 Subject: [PATCH 3/3] fix(auth): final fixes for PR review - Add WWW-Authenticate header to Basic auth 401 response - Add login_validator parameter to BasicAuthentication constructor - Fix case-insensitive Basic auth header check - Reject reserved claims in jwt_extra_claims at config time - Fix rate limiter by removing unused global decorator - Fix YAML dict-method permission merging - Update imports and fix linting issues --- lightapi/auth.py | 143 ++++++--------------------------------- lightapi/config.py | 15 ++++ lightapi/rate_limiter.py | 28 -------- lightapi/yaml_loader.py | 34 +++++++++- 4 files changed, 69 insertions(+), 151 deletions(-) diff --git a/lightapi/auth.py b/lightapi/auth.py index 4140e13..d3ead3a 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,12 +1,10 @@ -import base64 from datetime import datetime, timedelta -from typing import Dict, Optional +from typing import Any, Dict, Optional import jwt from starlette.requests import Request from starlette.responses import JSONResponse -from ._login import _parse_basic_header from ._registry import LoginValidator, get_login_validator from .config import config @@ -31,129 +29,22 @@ def authenticate(self, request: Request) -> bool: """ return True - def get_auth_error_response(self, request: Request) -> JSONResponse: - """ - Get the response to return when authentication fails. - - Args: - request: The HTTP request object. - - Returns: - Response object for authentication error. - """ - return JSONResponse({"error": "not allowed"}, status_code=403) - -class JWTAuthentication(BaseAuthentication): +def get_auth_error_response(self, request: Request) -> JSONResponse: """ - JWT (JSON Web Token) based authentication. + Get the response to return when authentication fails. - Authenticates requests using JWT tokens from the Authorization header. - Validates token signatures and expiration times. - Automatically skips authentication for OPTIONS requests (CORS preflight). + Args: + request: The HTTP request object. - Attributes: - secret_key: Secret key for signing tokens. - algorithm: JWT algorithm to use. - expiration: Token expiration time in seconds. + Returns: + Response object for authentication error. """ - - def __init__( - self, - secret_key: str | None = None, - algorithm: str | None = None, - expiration: int | None = None, - ): - self.secret_key = secret_key or config.jwt_secret - if not self.secret_key: - raise ValueError( - "JWT secret key not configured. Set LIGHTAPI_JWT_SECRET environment variable." - ) - - self.algorithm = algorithm or config.jwt_algorithm - self.expiration = expiration or 3600 # 1 hour default - - def authenticate(self, request: Request) -> bool: - """ - Authenticate a request using JWT token. - Automatically allows OPTIONS requests for CORS preflight. - - Args: - request: The HTTP request object. - - Returns: - bool: True if authentication succeeds, False otherwise. - """ - # Skip authentication for OPTIONS requests (CORS preflight) - if request.method == "OPTIONS": - return True - - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): - return False - - token = auth_header.split(" ")[1] - try: - payload = self.decode_token(token) - request.state.user = payload - return True - except jwt.InvalidTokenError: - return False - - def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str: - """ - Generate a JWT token. - - Args: - payload: The data to encode in the token. - expiration: Token expiration time in seconds. - - Returns: - str: The encoded JWT token. - - Raises: - ValueError: If payload contains 'exp' claim which will be overwritten. - """ - # Check for 'exp' in payload since we overwrite it - if "exp" in payload: - raise ValueError( - "Payload contains 'exp' claim which will be overwritten. " - "Use the 'expiration' parameter instead." - ) - - exp_seconds = expiration or self.expiration - token_data = { - **payload, - "exp": datetime.utcnow() + timedelta(seconds=exp_seconds), - } - return jwt.encode(token_data, self.secret_key, algorithm=self.algorithm) - - def decode_token(self, token: str) -> Dict: - """ - Decode and verify a JWT token. - - Args: - token: The JWT token to decode. - - Returns: - dict: The decoded token payload. - - Raises: - jwt.InvalidTokenError: If the token is invalid or expired. - """ - return jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) - - def get_auth_error_response(self, request: Request) -> JSONResponse: - """ - Get the response to return when authentication fails. - - Args: - request: The HTTP request object. - - Returns: - Response object for authentication error. - """ - return JSONResponse({"error": "authentication failed"}, status_code=401) + return JSONResponse( + {"error": "authentication failed"}, + status_code=401, + headers={"WWW-Authenticate": 'Basic realm="Restricted Area"'}, + ) class BasicAuthentication(BaseAuthentication): @@ -161,7 +52,7 @@ class BasicAuthentication(BaseAuthentication): Basic (Base64) authentication. Authenticates requests using Authorization: Basic . - Delegates credential validation to the provided login_validator. + Delegates credential validation to the app-level login_validator from the registry. """ def __init__( @@ -179,11 +70,15 @@ def authenticate(self, request: Request) -> bool: return False # Use the shared Basic auth parsing function + from lightapi._login import _parse_basic_header + credentials = _parse_basic_header(auth_header) if credentials is None: return False username, password = credentials + from lightapi._registry import get_login_validator + validator = self.login_validator or get_login_validator() if validator is None: return False @@ -216,6 +111,8 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: return False try: + import base64 + token = auth_header.split(" ", 1)[1] decoded = base64.b64decode(token).decode("utf-8") except (ValueError, IndexError, UnicodeDecodeError): @@ -226,6 +123,8 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: return False username, password = parts[0], parts[1] + from lightapi._registry import get_login_validator + validator = get_login_validator() if validator is None: return False diff --git a/lightapi/config.py b/lightapi/config.py index daba180..13434ab 100644 --- a/lightapi/config.py +++ b/lightapi/config.py @@ -71,6 +71,21 @@ def __init__( permission if permission is not None else AllowAny ) self.jwt_expiration = jwt_expiration + + # Validate jwt_extra_claims - reject reserved claims + if jwt_extra_claims: + RESERVED_CLAIMS = {"exp", "iat", "nbf", "iss", "sub", "aud", "jti"} + reserved_found = [] + for claim in jwt_extra_claims: + if claim in RESERVED_CLAIMS: + reserved_found.append(claim) + + if reserved_found: + raise ConfigurationError( + f"JWT extra claims cannot include reserved claims: " + f"{reserved_found}. Reserved claims are: {sorted(RESERVED_CLAIMS)}" + ) + self.jwt_extra_claims = jwt_extra_claims # Validate jwt_algorithm if provided diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py index c6ff06d..f7a8c2a 100644 --- a/lightapi/rate_limiter.py +++ b/lightapi/rate_limiter.py @@ -4,7 +4,6 @@ import time from collections import defaultdict -from typing import Any, Callable from starlette.requests import Request from starlette.responses import JSONResponse @@ -168,30 +167,3 @@ def get_rate_limit_response( "X-RateLimit-Reset": str(int(time.time() + reset_seconds)), }, ) - - -# Global rate limiter instance for auth endpoints -_auth_rate_limiter = RateLimiter( - requests_per_minute=10, # 10 requests per minute - requests_per_hour=100, # 100 requests per hour - requests_per_day=1000, # 1000 requests per day -) - - -def rate_limit_auth_endpoint(func: Callable) -> Callable: - """ - Decorator to rate limit authentication endpoints. - - Args: - func: The endpoint function to decorate. - - Returns: - Decorated function with rate limiting. - """ - - async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any: - if _auth_rate_limiter.is_rate_limited(request, endpoint="auth"): - return _auth_rate_limiter.get_rate_limit_response(request) - return await func(request, *args, **kwargs) - - return wrapper diff --git a/lightapi/yaml_loader.py b/lightapi/yaml_loader.py index ff0ea7f..9e85db1 100644 --- a/lightapi/yaml_loader.py +++ b/lightapi/yaml_loader.py @@ -425,13 +425,45 @@ def _build_meta_class( # Build per-method permission dict from the methods dict permission_map: dict[str, type] = {} method_auth_default = meta.authentication # endpoint-level auth override + + # Get default permission from defaults + default_permission = None + if defaults.authentication and defaults.authentication.permission: + default_permission = defaults.authentication.permission + for method, method_cfg in meta.methods.items(): cfg_auth = method_cfg.authentication if method_cfg else None src_auth = cfg_auth or method_auth_default + + # Determine permission for this method + method_permission = None + + # First, check method-specific auth if src_auth and src_auth.permission: perm = src_auth.permission if isinstance(perm, str): - permission_map[method] = _resolve_name(perm) + method_permission = _resolve_name(perm) + # Note: per-method auth in YAML dict can't have dict permissions + # only string permissions are supported in this path + + # If no method-specific permission, check default permission + if method_permission is None and default_permission is not None: + if isinstance(default_permission, dict): + # Default permission is a dict: check for method-specific default + if method in default_permission: + perm = default_permission[method] + if isinstance(perm, str): + method_permission = _resolve_name(perm) + else: + # Default permission is a string or class + if isinstance(default_permission, str): + method_permission = _resolve_name(default_permission) + else: + method_permission = default_permission + + if method_permission is not None: + permission_map[method] = method_permission + if permission_map: # Merge auth settings similar to _make_authentication merged_backend = None