From feb878067dc1c5725b6c69e44d94dac0e13a1269 Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 15:06:52 -0300 Subject: [PATCH 1/5] feat: add login authentication and rate limiting features - Add _login.py module for login authentication - Add rate_limiter.py module for request rate limiting - Update auth.py with authentication improvements - Update yaml_loader.py and config.py for better configuration handling - Add test_login_auth.py for login authentication tests - Update existing tests and documentation - Update dependencies in uv.lock --- README.md | 16 + docs/examples/yaml-configuration.md | 14 +- lightapi/__init__.py | 9 +- lightapi/_login.py | 135 ++++ lightapi/_registry.py | 46 +- lightapi/auth.py | 119 +++- lightapi/cache.py | 11 +- lightapi/config.py | 58 ++ lightapi/lightapi.py | 150 ++++- lightapi/rate_limiter.py | 176 +++++ lightapi/yaml_loader.py | 179 ++++- tests/test_auth.py | 8 +- tests/test_login_auth.py | 984 ++++++++++++++++++++++++++++ tests/test_yaml_config.py | 175 ++++- uv.lock | 2 +- 15 files changed, 2023 insertions(+), 59 deletions(-) create mode 100644 lightapi/_login.py create mode 100644 lightapi/rate_limiter.py create mode 100644 tests/test_login_auth.py diff --git a/README.md b/README.md index f2995a4..c19db4b 100644 --- a/README.md +++ b/README.md @@ -304,6 +304,22 @@ class AdminOnlyEndpoint(RestEndpoint): 2. Permission class `.has_permission(request)` — checks `request.state.user` 3. Returns `401` if authentication fails, `403` if permission denied +**Login and token endpoints:** When using `JWTAuthentication` or `BasicAuthentication`, pass `login_validator` to obtain automatic `/auth/login` and `/auth/token` endpoints: + +```python +def my_validator(username: str, password: str): + # Return user payload dict or None + user = db.query(User).filter_by(username=username).first() + if user and user.check_password(password): + return {"sub": str(user.id), "is_admin": user.is_admin} + return None + +app = LightApi(engine=engine, login_validator=my_validator) +app.register({"/secrets": ProtectedEndpoint}) +# POST /auth/login and POST /auth/token now accept {"username":"...","password":"..."} +# JWT mode: 200 {"token":"...","user":{...}}; Basic-only: 200 {"user":{...}} +``` + **Built-in permission classes:** | Class | Condition | diff --git a/docs/examples/yaml-configuration.md b/docs/examples/yaml-configuration.md index 26b7b38..e399cf0 100644 --- a/docs/examples/yaml-configuration.md +++ b/docs/examples/yaml-configuration.md @@ -22,6 +22,9 @@ defaults: authentication: backend: JWTAuthentication permission: IsAuthenticated + jwt_expiration: 3600 + jwt_extra_claims: [sub, email] + jwt_algorithm: HS256 pagination: style: page_number page_size: 20 @@ -29,6 +32,10 @@ defaults: middleware: - CORSMiddleware +auth: + auth_path: /auth + login_validator: myapp.validators.validate_login + endpoints: - route: /products fields: @@ -130,8 +137,13 @@ python -c "from lightapi import LightApi; LightApi.from_config('lightapi.yaml'). |-------|------|-------------| | `database.url` | string | SQLAlchemy URL. Supports `${VAR}` substitution. | | `cors_origins` | list | CORS allowed origins. | -| `defaults.authentication.backend` | string | Auth backend class name. | +| `defaults.authentication.backend` | string | Auth backend class name (`JWTAuthentication`, `BasicAuthentication`). | | `defaults.authentication.permission` | string | Permission class name. | +| `defaults.authentication.jwt_expiration` | int | JWT token expiration in seconds (JWT only). | +| `defaults.authentication.jwt_extra_claims` | list | Claims to include in token payload (JWT only). | +| `defaults.authentication.jwt_algorithm` | string | JWT algorithm (HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512). | +| `auth.auth_path` | string | Path prefix for `/login` and `/token` (default `/auth`). | +| `auth.login_validator` | string | Dotted path to credential validator callable (e.g. `myapp.validators.check_user`). | | `defaults.pagination.style` | string | `page_number` or `cursor`. | | `defaults.pagination.page_size` | int | Rows per page. | | `middleware` | list | Class names resolved at startup. | diff --git a/lightapi/__init__.py b/lightapi/__init__.py index da005e4..54a9846 100644 --- a/lightapi/__init__.py +++ b/lightapi/__init__.py @@ -1,6 +1,12 @@ """LightAPI v2 public API.""" -from lightapi.auth import AllowAny, IsAdminUser, IsAuthenticated, JWTAuthentication +from lightapi.auth import ( + AllowAny, + BasicAuthentication, + IsAdminUser, + IsAuthenticated, + JWTAuthentication, +) from lightapi.cache import RedisCache from lightapi.config import Authentication, Cache, Filtering, Pagination, Serializer @@ -33,6 +39,7 @@ "Pagination", "Serializer", # Auth + "BasicAuthentication", "JWTAuthentication", "AllowAny", "IsAuthenticated", diff --git a/lightapi/_login.py b/lightapi/_login.py new file mode 100644 index 0000000..0136f54 --- /dev/null +++ b/lightapi/_login.py @@ -0,0 +1,135 @@ +"""Login and token endpoint handlers.""" + +from __future__ import annotations + +import base64 +import json +import logging +from collections.abc import Callable +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from starlette.requests import Request +from starlette.responses import JSONResponse + +# JWTAuthentication imported locally where needed to avoid circular import + +logger = logging.getLogger(__name__) + + +class LoginRequest(BaseModel): + """Request body for POST /auth/login and /auth/token.""" + + model_config = ConfigDict(frozen=True) + + username: str = Field(min_length=1) + password: str = Field(min_length=1) + + +def _parse_basic_header(auth_header: str) -> tuple[str, str] | None: + """ + Decode Authorization: Basic header. + + Returns (username, password) or None if malformed. + """ + if not auth_header.lower().startswith("basic "): + return None + try: + token = auth_header.split(" ", 1)[1] + decoded = base64.b64decode(token).decode("utf-8") + except (ValueError, IndexError, UnicodeDecodeError): + return None + parts = decoded.split(":", 1) + if len(parts) != 2: + return None + return parts[0], parts[1] + + +async def _parse_credentials(request: Request) -> tuple[str, str] | None: + """ + Extract (username, password) from request. + + - If Authorization: Basic present: returns (u, p) or None if malformed. + - If no Basic header: reads body, validates with LoginRequest. + Returns (u, p) if valid. Raises ValidationError for body (caller returns 422). + - None means malformed Basic (caller returns 401). + """ + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Basic "): + return _parse_basic_header(auth_header) + + body = await _read_body(request) + parsed = LoginRequest.model_validate(body if body else {}) + return parsed.username, parsed.password + + +async def _read_body(request: Request) -> dict[str, Any]: + """Read JSON body; return {} on empty or invalid.""" + try: + body = await request.body() + return json.loads(body) if body else {} + except (json.JSONDecodeError, TypeError): + return {} + + +async def login_handler( + request: Request, + *, + login_validator: Callable[[str, str], dict[str, Any] | None], + has_jwt: bool, + jwt_expiration: int | None = None, + jwt_extra_claims: list[str] | None = None, + jwt_algorithm: str | None = None, + rate_limiter: Optional[Any] = None, +) -> JSONResponse: + """ + Handle POST /auth/login and POST /auth/token. + + Returns 422 for body validation, 401 for malformed Basic or invalid credentials, + 500 for validator exception, 200 with token+user (JWT) or user only (Basic). + """ + # Apply rate limiting if a rate limiter is provided + if rate_limiter is not None: + is_limited, window = rate_limiter.is_rate_limited(request, endpoint="auth") + if is_limited: + return rate_limiter.get_rate_limit_response(window) + + if request.method != "POST": + return JSONResponse( + {"detail": "method not allowed"}, + status_code=405, + headers={"Allow": "POST"}, + ) + + try: + creds = await _parse_credentials(request) + except ValidationError as exc: + return JSONResponse({"detail": exc.errors()}, status_code=422) + + if creds is None: + return JSONResponse({"detail": "Invalid credentials"}, status_code=401) + + username, password = creds + try: + payload = login_validator(username, password) + except Exception as e: + logger.exception("login_validator raised: %s", e) + raise + + if payload is None: + return JSONResponse({"detail": "Invalid credentials"}, status_code=401) + + if has_jwt: + from lightapi.auth import JWTAuthentication + + jwt_auth = JWTAuthentication(algorithm=jwt_algorithm) + if jwt_extra_claims and isinstance(payload, dict): + token_payload = {k: payload[k] for k in jwt_extra_claims if k in payload} + if not token_payload: + token_payload = payload + else: + token_payload = payload + token = jwt_auth.generate_token(token_payload, expiration=jwt_expiration) + return JSONResponse({"token": token, "user": payload}) + + return JSONResponse({"user": payload}) diff --git a/lightapi/_registry.py b/lightapi/_registry.py index c4498d8..e6851f6 100644 --- a/lightapi/_registry.py +++ b/lightapi/_registry.py @@ -7,31 +7,53 @@ from __future__ import annotations +from collections.abc import Callable +from typing import Any, cast + from sqlalchemy import MetaData from sqlalchemy.orm import registry -_registry: registry | None = None -_metadata: MetaData | None = None -_engine: object | None = None +LoginValidator = Callable[[str, str], dict[str, Any] | None] + +_state: dict[str, object | None] = { + "registry": None, + "metadata": None, + "engine": None, + "login_validator": None, +} def get_registry_and_metadata() -> tuple[registry, MetaData]: - global _registry, _metadata - if _registry is None or _metadata is None: - _metadata = MetaData() - _registry = registry(metadata=_metadata) - return _registry, _metadata + reg = cast(registry | None, _state["registry"]) + meta = cast(MetaData | None, _state["metadata"]) + if reg is None or meta is None: + meta = MetaData() + reg = registry(metadata=meta) + _state["metadata"] = meta + _state["registry"] = reg + return reg, meta def set_engine(engine: object) -> None: - global _engine - _engine = engine + _state["engine"] = engine def get_engine() -> object: - if _engine is None: + engine = _state["engine"] + if engine is None: raise RuntimeError( "No engine configured. Call LightApi(engine=...) or ensure " "database connection is set before the first request." ) - return _engine + return engine + + +def set_login_validator(validator: LoginValidator) -> None: + _state["login_validator"] = validator + + +def get_login_validator() -> LoginValidator | None: + validator = _state["login_validator"] + if validator is None: + return None + return cast(LoginValidator, validator) diff --git a/lightapi/auth.py b/lightapi/auth.py index 5a7980e..5a1080b 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,3 +1,4 @@ +import base64 from datetime import datetime, timedelta from typing import Dict, Optional @@ -5,6 +6,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse +from ._registry import LoginValidator, get_login_validator from .config import config @@ -16,29 +18,29 @@ class BaseAuthentication: By default, allows all requests. """ - def authenticate(self, request): + def authenticate(self, _request: Request) -> bool: """ Authenticate a request. Args: - request: The HTTP request to authenticate. + _request: The HTTP request to authenticate (unused in base class). Returns: bool: True if authentication succeeds, False otherwise. """ return True - def get_auth_error_response(self, request): + def get_auth_error_response(self, _request: Request) -> JSONResponse: """ Get the response to return when authentication fails. Args: - request: The HTTP request object. + _request: The HTTP request object (unused in base class). Returns: Response object for authentication error. """ - return JSONResponse({"error": "not allowed"}, status_code=403) + return JSONResponse({"error": "authentication failed"}, status_code=401) class JWTAuthentication(BaseAuthentication): @@ -55,16 +57,22 @@ class JWTAuthentication(BaseAuthentication): expiration: Token expiration time in seconds. """ - def __init__(self): - if not config.jwt_secret: + def __init__( + self, + secret_key: str | None = None, + algorithm: str | None = None, + expiration: int | None = None, + ): + self.secret_key = secret_key or config.jwt_secret + if not self.secret_key: raise ValueError( "JWT secret key not configured. Set LIGHTAPI_JWT_SECRET environment variable." ) - self.secret_key = config.jwt_secret - self.algorithm = "HS256" - self.expiration = 3600 # 1 hour default - def authenticate(self, request): + self.algorithm = algorithm or config.jwt_algorithm + self.expiration = expiration or 3600 # 1 hour default + + def authenticate(self, request: Request) -> bool: """ Authenticate a request using JWT token. Automatically allows OPTIONS requests for CORS preflight. @@ -101,7 +109,17 @@ def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str Returns: str: The encoded JWT token. + + Raises: + ValueError: If payload contains 'exp' claim which will be overwritten. """ + # Check for 'exp' in payload since we overwrite it + if "exp" in payload: + raise ValueError( + "Payload contains 'exp' claim which will be overwritten. " + "Use the 'expiration' parameter instead." + ) + exp_seconds = expiration or self.expiration token_data = { **payload, @@ -124,11 +142,88 @@ def decode_token(self, token: str) -> Dict: """ return jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. + + Args: + request: The HTTP request object. + + Returns: + Response object for authentication error. + """ + return JSONResponse({"error": "authentication failed"}, status_code=401) + + +class BasicAuthentication(BaseAuthentication): + """ + Basic (Base64) authentication. + + Authenticates requests using Authorization: Basic . + Delegates credential validation to the provided login_validator. + """ + + def __init__( + self, + login_validator: Optional[LoginValidator] = None, + ) -> None: + self.login_validator = login_validator + + def authenticate(self, request: Request) -> bool: + if request.method == "OPTIONS": + return True + + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.lower().startswith("basic "): + return False + + try: + token = auth_header.split(" ", 1)[1] + decoded = base64.b64decode(token).decode("utf-8") + except (ValueError, IndexError, UnicodeDecodeError): + return False + + parts = decoded.split(":", 1) + if len(parts) != 2: + return False + + username, password = parts[0], parts[1] + validator = self.login_validator or get_login_validator() + if validator is None: + return False + + try: + payload = validator(username, password) + except (ValueError, TypeError, RuntimeError): + return False + + if payload is None: + return False + + request.state.user = payload + return True + + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. + + Args: + request: The HTTP request object. + + Returns: + Response object for authentication error. + """ + return JSONResponse( + {"error": "authentication failed"}, + status_code=401, + headers={"WWW-Authenticate": 'Basic realm="Restricted Area"'}, + ) + class AllowAny: """Permits all requests regardless of authentication state.""" - def has_permission(self, request: Request) -> bool: + def has_permission(self, _request: Request) -> bool: return True diff --git a/lightapi/cache.py b/lightapi/cache.py index b56fb0d..7d5c947 100644 --- a/lightapi/cache.py +++ b/lightapi/cache.py @@ -10,17 +10,18 @@ logger = logging.getLogger(__name__) _REDIS_URL = os.environ.get("LIGHTAPI_REDIS_URL", "redis://localhost:6379/0") -_redis_client: "redis.Redis | None" = None +_redis_state: dict[str, "redis.Redis | None"] = {"client": None} def _get_redis() -> "redis.Redis | None": - global _redis_client - if _redis_client is None: + if _redis_state["client"] is None: try: - _redis_client = redis.from_url(_REDIS_URL, socket_connect_timeout=1) + _redis_state["client"] = redis.from_url( + _REDIS_URL, socket_connect_timeout=1 + ) except Exception: return None - return _redis_client + return _redis_state["client"] def _ping_redis() -> bool: diff --git a/lightapi/config.py b/lightapi/config.py index 02646ae..da66152 100644 --- a/lightapi/config.py +++ b/lightapi/config.py @@ -9,6 +9,20 @@ class _Config: """Configuration used by JWTAuthentication and other components.""" + VALID_JWT_ALGORITHMS = frozenset( + { + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + } + ) + def __init__(self) -> None: self._overrides: dict[str, Any] = {} @@ -25,6 +39,16 @@ def _get(self, key: str, env_key: str, default: Any = None) -> Any: def jwt_secret(self) -> str | None: return self._get("jwt_secret", "LIGHTAPI_JWT_SECRET") + @property + def jwt_algorithm(self) -> str: + algorithm = self._get("jwt_algorithm", "LIGHTAPI_JWT_ALGORITHM", "HS256") + if algorithm not in self.VALID_JWT_ALGORITHMS: + raise ConfigurationError( + f"Invalid JWT algorithm '{algorithm}'. " + f"Valid algorithms are: {sorted(self.VALID_JWT_ALGORITHMS)}" + ) + return algorithm + config = _Config() @@ -32,10 +56,16 @@ def jwt_secret(self) -> str | None: class Authentication: """Authentication configuration for a RestEndpoint.""" + # Standard JWT reserved claims that cannot be used as extra claims + RESERVED_CLAIMS = {"exp", "iat", "nbf", "iss", "sub", "aud", "jti"} + def __init__( self, backend: type | None = None, permission: type | dict[str, type] | None = None, + jwt_expiration: int | None = None, + jwt_extra_claims: list[str] | None = None, + jwt_algorithm: str | None = None, ) -> None: from lightapi.auth import AllowAny @@ -43,6 +73,34 @@ def __init__( self.permission: type | dict[str, type] = ( permission if permission is not None else AllowAny ) + self.jwt_expiration = jwt_expiration + + # Validate jwt_extra_claims - reject reserved claims + if jwt_extra_claims: + reserved_found = [] + for claim in jwt_extra_claims: + if claim in self.RESERVED_CLAIMS: + reserved_found.append(claim) + + if reserved_found: + raise ConfigurationError( + f"JWT extra claims cannot include reserved claims: " + f"{reserved_found}. Reserved claims are: " + f"{sorted(self.RESERVED_CLAIMS)}" + ) + + self.jwt_extra_claims = jwt_extra_claims + + # Validate jwt_algorithm if provided + if ( + jwt_algorithm is not None + and jwt_algorithm not in config.VALID_JWT_ALGORITHMS + ): + raise ConfigurationError( + f"Invalid JWT algorithm: '{jwt_algorithm}'. " + f"Must be one of: {sorted(config.VALID_JWT_ALGORITHMS)}" + ) + self.jwt_algorithm = jwt_algorithm class Filtering: diff --git a/lightapi/lightapi.py b/lightapi/lightapi.py index 8e087e7..10b8324 100644 --- a/lightapi/lightapi.py +++ b/lightapi/lightapi.py @@ -4,10 +4,11 @@ import asyncio import importlib +import json import logging import os import warnings -from typing import Any +from typing import Any, Callable, Dict, Optional import uvicorn from sqlalchemy import create_engine @@ -18,8 +19,16 @@ from starlette.responses import JSONResponse, Response from starlette.routing import Route -from lightapi._registry import get_registry_and_metadata, set_engine +from lightapi._registry import ( + get_registry_and_metadata, + set_engine, + set_login_validator, +) +from lightapi.auth import AllowAny, BasicAuthentication, JWTAuthentication +from lightapi.cache import get_cached, invalidate_cache_prefix, set_cached from lightapi.exceptions import ConfigurationError +from lightapi.rest import RestEndpoint +from lightapi.yaml_loader import load_config logger = logging.getLogger(__name__) @@ -44,6 +53,8 @@ def __init__( database_url: str | None = None, cors_origins: list[str] | None = None, middlewares: list[type] | None = None, + login_validator: Any = None, + auth_path: str = "/auth", ) -> None: if engine is None and database_url: engine = create_engine(database_url) @@ -61,6 +72,7 @@ def __init__( # Detect async engine — drives session strategy and startup validation try: + importlib.import_module("sqlalchemy.ext.asyncio") from sqlalchemy.ext.asyncio import AsyncEngine self._async: bool = isinstance(engine, AsyncEngine) @@ -71,6 +83,11 @@ def __init__( self._endpoint_map: dict[str, type] = {} self._middlewares: list[type] = middlewares or [] self._cors_origins: list[str] = cors_origins or [] + self._login_validator = login_validator + self._auth_path = auth_path + self._auth_rate_limiter = None + if login_validator is not None: + set_login_validator(login_validator) # ───────────────────────────────────────────────────────────────────────── # Registration @@ -83,7 +100,6 @@ def register(self, mapping: dict[str, type]) -> None: mapping: ``{"/path": EndpointClass}`` dictionary. Each class must be a ``RestEndpoint`` subclass. """ - from lightapi.rest import RestEndpoint for path, cls in mapping.items(): if not (isinstance(cls, type) and issubclass(cls, RestEndpoint)): @@ -144,6 +160,103 @@ def register(self, mapping: dict[str, type]) -> None: self._routes.append(detail_route) self._endpoint_map[path] = cls + # Auto-register /auth/login and /auth/token when JWT or Basic auth is used + + auth_backends: set[type] = set() + jwt_config_expiration: int | None = None + jwt_config_extra_claims: list[str] | None = None + jwt_config_algorithm: str | None = None + for cls in self._endpoint_map.values(): + auth_cfg = getattr(cls, "_meta", {}).get("authentication") + if auth_cfg and auth_cfg.backend: + auth_backends.add(auth_cfg.backend) + if ( + auth_cfg.backend is JWTAuthentication + and jwt_config_expiration is None + ): + jwt_config_expiration = getattr(auth_cfg, "jwt_expiration", None) + jwt_config_extra_claims = getattr( + auth_cfg, "jwt_extra_claims", None + ) + jwt_config_algorithm = getattr(auth_cfg, "jwt_algorithm", None) + + if auth_backends & {JWTAuthentication, BasicAuthentication}: + # Initialize rate limiter if not already created + if self._auth_rate_limiter is None: + from lightapi.rate_limiter import RateLimiter + + self._auth_rate_limiter = RateLimiter() + + if self._login_validator is None: + raise ConfigurationError( + "login_validator is required when using JWTAuthentication " + "or BasicAuthentication. Pass it to LightApi(login_validator=...)." + ) + has_jwt = JWTAuthentication in auth_backends + auth_path = self._auth_path.rstrip("/") + + # Check if auth routes already exist + login_path = f"{auth_path}/login" + token_path = f"{auth_path}/token" + + # Remove existing auth routes if any + self._routes = [ + route + for route in self._routes + if not ( + isinstance(route, Route) and route.path in {login_path, token_path} + ) + ] + + login_endpoint = self._make_login_endpoint( + has_jwt=has_jwt, + jwt_expiration=jwt_config_expiration, + jwt_extra_claims=jwt_config_extra_claims, + jwt_algorithm=jwt_config_algorithm, + ) + self._routes.insert( + 0, + Route( + login_path, + login_endpoint, + methods=["POST"], + ), + ) + self._routes.insert( + 1, + Route( + token_path, + login_endpoint, + methods=["POST"], + ), + ) + + def _make_login_endpoint( + self, + *, + has_jwt: bool, + jwt_expiration: int | None, + jwt_extra_claims: list[str] | None, + jwt_algorithm: str | None, + ) -> Any: + """Create the login/token handler with captured config.""" + from lightapi._login import login_handler + + login_validator = self._login_validator + + async def handler(request: Request) -> Response: + return await login_handler( + request, + login_validator=login_validator, + has_jwt=has_jwt, + jwt_expiration=jwt_expiration, + jwt_extra_claims=jwt_extra_claims, + jwt_algorithm=jwt_algorithm, + rate_limiter=self._auth_rate_limiter, + ) + + return handler + def _make_collection_handler(self, cls: type) -> Any: app_middlewares = self._middlewares is_async = self._async @@ -157,7 +270,9 @@ async def handler(request: Request) -> Response: if pre_result is not None: return pre_result - auth_result = _check_auth(cls, request) + auth_result = _check_auth( + cls, request, login_validator=self._login_validator + ) if auth_result is not None: return auth_result @@ -214,7 +329,9 @@ async def handler(request: Request) -> Response: if pre_result is not None: return pre_result - auth_result = _check_auth(cls, request) + auth_result = _check_auth( + cls, request, login_validator=self._login_validator + ) if auth_result is not None: return auth_result @@ -331,7 +448,6 @@ def from_config(cls, config_path: str, **kwargs: Any) -> "LightApi": Kwargs override YAML values (e.g. engine=..., database_url=...). """ - from lightapi.yaml_loader import load_config return load_config(cls, config_path, **kwargs) @@ -426,7 +542,6 @@ def _validate_async_dependencies(engine: Any) -> None: async def _read_body(request: Request) -> dict[str, Any]: """Read and parse JSON body; return {} on failure.""" - import json try: body = await request.body() @@ -435,9 +550,12 @@ async def _read_body(request: Request) -> dict[str, Any]: return {} -def _check_auth(cls: type, request: Request) -> Response | None: +def _check_auth( + cls: type, + request: Request, + login_validator: Optional[Callable[[str, str], Dict[str, Any] | None]] = None, +) -> Response | None: """Run authentication + permission checks; return 401/403 response or None.""" - from lightapi.auth import AllowAny auth_cfg = cls._meta.get("authentication") if auth_cfg is None: @@ -461,7 +579,17 @@ def _check_auth(cls: type, request: Request) -> Response | None: perm_cls = AllowAny if backend is not None: - authenticator = backend() + # Pass JWT configuration if backend is JWTAuthentication + if backend.__name__ == "JWTAuthentication": + authenticator = backend( + expiration=getattr(auth_cfg, "jwt_expiration", None), + algorithm=getattr(auth_cfg, "jwt_algorithm", None), + ) + elif backend.__name__ == "BasicAuthentication": + authenticator = backend(login_validator=login_validator) + else: + authenticator = backend() + if not authenticator.authenticate(request): return JSONResponse( {"detail": "Authentication credentials invalid."}, status_code=401 @@ -510,7 +638,6 @@ async def _run_post_middlewares( def _maybe_cached(cls: type, request: Request, fn: Any) -> Response: """Serve from Redis cache (GET only) or call fn() and populate cache.""" - from lightapi.cache import get_cached, set_cached cache_cfg = cls._meta.get("cache") if cache_cfg is None: @@ -541,7 +668,6 @@ def _maybe_invalidate_cache(cls: type, request: Request) -> None: cache_cfg = cls._meta.get("cache") if cache_cfg is None: return - from lightapi.cache import invalidate_cache_prefix invalidate_cache_prefix(_cache_key_prefix(cls)) diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py new file mode 100644 index 0000000..e46bb22 --- /dev/null +++ b/lightapi/rate_limiter.py @@ -0,0 +1,176 @@ +"""Simple rate limiting for authentication endpoints.""" + +from __future__ import annotations + +import time +from collections import defaultdict + +from starlette.requests import Request +from starlette.responses import JSONResponse + + +class RateLimiter: + """ + Simple in-memory rate limiter. + + Tracks requests by IP address and endpoint. + + NOTE: This implementation uses process-local counters. In a multi-process + deployment (e.g., with multiple workers), rate limiting will not be shared + across processes. For production use with multiple workers, consider using + a shared storage backend like Redis. + """ + + def __init__( + self, + requests_per_minute: int = 10, + requests_per_hour: int = 100, + requests_per_day: int = 1000, + ) -> None: + self.requests_per_minute = requests_per_minute + self.requests_per_hour = requests_per_hour + self.requests_per_day = requests_per_day + + # Storage: {ip: {window: {timestamp: count}}} + self._store: dict[str, dict[str, dict[float, int]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(int)) + ) + self._cleanup_interval = 300 # Cleanup every 5 minutes + self._last_cleanup = time.time() + + def _cleanup_old_entries(self) -> None: + """Remove old entries to prevent memory leak.""" + current_time = time.time() + if current_time - self._last_cleanup < self._cleanup_interval: + return + + for ip in list(self._store.keys()): + for window in list(self._store[ip].keys()): + # Remove entries older than window size + window_seconds = self._get_window_seconds(window) + cutoff = current_time - window_seconds + + # Remove old timestamps + for timestamp in list(self._store[ip][window].keys()): + if timestamp < cutoff: + del self._store[ip][window][timestamp] + + # Remove empty windows + if not self._store[ip][window]: + del self._store[ip][window] + + # Remove IPs with no windows + if not self._store[ip]: + del self._store[ip] + + self._last_cleanup = current_time + + def _get_window_seconds(self, window: str) -> int: + """Convert window name to seconds.""" + # Extract window name from key (e.g., "auth:minute" -> "minute") + window_name = window.split(":")[-1] if ":" in window else window + + if window_name == "minute": + return 60 + elif window_name == "hour": + return 3600 + elif window_name == "day": + return 86400 + else: + return 60 + + def _get_client_ip(self, request: Request) -> str: + """Extract client IP from request.""" + # Always prefer the actual client host for security + if request.client and request.client.host: + return request.client.host + + # Fallback to headers only if client host is not available + for header in ("X-Forwarded-For", "X-Real-IP", "X-Client-IP"): + if header in request.headers: + ip = request.headers[header].split(",")[0].strip() + if ip: + return ip + + return "0.0.0.0" + + def is_rate_limited( + self, request: Request, endpoint: str = "" + ) -> tuple[bool, str | None]: + """ + Check if request should be rate limited. + + Args: + request: The HTTP request. + endpoint: Optional endpoint identifier for per-endpoint limiting. + + Returns: + tuple[bool, str | None]: (True if rate limited, window name) or + (False, None) + """ + self._cleanup_old_entries() + + client_ip = self._get_client_ip(request) + current_time = time.time() + + # Check all windows first before incrementing + windows = [ + ("minute", self.requests_per_minute), + ("hour", self.requests_per_hour), + ("day", self.requests_per_day), + ] + + # First pass: check all windows + for window_name, limit in windows: + window_seconds = self._get_window_seconds(window_name) + window_key = f"{endpoint}:{window_name}" if endpoint else window_name + + # Count requests in this window + count = 0 + for timestamp, request_count in self._store[client_ip][window_key].items(): + if current_time - timestamp < window_seconds: + count += request_count + + if count >= limit: + # Don't count this request since it's being blocked + return (True, window_name) + + # Second pass: increment all windows (only if request is allowed) + for window_name, _ in windows: + window_seconds = self._get_window_seconds(window_name) + window_key = f"{endpoint}:{window_name}" if endpoint else window_name + self._store[client_ip][window_key][current_time] = ( + self._store[client_ip][window_key].get(current_time, 0) + 1 + ) + + return (False, None) + + def get_rate_limit_response(self, window: str = "minute") -> JSONResponse: + """Get standard rate limit exceeded response.""" + # Determine window-specific values + if window == "hour": + limit = self.requests_per_hour + retry_after = 3600 + reset_seconds = 3600 + elif window == "day": + limit = self.requests_per_day + retry_after = 86400 + reset_seconds = 86400 + else: # minute + limit = self.requests_per_minute + retry_after = 60 + reset_seconds = 60 + + return JSONResponse( + { + "error": "rate_limit_exceeded", + "detail": "Too many requests. Please try again later.", + }, + status_code=429, + headers={ + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time() + reset_seconds)), + }, + ) diff --git a/lightapi/yaml_loader.py b/lightapi/yaml_loader.py index ebed063..feb4339 100644 --- a/lightapi/yaml_loader.py +++ b/lightapi/yaml_loader.py @@ -43,7 +43,13 @@ def _build_name_registry() -> dict[str, type]: - from lightapi.auth import AllowAny, IsAdminUser, IsAuthenticated, JWTAuthentication + from lightapi.auth import ( + AllowAny, + BasicAuthentication, + IsAdminUser, + IsAuthenticated, + JWTAuthentication, + ) from lightapi.core import AuthenticationMiddleware, CORSMiddleware, Middleware from lightapi.filters import ( FieldFilter, @@ -55,6 +61,7 @@ def _build_name_registry() -> dict[str, type]: return { # Auth backends "JWTAuthentication": JWTAuthentication, + "BasicAuthentication": BasicAuthentication, # Permissions "AllowAny": AllowAny, "IsAuthenticated": IsAuthenticated, @@ -76,6 +83,65 @@ def _build_name_registry() -> dict[str, type]: } +def _resolve_callable(dotted_path: str) -> Any: + """Resolve a dotted path like 'myapp.validators.validate_login' to a callable.""" + if "." not in dotted_path: + raise ConfigurationError( + f"login_validator must be a dotted path (e.g. myapp.validators.check), " + f"got '{dotted_path}'" + ) + module_path, attr_name = dotted_path.rsplit(".", 1) + try: + mod = importlib.import_module(module_path) + fn = getattr(mod, attr_name) + except (ImportError, AttributeError) as exc: + raise ConfigurationError( + f"Cannot resolve login_validator '{dotted_path}': {exc}" + ) from exc + if not callable(fn): + raise ConfigurationError(f"login_validator '{dotted_path}' is not callable.") + + # Validate signature: must be sync function with exactly 2 positional args + import inspect + + if inspect.iscoroutinefunction(fn): + raise ConfigurationError( + f"login_validator '{dotted_path}' is async, but sync function required. " + f"Login validation must be a synchronous function." + ) + + try: + sig = inspect.signature(fn) + except (ValueError, TypeError) as exc: + # Only allow specific cases where signature inspection legitimately fails + if hasattr(fn, "__name__") and fn.__name__ in ("",): + # Lambdas can't be properly inspected in some Python versions + return fn + # For other cases, raise a clear error about the validation function + raise ValueError( + f"Login validation function {fn!r} cannot be inspected: {exc}. " + f"Ensure it's a regular Python function with inspectable signature." + ) from exc + + # Count required positional parameters + required_params = 0 + for param in sig.parameters.values(): + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + if param.default is inspect.Parameter.empty: + required_params += 1 + + if required_params != 2: + raise ConfigurationError( + f"login_validator '{dotted_path}' must accept exactly 2 required " + f"positional parameters (username, password), got {required_params}" + ) + + return fn + + def _resolve_name(name: str) -> type: """Resolve a class name string to a class. @@ -129,11 +195,26 @@ def substitute_env(cls, v: str) -> str: return _substitute_env(v) +class AuthLoginConfig(BaseModel): + """Login/auth block: auth: { auth_path: ..., login_validator: ... }. + + When using JWTAuthentication or BasicAuthentication, login_validator is required. + It can be specified as a dotted path (e.g. myapp.validators.validate_login) + or passed as an override to from_config(login_validator=...). + """ + + auth_path: str = "/auth" + login_validator: str | None = None + + class AuthConfig(BaseModel): """Authentication block used in defaults and per-endpoint meta.""" backend: str | None = None permission: Union[str, dict[str, str], None] = None + jwt_expiration: int | None = None + jwt_extra_claims: list[str] | None = None + jwt_algorithm: str | None = None class FilteringConfig(BaseModel): @@ -220,6 +301,7 @@ class LightAPIConfig(BaseModel): defaults: DefaultsConfig = DefaultsConfig() endpoints: list[EndpointConfig] = [] middleware: list[str] = [] + auth: AuthLoginConfig | None = None @property def effective_database_url(self) -> str | None: @@ -245,14 +327,15 @@ def _substitute_env(value: str) -> str: def _make_authentication( - auth_cfg: AuthConfig | None, defaults_auth: AuthConfig | None + auth_cfg: AuthConfig | None, + defaults_auth: AuthConfig | None, ) -> Any: - """Build an Authentication instance from an AuthConfig, merged with defaults.""" - from lightapi.config import Authentication - # Merge: explicit cfg wins over defaults, defaults fill gaps merged_backend = None merged_permission = None + merged_jwt_expiration = None + merged_jwt_extra_claims = None + merged_jwt_algorithm = None for source in (defaults_auth, auth_cfg): if source is None: @@ -261,6 +344,12 @@ def _make_authentication( merged_backend = source.backend if source.permission is not None: merged_permission = source.permission + if source.jwt_expiration is not None: + merged_jwt_expiration = source.jwt_expiration + if source.jwt_extra_claims is not None: + merged_jwt_extra_claims = source.jwt_extra_claims + if source.jwt_algorithm is not None: + merged_jwt_algorithm = source.jwt_algorithm if merged_backend is None and merged_permission is None: return None @@ -277,7 +366,15 @@ def _make_authentication( else: permission = None - return Authentication(backend=backend_cls, permission=permission) + from lightapi.config import Authentication + + return Authentication( + backend=backend_cls, + permission=permission, + jwt_expiration=merged_jwt_expiration, + jwt_extra_claims=merged_jwt_extra_claims, + jwt_algorithm=merged_jwt_algorithm, + ) def _make_filtering(filtering_cfg: FilteringConfig | None) -> Any: @@ -335,23 +432,72 @@ def _build_meta_class( # Build per-method permission dict from the methods dict permission_map: dict[str, type] = {} method_auth_default = meta.authentication # endpoint-level auth override + + # Get default permission from defaults + default_permission = None + if defaults.authentication and defaults.authentication.permission: + default_permission = defaults.authentication.permission + for method, method_cfg in meta.methods.items(): cfg_auth = method_cfg.authentication if method_cfg else None src_auth = cfg_auth or method_auth_default + + # Determine permission for this method + method_permission = None + + # First, check method-specific auth if src_auth and src_auth.permission: perm = src_auth.permission if isinstance(perm, str): - permission_map[method] = _resolve_name(perm) + method_permission = _resolve_name(perm) + # Note: per-method auth in YAML dict can't have dict permissions + # only string permissions are supported in this path + + # If no method-specific permission, check default permission + if method_permission is None and default_permission is not None: + if isinstance(default_permission, dict): + # Default permission is a dict: check for method-specific default + if method in default_permission: + perm = default_permission[method] + if isinstance(perm, str): + method_permission = _resolve_name(perm) + else: + # Default permission is a string or class + if isinstance(default_permission, str): + method_permission = _resolve_name(default_permission) + else: + method_permission = default_permission + + if method_permission is not None: + permission_map[method] = method_permission + if permission_map: - # Determine shared backend (from endpoint auth or defaults) - backend_name = (meta.authentication and meta.authentication.backend) or ( - defaults.authentication and defaults.authentication.backend - ) + # Merge auth settings similar to _make_authentication + merged_backend = None + merged_jwt_expiration = None + merged_jwt_extra_claims = None + merged_jwt_algorithm = None + + for source in (defaults.authentication, meta.authentication): + if source is None: + continue + if source.backend is not None: + merged_backend = source.backend + if source.jwt_expiration is not None: + merged_jwt_expiration = source.jwt_expiration + if source.jwt_extra_claims is not None: + merged_jwt_extra_claims = source.jwt_extra_claims + if source.jwt_algorithm is not None: + merged_jwt_algorithm = source.jwt_algorithm + from lightapi.config import Authentication attrs["authentication"] = Authentication( - backend=_resolve_name(backend_name) if backend_name else None, + backend=_resolve_name(merged_backend) if merged_backend else None, permission=permission_map, + jwt_expiration=merged_jwt_expiration, + jwt_extra_claims=merged_jwt_extra_claims, + jwt_algorithm=merged_jwt_algorithm, ) else: # Simple auth: endpoint overrides defaults @@ -480,6 +626,15 @@ def load_config(app_cls: type, config_path: str, **overrides: Any) -> Any: "cors_origins": cfg.cors_origins or None, "middlewares": middlewares or None, } + + # Auth/login config from YAML auth: block + if cfg.auth: + constructor_kwargs["auth_path"] = cfg.auth.auth_path + if cfg.auth.login_validator: + constructor_kwargs["login_validator"] = _resolve_callable( + cfg.auth.login_validator + ) + constructor_kwargs.update(overrides) instance = app_cls(**constructor_kwargs) diff --git a/tests/test_auth.py b/tests/test_auth.py index 11b211a..33c7db7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,7 @@ """Tests for US3: Authentication and Permission classes.""" import os +from typing import Any import pytest from sqlalchemy import create_engine @@ -59,6 +60,11 @@ def jwt_secret(monkeypatch_session=None): return secret +def _login_validator(_username: str, _password: str) -> dict[str, Any] | None: + """Test validator; always returns None (tests use _make_token for tokens).""" + return None + + @pytest.fixture(scope="module") def client(jwt_secret): os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -67,7 +73,7 @@ def client(jwt_secret): connect_args={"check_same_thread": False}, poolclass=StaticPool, ) - app_instance = LightApi(engine=engine) + app_instance = LightApi(engine=engine, login_validator=_login_validator) app_instance.register( { "/secrets": SecretEndpoint, diff --git a/tests/test_login_auth.py b/tests/test_login_auth.py new file mode 100644 index 0000000..4903177 --- /dev/null +++ b/tests/test_login_auth.py @@ -0,0 +1,984 @@ +"""Tests for login and token endpoints (US-L1, US-L2, US-L3, US-L4).""" + +import base64 +import os + +import jwt +import pytest +from pydantic import ValidationError +from sqlalchemy import create_engine +from sqlalchemy.pool import StaticPool +from starlette.testclient import TestClient + +from lightapi import ( + Authentication, + BasicAuthentication, + JWTAuthentication, + LightApi, + RestEndpoint, +) +from lightapi._login import LoginRequest +from lightapi.exceptions import ConfigurationError +from lightapi.fields import Field as LField + + +def _valid_validator(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "alice@example.com", "is_admin": False} + return None + + +def _valid_validator_admin(username: str, password: str): + if username == "admin" and password == "admin": + return {"sub": "2", "is_admin": True} + return None + + +class JWTProtectedEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=JWTAuthentication) + + +class BasicProtectedEndpoint(RestEndpoint): + name: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=BasicAuthentication) + + +@pytest.fixture(scope="module") +def jwt_client(): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/secrets": JWTProtectedEndpoint}) + return TestClient(app.build_app()) + + +@pytest.fixture(scope="module") +def basic_client(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/items": BasicProtectedEndpoint}) + return TestClient(app.build_app()) + + +class TestJWTTokenEndpoint: + def test_login_valid_credentials_returns_token(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert "user" in data + assert data["user"]["sub"] == "1" + assert data["user"]["email"] == "alice@example.com" + + def test_login_invalid_credentials_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "wrong"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_missing_body_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={}) + assert resp.status_code == 422 + + def test_token_endpoint_same_as_login(self, jwt_client): + resp = jwt_client.post( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + assert "user" in resp.json() + + +class TestBodyValidation: + def test_login_missing_username_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={"password": "x"}) + assert resp.status_code == 422 + detail = resp.json()["detail"] + assert any("username" in str(e.get("loc", [])) for e in detail) + + def test_login_missing_password_returns_422(self, jwt_client): + resp = jwt_client.post("/auth/login", json={"username": "x"}) + assert resp.status_code == 422 + detail = resp.json()["detail"] + assert any("password" in str(e.get("loc", [])) for e in detail) + + def test_login_empty_username_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "", "password": "x"}, + ) + assert resp.status_code == 422 + + def test_login_empty_password_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "x", "password": ""}, + ) + assert resp.status_code == 422 + + def test_login_invalid_json_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content="not json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + +class TestBasicHeaderInput: + def test_login_basic_header_valid_returns_token(self, jwt_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert data["user"]["sub"] == "1" + + def test_login_basic_header_takes_precedence_over_body(self, jwt_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + json={"username": "alice", "password": "wrong"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_login_basic_header_malformed_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic not-valid-base64!!"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_basic_header_no_colon_returns_401(self, jwt_client): + creds = base64.b64encode(b"nocolon").decode() + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + assert resp.json()["detail"] == "Invalid credentials" + + def test_login_bearer_header_falls_through_to_body_validation(self, jwt_client): + """Authorization: Bearer x (no Basic) + empty body → 422.""" + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Bearer x"}, + json={}, + ) + assert resp.status_code == 422 + + +class TestBasicAuthEndpoint: + def test_basic_only_returns_user_without_token(self, basic_client): + resp = basic_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "user" in data + assert "token" not in data + assert data["user"]["sub"] == "1" + + +class TestHttpMethodAndErrorFormat: + def test_login_get_returns_405(self, jwt_client): + resp = jwt_client.get("/auth/login") + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + +class TestConfigurationAndRegistration: + def test_register_jwt_without_validator_raises_configuration_error(self): + os.environ.setdefault("LIGHTAPI_JWT_SECRET", "test-secret-key") + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine) + with pytest.raises(ConfigurationError, match="login_validator"): + app.register({"/secrets": JWTProtectedEndpoint}) + + def test_register_basic_without_validator_raises_configuration_error(self): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine) + with pytest.raises(ConfigurationError, match="login_validator"): + app.register({"/items": BasicProtectedEndpoint}) + + def test_auth_path_customization(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/api/auth", + ) + app.register({"/secrets": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/api/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + +class TestJWTConfigOverrides: + def test_jwt_expiration_override(self): + class JWTWithExpEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_expiration=10, + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTWithExpEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + exp = payload.get("exp") + assert exp is not None + import time + + assert 0 < exp - int(time.time()) <= 15 + + def test_jwt_extra_claims_filters_payload(self): + def validator_with_extra(username: str, password: str): + if username == "alice" and password == "secret": + return { + "sub": "1", + "email": "a@b.com", + "secret": "must-not-appear", + } + return None + + class JWTWithExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["sub", "email"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_with_extra) + app.register({"/x": JWTWithExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert "email" in payload + assert "secret" not in payload + + +class TestBasicAuthProtectedEndpoints: + def test_basic_protected_get_with_valid_header_returns_200(self, basic_client): + creds = base64.b64encode(b"alice:secret").decode() + resp = basic_client.get( + "/items", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + + def test_basic_protected_get_without_auth_returns_401(self, basic_client): + resp = basic_client.get("/items") + assert resp.status_code == 401 + + def test_basic_protected_get_invalid_credentials_returns_401(self, basic_client): + creds = base64.b64encode(b"alice:wrong").decode() + resp = basic_client.get( + "/items", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + +class TestTokenUsability: + def test_jwt_token_usable_on_protected_endpoint(self, jwt_client): + login_resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert login_resp.status_code == 200 + token = login_resp.json()["token"] + resp = jwt_client.get( + "/secrets", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + + +class TestValidatorException: + def test_login_validator_exception_returns_500(self): + """Validator raising propagates to 500 (Starlette default handler).""" + + def failing_validator(username: str, password: str): + raise RuntimeError("DB unavailable") + + os.environ.setdefault("LIGHTAPI_JWT_SECRET", "test-secret-key") + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=failing_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app(), raise_server_exceptions=False) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 500 + + +class TestLoginRequestModel: + def test_login_request_validates_min_length(self): + with pytest.raises(ValidationError): + LoginRequest.model_validate({"username": "x", "password": ""}) + + +class TestBasicHeaderEdgeCases: + def test_login_basic_header_password_with_colon_parsed_correctly(self): + def validator_colon(username: str, password: str): + if username == "alice" and password == "pass:word": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_colon) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:pass:word").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + assert resp.json()["user"]["sub"] == "1" + + def test_login_basic_header_empty_username_returns_401(self): + def validator_reject_empty(username: str, password: str): + if not username or not password: + return None + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_reject_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b":secret").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_empty_password_returns_401(self): + def validator_reject_empty(username: str, password: str): + if not username or not password: + return None + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_reject_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_unicode_credentials_success(self): + def validator_unicode(username: str, password: str): + if username == "josé" and password == "contraseña": + return {"sub": "1"} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_unicode) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode("josé:contraseña".encode("utf-8")).decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code == 200 + assert resp.json()["user"]["sub"] == "1" + + def test_login_basic_header_lowercase_falls_through_to_body(self): + """Authorization: basic x (lowercase) does not match Basic; falls to body.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + creds = base64.b64encode(b"alice:secret").decode() + resp = client.post( + "/auth/login", + headers={"Authorization": f"basic {creds}"}, + json={}, + ) + assert resp.status_code == 422 + + def test_login_basic_header_no_value_returns_401(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + headers={"Authorization": "Basic "}, + json={}, + ) + assert resp.status_code == 401 + + def test_login_basic_header_multiple_spaces_after_basic(self): + """Authorization: Basic - leading space may affect decode.""" + creds = base64.b64encode(b"alice:secret").decode() + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + headers={"Authorization": f"Basic {creds}"}, + ) + assert resp.status_code in (200, 401) + if resp.status_code == 200: + assert resp.json()["user"]["sub"] == "1" + + +class TestBodyEdgeCases: + def test_login_body_with_extra_keys_ignored_success(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret", "extra": "x"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_login_body_wrong_type_username_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": 123, "password": "x"}, + ) + assert resp.status_code == 422 + + def test_login_truncated_json_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content='{"username": "a",', + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + def test_login_empty_body_with_content_type_returns_422(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + content="{}", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + +class TestAuthPathEdgeCases: + def test_auth_path_trailing_slash_normalized(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/auth/", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_auth_path_root_routes_at_login_and_token(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_auth_path_nested(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi( + engine=engine, + login_validator=_valid_validator, + auth_path="/api/v1/auth", + ) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/api/v1/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + +class TestJWTExtraClaimsEdgeCases: + def test_jwt_extra_claims_empty_list_uses_full_payload(self): + def validator_full(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "a@b.com"} + return None + + class JWTEmptyExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=[], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_full) + app.register({"/x": JWTEmptyExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert "email" in payload + + def test_jwt_extra_claims_all_keys_missing_fallback_to_full_payload(self): + def validator_minimal(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1"} + return None + + class JWTMissingExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["x", "y"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_minimal) + app.register({"/x": JWTMissingExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + + def test_jwt_extra_claims_partial_overlap_only_included_in_token(self): + def validator_partial(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "email": "a@b.com"} + return None + + class JWTPartialExtraEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_extra_claims=["sub", "missing"], + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_partial) + app.register({"/x": JWTPartialExtraEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "sub" in payload + assert payload["sub"] == "1" + assert "missing" not in payload + + +class TestValidatorReturnShapes: + def test_validator_empty_dict_returns_200(self): + def validator_empty(username: str, password: str): + if username == "alice" and password == "secret": + return {} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_empty) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "token" in data + assert data["user"] == {} + token = data["token"] + payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) + assert "exp" in payload + + def test_validator_nested_payload_success(self): + def validator_nested(username: str, password: str): + if username == "alice" and password == "secret": + return {"sub": "1", "meta": {"role": "admin"}} + return None + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=validator_nested) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["user"]["sub"] == "1" + assert data["user"]["meta"]["role"] == "admin" + + +class TestMixedAppScenarios: + def test_both_jwt_and_basic_endpoints_returns_token(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register( + { + "/secrets": JWTProtectedEndpoint, + "/items": BasicProtectedEndpoint, + } + ) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" in resp.json() + + def test_basic_only_auth_token_same_as_login(self, basic_client): + resp = basic_client.post( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "user" in data + assert "token" not in data + + def test_allowany_only_no_auth_routes_404(self): + from lightapi import AllowAny + + class AllowAnyEndpoint(RestEndpoint): + name: str = LField(min_length=1) + + class Meta: + authentication = Authentication(backend=AllowAny) + + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/public": AllowAnyEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 404 + + +class TestMethodEdgeCases: + def test_put_auth_login_returns_405(self, jwt_client): + resp = jwt_client.put( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + def test_patch_auth_token_returns_405(self, jwt_client): + resp = jwt_client.patch( + "/auth/token", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + def test_delete_auth_login_returns_405(self, jwt_client): + resp = jwt_client.delete("/auth/login") + assert resp.status_code == 405 + assert resp.headers.get("Allow") == "POST" + + +class TestResponseStructure: + def test_jwt_mode_response_has_exactly_token_and_user(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert set(data.keys()) == {"token", "user"} + + def test_basic_mode_response_has_no_token_key(self, basic_client): + resp = basic_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + assert "token" not in resp.json() + assert "user" in resp.json() + + +class TestTokenVerification: + def test_expired_token_rejected_on_protected_endpoint(self): + import time + + class JWTExpiringEndpoint(RestEndpoint): + content: str = LField(min_length=1) + + class Meta: + authentication = Authentication( + backend=JWTAuthentication, + jwt_expiration=1, + ) + + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTExpiringEndpoint}) + client = TestClient(app.build_app()) + resp = client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + time.sleep(2) + get_resp = client.get("/x", headers={"Authorization": f"Bearer {token}"}) + assert get_resp.status_code == 401 + + def test_wrong_secret_token_rejected(self): + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" + wrong_token = jwt.encode( + {"sub": "1", "exp": 9999999999}, + "wrong-secret", + algorithm="HS256", + ) + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + app = LightApi(engine=engine, login_validator=_valid_validator) + app.register({"/x": JWTProtectedEndpoint}) + client = TestClient(app.build_app()) + resp = client.get( + "/x", + headers={"Authorization": f"Bearer {wrong_token}"}, + ) + assert resp.status_code == 401 + + def test_token_structure_has_three_parts(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 200 + token = resp.json()["token"] + assert len(token.split(".")) == 3 + + +class TestBasicVsBodyPrecedence: + def test_malformed_basic_overrides_valid_body_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic not-valid-base64!!"}, + json={"username": "alice", "password": "secret"}, + ) + assert resp.status_code == 401 + + def test_invalid_basic_and_empty_body_returns_401(self, jwt_client): + resp = jwt_client.post( + "/auth/login", + headers={"Authorization": "Basic !!!"}, + json={}, + ) + assert resp.status_code == 401 diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index 6baab36..fc76a8d 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -26,10 +26,18 @@ def _write_yaml(content: str | dict) -> str: return f.name -def _from_str(content: str) -> LightApi: +def _dummy_login_validator(username: str, password: str): + return None + + +def _from_str(content: str, login_validator=None) -> LightApi: path = _write_yaml(content) try: - return LightApi.from_config(path) + # Only pass login_validator if explicitly provided + kwargs = {} + if login_validator is not None: + kwargs["login_validator"] = login_validator + return LightApi.from_config(path, **kwargs) finally: os.unlink(path) @@ -299,6 +307,169 @@ def test_from_config_kwargs_override_yaml(self): os.unlink(path) +# ───────────────────────────────────────────────────────────────────────────── +# Auth/login YAML configuration +# ───────────────────────────────────────────────────────────────────────────── + + +class TestYamlAuthConfig: + def test_auth_path_from_yaml(self): + """auth.auth_path in YAML configures login/token route prefix.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + content = """\ + database: + url: "sqlite:///:memory:" + auth: + auth_path: /api/auth + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + from starlette.testclient import TestClient + + client = TestClient(app.build_app()) + resp = client.post( + "/api/auth/login", + json={"username": "a", "password": "b"}, + ) + assert resp.status_code in (200, 401) + + def test_login_validator_dotted_path_from_yaml(self): + """auth.login_validator as dotted path resolves to callable.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + + # Create a simple validator function + def test_validator(username: str, password: str): + return {"sub": "test"} + + # Monkey-patch it into the module + import tests.test_yaml_config + + tests.test_yaml_config.test_validator_func = test_validator + + try: + content = """\ + database: + url: "sqlite:///:memory:" + auth: + login_validator: tests.test_yaml_config.test_validator_func + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + # The validator from YAML should be resolved and used + assert app._login_validator is test_validator + finally: + # Clean up + delattr(tests.test_yaml_config, "test_validator_func") + + def test_jwt_expiration_from_defaults(self): + """defaults.authentication.jwt_expiration flows to Meta.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + jwt_expiration: 300 + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + cls = app._endpoint_map["/x"] + assert cls.Meta.authentication.jwt_expiration == 300 + + def test_jwt_extra_claims_from_defaults(self): + """defaults.authentication.jwt_extra_claims flows to Meta.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + jwt_extra_claims: [sub, email] + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + cls = app._endpoint_map["/x"] + assert cls.Meta.authentication.jwt_extra_claims == ["sub", "email"] + + def test_basic_authentication_from_yaml(self): + """BasicAuthentication can be specified as backend in YAML.""" + content = """\ + database: + url: "sqlite:///:memory:" + defaults: + authentication: + backend: BasicAuthentication + permission: IsAuthenticated + endpoints: + - route: /items + fields: + name: { type: str } + meta: + methods: [GET] + """ + app = _from_str(content) + from lightapi.auth import BasicAuthentication + + cls = app._endpoint_map["/items"] + assert cls.Meta.authentication.backend is BasicAuthentication + + def test_login_validator_invalid_dotted_path_raises(self): + """Invalid login_validator dotted path raises ConfigurationError.""" + os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret" + content = """\ + database: + url: "sqlite:///:memory:" + auth: + login_validator: non.existent.module.foo + defaults: + authentication: + backend: JWTAuthentication + permission: IsAuthenticated + endpoints: + - route: /x + fields: + data: { type: str } + meta: + methods: [GET] + """ + path = _write_yaml(content) + try: + with pytest.raises(ConfigurationError, match="login_validator"): + LightApi.from_config(path) + finally: + os.unlink(path) + + # ───────────────────────────────────────────────────────────────────────────── # _resolve_name utility # ───────────────────────────────────────────────────────────────────────────── diff --git a/uv.lock b/uv.lock index cfc1d4a..8a87de6 100644 --- a/uv.lock +++ b/uv.lock @@ -526,7 +526,7 @@ wheels = [ [[package]] name = "lightapi" -version = "0.1.16" +version = "0.1.20" source = { editable = "." } dependencies = [ { name = "pydantic" }, From 62ae56b153d61b326a9e8617bf804cea42c1201e Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 15:42:54 -0300 Subject: [PATCH 2/5] fix: resolve linter and type checking issues - Fix ruff linting: remove unused import from auth.py - Fix mypy type errors in rate_limiter.py: correct function indentation - Fix mypy type errors in _login.py: add proper type hints for JSON parsing - Add RateLimiter import and type annotation in _login.py - Create tests/__init__.py to fix mypy module resolution - Fix import sorting in _login.py --- lightapi/_login.py | 11 +++++-- lightapi/auth.py | 2 +- lightapi/rate_limiter.py | 63 ++++++++++++++++++++-------------------- tests/__init__.py | 0 4 files changed, 40 insertions(+), 36 deletions(-) create mode 100644 tests/__init__.py diff --git a/lightapi/_login.py b/lightapi/_login.py index 0e205d6..8bfc25d 100644 --- a/lightapi/_login.py +++ b/lightapi/_login.py @@ -12,6 +12,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse +from .rate_limiter import RateLimiter + # JWTAuthentication imported locally where needed to avoid circular import logger = logging.getLogger(__name__) @@ -67,7 +69,10 @@ async def _read_body(request: Request) -> dict[str, Any]: """Read JSON body; return {} on empty or invalid.""" try: body = await request.body() - return json.loads(body) if body else {} + if body: + result: dict[str, Any] = json.loads(body) + return result + return {} except (json.JSONDecodeError, TypeError): return {} @@ -80,7 +85,7 @@ async def login_handler( jwt_expiration: int | None = None, jwt_extra_claims: list[str] | None = None, jwt_algorithm: str | None = None, - rate_limiter: Optional[Any] = None, + rate_limiter: Optional[RateLimiter] = None, ) -> JSONResponse: """ Handle POST /auth/login and POST /auth/token. @@ -91,7 +96,7 @@ async def login_handler( # Apply rate limiting if a rate limiter is provided if rate_limiter is not None: is_limited, window = rate_limiter.is_rate_limited(request, endpoint="auth") - if is_limited: + if is_limited and window is not None: return rate_limiter.get_rate_limit_response(request, window) if request.method != "POST": diff --git a/lightapi/auth.py b/lightapi/auth.py index 7cb29d8..2de7a4e 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,6 +1,6 @@ import base64 from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from typing import Any, Optional import jwt from starlette.requests import Request diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py index 4af515a..1268e6b 100644 --- a/lightapi/rate_limiter.py +++ b/lightapi/rate_limiter.py @@ -140,35 +140,34 @@ def is_rate_limited( return (False, None) - -def get_rate_limit_response( - self, request: Request, window: str = "minute" -) -> JSONResponse: - """Get standard rate limit exceeded response.""" - # Determine window-specific values - if window == "hour": - limit = self.requests_per_hour - retry_after = 3600 - reset_seconds = 3600 - elif window == "day": - limit = self.requests_per_day - retry_after = 86400 - reset_seconds = 86400 - else: # minute - limit = self.requests_per_minute - retry_after = 60 - reset_seconds = 60 - - return JSONResponse( - { - "error": "rate_limit_exceeded", - "detail": "Too many requests. Please try again later.", - }, - status_code=429, - headers={ - "Retry-After": str(retry_after), - "X-RateLimit-Limit": str(limit), - "X-RateLimit-Remaining": "0", - "X-RateLimit-Reset": str(int(time.time() + reset_seconds)), - }, - ) + def get_rate_limit_response( + self, request: Request, window: str = "minute" + ) -> JSONResponse: + """Get standard rate limit exceeded response.""" + # Determine window-specific values + if window == "hour": + limit = self.requests_per_hour + retry_after = 3600 + reset_seconds = 3600 + elif window == "day": + limit = self.requests_per_day + retry_after = 86400 + reset_seconds = 86400 + else: # minute + limit = self.requests_per_minute + retry_after = 60 + reset_seconds = 60 + + return JSONResponse( + { + "error": "rate_limit_exceeded", + "detail": "Too many requests. Please try again later.", + }, + status_code=429, + headers={ + "Retry-After": str(retry_after), + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(time.time() + reset_seconds)), + }, + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 From 5e184d01fb0c0bbb85d9503a857930e0fc51a9d2 Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 18:51:44 -0300 Subject: [PATCH 3/5] fix: resolve test failures - Fix rate limiting in tests by disabling auth_rate_limiter in fixtures - Fix JWT configuration errors: replace reserved claim 'sub' with 'user_id' in jwt_extra_claims - Update test assertions to check for 'user_id' instead of 'sub' - Fix YAML config test by adding login_validator parameter --- tests/test_login_auth.py | 18 +++++++++++------- tests/test_yaml_config.py | 6 +++--- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/test_login_auth.py b/tests/test_login_auth.py index 4903177..063cd3a 100644 --- a/tests/test_login_auth.py +++ b/tests/test_login_auth.py @@ -58,6 +58,8 @@ def jwt_client(): ) app = LightApi(engine=engine, login_validator=_valid_validator) app.register({"/secrets": JWTProtectedEndpoint}) + # Disable rate limiting for tests + app._auth_rate_limiter = None return TestClient(app.build_app()) @@ -70,6 +72,8 @@ def basic_client(): ) app = LightApi(engine=engine, login_validator=_valid_validator) app.register({"/items": BasicProtectedEndpoint}) + # Disable rate limiting for tests + app._auth_rate_limiter = None return TestClient(app.build_app()) @@ -294,7 +298,7 @@ def test_jwt_extra_claims_filters_payload(self): def validator_with_extra(username: str, password: str): if username == "alice" and password == "secret": return { - "sub": "1", + "user_id": "1", "email": "a@b.com", "secret": "must-not-appear", } @@ -306,7 +310,7 @@ class JWTWithExtraEndpoint(RestEndpoint): class Meta: authentication = Authentication( backend=JWTAuthentication, - jwt_extra_claims=["sub", "email"], + jwt_extra_claims=["user_id", "email"], ) os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -325,7 +329,7 @@ class Meta: assert resp.status_code == 200 token = resp.json()["token"] payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) - assert "sub" in payload + assert "user_id" in payload assert "email" in payload assert "secret" not in payload @@ -717,7 +721,7 @@ class Meta: def test_jwt_extra_claims_partial_overlap_only_included_in_token(self): def validator_partial(username: str, password: str): if username == "alice" and password == "secret": - return {"sub": "1", "email": "a@b.com"} + return {"user_id": "1", "email": "a@b.com"} return None class JWTPartialExtraEndpoint(RestEndpoint): @@ -726,7 +730,7 @@ class JWTPartialExtraEndpoint(RestEndpoint): class Meta: authentication = Authentication( backend=JWTAuthentication, - jwt_extra_claims=["sub", "missing"], + jwt_extra_claims=["user_id", "missing"], ) os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -745,8 +749,8 @@ class Meta: assert resp.status_code == 200 token = resp.json()["token"] payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) - assert "sub" in payload - assert payload["sub"] == "1" + assert "user_id" in payload + assert payload["user_id"] == "1" assert "missing" not in payload diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index fc76a8d..4d9fadf 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -372,7 +372,7 @@ def test_validator(username: str, password: str): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) # The validator from YAML should be resolved and used assert app._login_validator is test_validator finally: @@ -409,7 +409,7 @@ def test_jwt_extra_claims_from_defaults(self): authentication: backend: JWTAuthentication permission: IsAuthenticated - jwt_extra_claims: [sub, email] + jwt_extra_claims: [user_id, email] endpoints: - route: /x fields: @@ -419,7 +419,7 @@ def test_jwt_extra_claims_from_defaults(self): """ app = _from_str(content) cls = app._endpoint_map["/x"] - assert cls.Meta.authentication.jwt_extra_claims == ["sub", "email"] + assert cls.Meta.authentication.jwt_extra_claims == ["user_id", "email"] def test_basic_authentication_from_yaml(self): """BasicAuthentication can be specified as backend in YAML.""" From 251c68d28cf8600bed6cdc21505ad644b39704b6 Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 19:39:07 -0300 Subject: [PATCH 4/5] fix: resolve YAML config test failures - Fix indentation errors in test_yaml_config.py - Add login_validator parameter to YAML tests using authentication - Fix test_login_validator_dotted_path_from_yaml to not override YAML-specified validator - Update jwt_extra_claims from 'sub' to 'user_id' to avoid reserved claim error --- tests/test_yaml_config.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index 4d9fadf..bb3f089 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -114,7 +114,7 @@ def test_dynamic_fields_create_endpoint_class(self): meta: methods: [GET, POST] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) assert "/articles" in app._endpoint_map def test_dynamic_fields_are_on_class_annotations(self): @@ -129,7 +129,7 @@ def test_dynamic_fields_are_on_class_annotations(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/items2") assert "title" in cls.__annotations__ assert "count" in cls.__annotations__ @@ -150,7 +150,7 @@ def test_defaults_applied_to_endpoint(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/secure") meta = cls.Meta assert hasattr(meta, "authentication") @@ -175,7 +175,7 @@ def test_endpoint_auth_overrides_defaults(self): authentication: permission: AllowAny """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/public") from lightapi.auth import AllowAny @@ -198,7 +198,7 @@ def test_per_method_auth_in_meta(self): authentication: backend: JWTAuthentication """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/itemsauth") from lightapi.auth import AllowAny, IsAdminUser @@ -222,7 +222,7 @@ def test_filtering_config_auto_selects_backends(self): fields: [published] ordering: [title] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/posts") from lightapi.filters import FieldFilter, OrderingFilter @@ -247,7 +247,7 @@ def test_pagination_config_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/things") meta = cls.Meta assert hasattr(meta, "pagination") @@ -276,7 +276,7 @@ def test_middleware_resolved_by_name(self): url: "sqlite:///:memory:" middleware: [CORSMiddleware] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) assert app is not None def test_from_config_kwargs_override_yaml(self): @@ -332,7 +332,7 @@ def test_auth_path_from_yaml(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) from starlette.testclient import TestClient client = TestClient(app.build_app()) @@ -372,7 +372,7 @@ def test_validator(username: str, password: str): meta: methods: [GET] """ - app = _from_str(content, login_validator=_dummy_login_validator) + app = _from_str(content) # The validator from YAML should be resolved and used assert app._login_validator is test_validator finally: @@ -396,7 +396,7 @@ def test_jwt_expiration_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = app._endpoint_map["/x"] assert cls.Meta.authentication.jwt_expiration == 300 @@ -417,7 +417,7 @@ def test_jwt_extra_claims_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = app._endpoint_map["/x"] assert cls.Meta.authentication.jwt_extra_claims == ["user_id", "email"] @@ -437,7 +437,7 @@ def test_basic_authentication_from_yaml(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) from lightapi.auth import BasicAuthentication cls = app._endpoint_map["/items"] From 3f30780b196b4090e232ede685d456cb735912da Mon Sep 17 00:00:00 2001 From: iklobato Date: Tue, 10 Mar 2026 19:52:05 -0300 Subject: [PATCH 5/5] fix: resolve cache test failures - Update cache test mocks to patch 'lightapi.lightapi' instead of 'lightapi.cache' - Fixes issue where mocks weren't being called because functions are imported at module level - All cache tests now pass --- tests/test_cache_v2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_cache_v2.py b/tests/test_cache_v2.py index 23d46fb..5c16380 100644 --- a/tests/test_cache_v2.py +++ b/tests/test_cache_v2.py @@ -40,8 +40,8 @@ def client(): class TestCacheGetList: def test_cache_get_list_returns_cached_on_second_request(self, client): """First GET hits DB, second GET returns cached (when get_cached returns data).""" - with patch("lightapi.cache.get_cached") as mock_get: - with patch("lightapi.cache.set_cached") as mock_set: + with patch("lightapi.lightapi.get_cached") as mock_get: + with patch("lightapi.lightapi.set_cached") as mock_set: mock_get.return_value = None # First call: cache miss resp1 = client.get("/cached") assert resp1.status_code == 200 @@ -60,13 +60,13 @@ def test_cache_get_list_returns_cached_on_second_request(self, client): class TestCacheInvalidation: def test_cache_post_invalidates(self, client): """POST triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: client.post("/cached", json={"name": "new"}) mock_inv.assert_called() def test_cache_put_invalidates(self, client): """PUT triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: post_resp = client.post("/cached", json={"name": "item"}) item_id = post_resp.json()["id"] version = post_resp.json()["version"] @@ -78,7 +78,7 @@ def test_cache_put_invalidates(self, client): def test_cache_delete_invalidates(self, client): """DELETE triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: post_resp = client.post("/cached", json={"name": "to_delete"}) item_id = post_resp.json()["id"] client.delete(f"/cached/{item_id}") @@ -101,8 +101,8 @@ def test_cache_redis_unreachable_startup_warning(self): def test_cache_redis_unreachable_mid_request_serves_db(self, client): """When get_cached raises/fails, GET still returns 200 from DB.""" - with patch("lightapi.cache.get_cached", side_effect=Exception("Redis down")): - with patch("lightapi.cache.set_cached"): + with patch("lightapi.lightapi.get_cached", side_effect=Exception("Redis down")): + with patch("lightapi.lightapi.set_cached"): resp = client.get("/cached") assert resp.status_code == 200 assert "results" in resp.json() @@ -121,8 +121,8 @@ def test_cache_vary_on_query_params_uses_different_keys(self): c = TestClient(app.build_app()) c.post("/cached_vary", json={"label": "x"}) - with patch("lightapi.cache.get_cached") as mock_get: - with patch("lightapi.cache.set_cached") as mock_set: + with patch("lightapi.lightapi.get_cached") as mock_get: + with patch("lightapi.lightapi.set_cached") as mock_set: mock_get.return_value = None c.get("/cached_vary?page=1") c.get("/cached_vary?page=2")