-
Notifications
You must be signed in to change notification settings - Fork 7
feat(auth): implement comprehensive JWT and Basic authentication #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| """Login and token endpoint handlers.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import json | ||
| import logging | ||
| from collections.abc import Callable | ||
| from typing import Any, Optional | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field, ValidationError | ||
| from starlette.requests import Request | ||
| from starlette.responses import JSONResponse | ||
|
|
||
| # JWTAuthentication imported locally where needed to avoid circular import | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class LoginRequest(BaseModel): | ||
| """Request body for POST /auth/login and /auth/token.""" | ||
|
|
||
| model_config = ConfigDict(frozen=True) | ||
|
|
||
| username: str = Field(min_length=1) | ||
| password: str = Field(min_length=1) | ||
|
|
||
|
|
||
| def _parse_basic_header(auth_header: str) -> tuple[str, str] | None: | ||
| """ | ||
| Decode Authorization: Basic header. | ||
|
|
||
| Returns (username, password) or None if malformed. | ||
| """ | ||
| if not auth_header.lower().startswith("basic "): | ||
| return None | ||
| try: | ||
| token = auth_header.split(" ", 1)[1] | ||
| decoded = base64.b64decode(token).decode("utf-8") | ||
| except (ValueError, IndexError, UnicodeDecodeError): | ||
| return None | ||
| parts = decoded.split(":", 1) | ||
| if len(parts) != 2: | ||
| return None | ||
| return parts[0], parts[1] | ||
|
|
||
|
|
||
| async def _parse_credentials(request: Request) -> tuple[str, str] | None: | ||
| """ | ||
| Extract (username, password) from request. | ||
|
|
||
| - If Authorization: Basic present: returns (u, p) or None if malformed. | ||
| - If no Basic header: reads body, validates with LoginRequest. | ||
| Returns (u, p) if valid. Raises ValidationError for body (caller returns 422). | ||
| - None means malformed Basic (caller returns 401). | ||
| """ | ||
| auth_header = request.headers.get("Authorization") | ||
| if auth_header and auth_header.startswith("Basic "): | ||
| return _parse_basic_header(auth_header) | ||
|
|
||
| body = await _read_body(request) | ||
| parsed = LoginRequest.model_validate(body if body else {}) | ||
| return parsed.username, parsed.password | ||
|
|
||
|
|
||
| async def _read_body(request: Request) -> dict[str, Any]: | ||
| """Read JSON body; return {} on empty or invalid.""" | ||
| try: | ||
| body = await request.body() | ||
| return json.loads(body) if body else {} | ||
| except (json.JSONDecodeError, TypeError): | ||
| return {} | ||
|
|
||
|
|
||
| async def login_handler( | ||
| request: Request, | ||
| *, | ||
| login_validator: Callable[[str, str], dict[str, Any] | None], | ||
| has_jwt: bool, | ||
| jwt_expiration: int | None = None, | ||
| jwt_extra_claims: list[str] | None = None, | ||
| jwt_algorithm: str | None = None, | ||
| rate_limiter: Optional[Any] = None, | ||
| ) -> JSONResponse: | ||
| """ | ||
| Handle POST /auth/login and POST /auth/token. | ||
|
|
||
| Returns 422 for body validation, 401 for malformed Basic or invalid credentials, | ||
| 500 for validator exception, 200 with token+user (JWT) or user only (Basic). | ||
| """ | ||
| # Apply rate limiting if a rate limiter is provided | ||
| if rate_limiter is not None: | ||
| is_limited, window = rate_limiter.is_rate_limited(request, endpoint="auth") | ||
| if is_limited: | ||
| return rate_limiter.get_rate_limit_response(request, window) | ||
|
|
||
| if request.method != "POST": | ||
| return JSONResponse( | ||
| {"detail": "method not allowed"}, | ||
| status_code=405, | ||
| headers={"Allow": "POST"}, | ||
| ) | ||
|
|
||
| try: | ||
| creds = await _parse_credentials(request) | ||
| except ValidationError as exc: | ||
| return JSONResponse({"detail": exc.errors()}, status_code=422) | ||
|
|
||
| if creds is None: | ||
| return JSONResponse({"detail": "Invalid credentials"}, status_code=401) | ||
|
|
||
| username, password = creds | ||
| try: | ||
| payload = login_validator(username, password) | ||
| except Exception as e: | ||
| logger.exception("login_validator raised: %s", e) | ||
| raise | ||
|
|
||
| if payload is None: | ||
| return JSONResponse({"detail": "Invalid credentials"}, status_code=401) | ||
|
|
||
| if has_jwt: | ||
| from lightapi.auth import JWTAuthentication | ||
|
|
||
| jwt_auth = JWTAuthentication(algorithm=jwt_algorithm) | ||
| if jwt_extra_claims and isinstance(payload, dict): | ||
| token_payload = {k: payload[k] for k in jwt_extra_claims if k in payload} | ||
| if not token_payload: | ||
| token_payload = payload | ||
| else: | ||
| token_payload = payload | ||
| token = jwt_auth.generate_token(token_payload, expiration=jwt_expiration) | ||
| return JSONResponse({"token": token, "user": payload}) | ||
|
|
||
| return JSONResponse({"user": payload}) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,31 +7,53 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from typing import Any, cast | ||
|
|
||
| from sqlalchemy import MetaData | ||
| from sqlalchemy.orm import registry | ||
|
|
||
| _registry: registry | None = None | ||
| _metadata: MetaData | None = None | ||
| _engine: object | None = None | ||
| LoginValidator = Callable[[str, str], dict[str, Any] | None] | ||
|
|
||
| _state: dict[str, object | None] = { | ||
| "registry": None, | ||
| "metadata": None, | ||
| "engine": None, | ||
| "login_validator": None, | ||
| } | ||
|
Comment on lines
+18
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A single global
Also applies to: 51-59 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def get_registry_and_metadata() -> tuple[registry, MetaData]: | ||
| global _registry, _metadata | ||
| if _registry is None or _metadata is None: | ||
| _metadata = MetaData() | ||
| _registry = registry(metadata=_metadata) | ||
| return _registry, _metadata | ||
| reg = cast(registry | None, _state["registry"]) | ||
| meta = cast(MetaData | None, _state["metadata"]) | ||
| if reg is None or meta is None: | ||
| meta = MetaData() | ||
| reg = registry(metadata=meta) | ||
| _state["metadata"] = meta | ||
| _state["registry"] = reg | ||
| return reg, meta | ||
|
|
||
|
|
||
| def set_engine(engine: object) -> None: | ||
| global _engine | ||
| _engine = engine | ||
| _state["engine"] = engine | ||
|
|
||
|
|
||
| def get_engine() -> object: | ||
| if _engine is None: | ||
| engine = _state["engine"] | ||
| if engine is None: | ||
| raise RuntimeError( | ||
| "No engine configured. Call LightApi(engine=...) or ensure " | ||
| "database connection is set before the first request." | ||
| ) | ||
| return _engine | ||
| return engine | ||
|
|
||
|
|
||
| def set_login_validator(validator: LoginValidator) -> None: | ||
| _state["login_validator"] = validator | ||
|
|
||
|
|
||
| def get_login_validator() -> LoginValidator | None: | ||
| validator = _state["login_validator"] | ||
| if validator is None: | ||
| return None | ||
| return cast(LoginValidator, validator) | ||
Uh oh!
There was an error while loading. Please reload this page.