diff --git a/lightapi/_login.py b/lightapi/_login.py index 0e205d6..8bfc25d 100644 --- a/lightapi/_login.py +++ b/lightapi/_login.py @@ -12,6 +12,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse +from .rate_limiter import RateLimiter + # JWTAuthentication imported locally where needed to avoid circular import logger = logging.getLogger(__name__) @@ -67,7 +69,10 @@ async def _read_body(request: Request) -> dict[str, Any]: """Read JSON body; return {} on empty or invalid.""" try: body = await request.body() - return json.loads(body) if body else {} + if body: + result: dict[str, Any] = json.loads(body) + return result + return {} except (json.JSONDecodeError, TypeError): return {} @@ -80,7 +85,7 @@ async def login_handler( jwt_expiration: int | None = None, jwt_extra_claims: list[str] | None = None, jwt_algorithm: str | None = None, - rate_limiter: Optional[Any] = None, + rate_limiter: Optional[RateLimiter] = None, ) -> JSONResponse: """ Handle POST /auth/login and POST /auth/token. @@ -91,7 +96,7 @@ async def login_handler( # Apply rate limiting if a rate limiter is provided if rate_limiter is not None: is_limited, window = rate_limiter.is_rate_limited(request, endpoint="auth") - if is_limited: + if is_limited and window is not None: return rate_limiter.get_rate_limit_response(request, window) if request.method != "POST": diff --git a/lightapi/auth.py b/lightapi/auth.py index d3ead3a..2de7a4e 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -1,5 +1,6 @@ +import base64 from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from typing import Any, Optional import jwt from starlette.requests import Request @@ -29,90 +30,130 @@ def authenticate(self, request: Request) -> bool: """ return True + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. -def get_auth_error_response(self, request: Request) -> JSONResponse: - """ - Get the response to return when authentication fails. - - Args: - request: The HTTP request object. + Args: + request: The HTTP request object. - Returns: - Response object for authentication error. - """ - return JSONResponse( - {"error": "authentication failed"}, - status_code=401, - headers={"WWW-Authenticate": 'Basic realm="Restricted Area"'}, - ) + Returns: + Response object for authentication error. + """ + return JSONResponse({"error": "authentication failed"}, status_code=401) -class BasicAuthentication(BaseAuthentication): +class JWTAuthentication(BaseAuthentication): """ - Basic (Base64) authentication. + JWT (JSON Web Token) based authentication. - Authenticates requests using Authorization: Basic . - Delegates credential validation to the app-level login_validator from the registry. + Authenticates requests using JWT tokens from the Authorization header. + Validates token signatures and expiration times. + Automatically skips authentication for OPTIONS requests (CORS preflight). + + Attributes: + secret_key: Secret key for signing tokens. + algorithm: JWT algorithm to use. + expiration: Token expiration time in seconds. """ def __init__( self, - login_validator: Optional[LoginValidator] = None, - ) -> None: - self.login_validator = login_validator + secret_key: str | None = None, + algorithm: str | None = None, + expiration: int | None = None, + ): + self.secret_key = secret_key or config.jwt_secret + if not self.secret_key: + raise ValueError( + "JWT secret key not configured. Set LIGHTAPI_JWT_SECRET environment variable." + ) + + self.algorithm = algorithm or config.jwt_algorithm + self.expiration = expiration or 3600 # 1 hour default def authenticate(self, request: Request) -> bool: + """ + Authenticate a request using JWT token. + Automatically allows OPTIONS requests for CORS preflight. + + Args: + request: The HTTP request object. + + Returns: + bool: True if authentication succeeds, False otherwise. + """ + # Skip authentication for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return True auth_header = request.headers.get("Authorization") - if not auth_header: - return False - - # Use the shared Basic auth parsing function - from lightapi._login import _parse_basic_header - - credentials = _parse_basic_header(auth_header) - if credentials is None: - return False - - username, password = credentials - from lightapi._registry import get_login_validator - - validator = self.login_validator or get_login_validator() - if validator is None: + if not auth_header or not auth_header.lower().startswith("bearer "): return False try: - payload = validator(username, password) - except Exception: - return False - - if payload is None: + token = auth_header.split(" ", 1)[1] + payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + except (jwt.InvalidTokenError, ValueError, IndexError): return False request.state.user = payload return True - def get_auth_error_response(self, request: Request) -> JSONResponse: + def generate_token( + self, payload: dict[str, Any], expiration: int | None = None + ) -> str: """ - Get the response to return when authentication fails. + Generate a JWT token with the given payload. Args: - request: The HTTP request object. + payload: Dictionary of claims to include in the token. + expiration: Optional expiration time in seconds (overrides default). Returns: - Response object for authentication error. + str: The encoded JWT token. + + Raises: + ValueError: If payload contains 'exp' claim which will be overwritten. """ - return JSONResponse({"error": "authentication failed"}, status_code=401) + # Check for 'exp' in payload since we overwrite it + if "exp" in payload: + raise ValueError( + "Payload contains 'exp' claim which will be overwritten. " + "Use the 'expiration' parameter instead." + ) + + exp_seconds = expiration or self.expiration + token_data = { + **payload, + "exp": datetime.utcnow() + timedelta(seconds=exp_seconds), + } + return jwt.encode(token_data, self.secret_key, algorithm=self.algorithm) + + +class BasicAuthentication(BaseAuthentication): + """ + Basic (Base64) authentication. + + Authenticates requests using Authorization: Basic . + Delegates credential validation to the provided login_validator. + """ + + def __init__( + self, + login_validator: Optional[LoginValidator] = None, + ) -> None: + self.login_validator = login_validator + + def authenticate(self, request: Request) -> bool: + if request.method == "OPTIONS": + return True auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.lower().startswith("basic "): return False try: - import base64 - token = auth_header.split(" ", 1)[1] decoded = base64.b64decode(token).decode("utf-8") except (ValueError, IndexError, UnicodeDecodeError): @@ -123,15 +164,13 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: return False username, password = parts[0], parts[1] - from lightapi._registry import get_login_validator - - validator = get_login_validator() + validator = self.login_validator or get_login_validator() if validator is None: return False try: payload = validator(username, password) - except Exception: + except (ValueError, TypeError, RuntimeError): return False if payload is None: @@ -140,11 +179,27 @@ def get_auth_error_response(self, request: Request) -> JSONResponse: request.state.user = payload return True + def get_auth_error_response(self, request: Request) -> JSONResponse: + """ + Get the response to return when authentication fails. + + Args: + request: The HTTP request object. + + Returns: + Response object for authentication error. + """ + return JSONResponse( + {"error": "authentication failed"}, + status_code=401, + headers={"WWW-Authenticate": 'Basic realm="Restricted Area"'}, + ) + class AllowAny: """Permits all requests regardless of authentication state.""" - def has_permission(self, request: Request) -> bool: + def has_permission(self, _request: Request) -> bool: return True diff --git a/lightapi/config.py b/lightapi/config.py index 13434ab..da66152 100644 --- a/lightapi/config.py +++ b/lightapi/config.py @@ -56,6 +56,9 @@ def jwt_algorithm(self) -> str: class Authentication: """Authentication configuration for a RestEndpoint.""" + # Standard JWT reserved claims that cannot be used as extra claims + RESERVED_CLAIMS = {"exp", "iat", "nbf", "iss", "sub", "aud", "jti"} + def __init__( self, backend: type | None = None, @@ -74,16 +77,16 @@ def __init__( # Validate jwt_extra_claims - reject reserved claims if jwt_extra_claims: - RESERVED_CLAIMS = {"exp", "iat", "nbf", "iss", "sub", "aud", "jti"} reserved_found = [] for claim in jwt_extra_claims: - if claim in RESERVED_CLAIMS: + if claim in self.RESERVED_CLAIMS: reserved_found.append(claim) if reserved_found: raise ConfigurationError( f"JWT extra claims cannot include reserved claims: " - f"{reserved_found}. Reserved claims are: {sorted(RESERVED_CLAIMS)}" + f"{reserved_found}. Reserved claims are: " + f"{sorted(self.RESERVED_CLAIMS)}" ) self.jwt_extra_claims = jwt_extra_claims diff --git a/lightapi/rate_limiter.py b/lightapi/rate_limiter.py index f7a8c2a..1268e6b 100644 --- a/lightapi/rate_limiter.py +++ b/lightapi/rate_limiter.py @@ -11,9 +11,13 @@ class RateLimiter: """ - Simple in-memory rate limiter. + Simple in-memory rate limiter. - Tracks requests by IP address and endpoint. + Tracks requests by IP address and endpoint. + NOTE: This implementation uses process-local counters. In a multi-process + deployment (e.g., with multiple workers), rate limiting will not be shared + across processes. For production use with multiple workers, consider using + a shared storage backend like Redis. """ def __init__( diff --git a/lightapi/yaml_loader.py b/lightapi/yaml_loader.py index 9e85db1..feb4339 100644 --- a/lightapi/yaml_loader.py +++ b/lightapi/yaml_loader.py @@ -112,9 +112,16 @@ def _resolve_callable(dotted_path: str) -> Any: try: sig = inspect.signature(fn) - except ValueError: - # Some callables (e.g., builtins) don't have inspectable signatures - return fn + except (ValueError, TypeError) as exc: + # Only allow specific cases where signature inspection legitimately fails + if hasattr(fn, "__name__") and fn.__name__ in ("",): + # Lambdas can't be properly inspected in some Python versions + return fn + # For other cases, raise a clear error about the validation function + raise ValueError( + f"Login validation function {fn!r} cannot be inspected: {exc}. " + f"Ensure it's a regular Python function with inspectable signature." + ) from exc # Count required positional parameters required_params = 0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cache_v2.py b/tests/test_cache_v2.py index 23d46fb..5c16380 100644 --- a/tests/test_cache_v2.py +++ b/tests/test_cache_v2.py @@ -40,8 +40,8 @@ def client(): class TestCacheGetList: def test_cache_get_list_returns_cached_on_second_request(self, client): """First GET hits DB, second GET returns cached (when get_cached returns data).""" - with patch("lightapi.cache.get_cached") as mock_get: - with patch("lightapi.cache.set_cached") as mock_set: + with patch("lightapi.lightapi.get_cached") as mock_get: + with patch("lightapi.lightapi.set_cached") as mock_set: mock_get.return_value = None # First call: cache miss resp1 = client.get("/cached") assert resp1.status_code == 200 @@ -60,13 +60,13 @@ def test_cache_get_list_returns_cached_on_second_request(self, client): class TestCacheInvalidation: def test_cache_post_invalidates(self, client): """POST triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: client.post("/cached", json={"name": "new"}) mock_inv.assert_called() def test_cache_put_invalidates(self, client): """PUT triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: post_resp = client.post("/cached", json={"name": "item"}) item_id = post_resp.json()["id"] version = post_resp.json()["version"] @@ -78,7 +78,7 @@ def test_cache_put_invalidates(self, client): def test_cache_delete_invalidates(self, client): """DELETE triggers cache invalidation.""" - with patch("lightapi.cache.invalidate_cache_prefix") as mock_inv: + with patch("lightapi.lightapi.invalidate_cache_prefix") as mock_inv: post_resp = client.post("/cached", json={"name": "to_delete"}) item_id = post_resp.json()["id"] client.delete(f"/cached/{item_id}") @@ -101,8 +101,8 @@ def test_cache_redis_unreachable_startup_warning(self): def test_cache_redis_unreachable_mid_request_serves_db(self, client): """When get_cached raises/fails, GET still returns 200 from DB.""" - with patch("lightapi.cache.get_cached", side_effect=Exception("Redis down")): - with patch("lightapi.cache.set_cached"): + with patch("lightapi.lightapi.get_cached", side_effect=Exception("Redis down")): + with patch("lightapi.lightapi.set_cached"): resp = client.get("/cached") assert resp.status_code == 200 assert "results" in resp.json() @@ -121,8 +121,8 @@ def test_cache_vary_on_query_params_uses_different_keys(self): c = TestClient(app.build_app()) c.post("/cached_vary", json={"label": "x"}) - with patch("lightapi.cache.get_cached") as mock_get: - with patch("lightapi.cache.set_cached") as mock_set: + with patch("lightapi.lightapi.get_cached") as mock_get: + with patch("lightapi.lightapi.set_cached") as mock_set: mock_get.return_value = None c.get("/cached_vary?page=1") c.get("/cached_vary?page=2") diff --git a/tests/test_login_auth.py b/tests/test_login_auth.py index 4903177..063cd3a 100644 --- a/tests/test_login_auth.py +++ b/tests/test_login_auth.py @@ -58,6 +58,8 @@ def jwt_client(): ) app = LightApi(engine=engine, login_validator=_valid_validator) app.register({"/secrets": JWTProtectedEndpoint}) + # Disable rate limiting for tests + app._auth_rate_limiter = None return TestClient(app.build_app()) @@ -70,6 +72,8 @@ def basic_client(): ) app = LightApi(engine=engine, login_validator=_valid_validator) app.register({"/items": BasicProtectedEndpoint}) + # Disable rate limiting for tests + app._auth_rate_limiter = None return TestClient(app.build_app()) @@ -294,7 +298,7 @@ def test_jwt_extra_claims_filters_payload(self): def validator_with_extra(username: str, password: str): if username == "alice" and password == "secret": return { - "sub": "1", + "user_id": "1", "email": "a@b.com", "secret": "must-not-appear", } @@ -306,7 +310,7 @@ class JWTWithExtraEndpoint(RestEndpoint): class Meta: authentication = Authentication( backend=JWTAuthentication, - jwt_extra_claims=["sub", "email"], + jwt_extra_claims=["user_id", "email"], ) os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -325,7 +329,7 @@ class Meta: assert resp.status_code == 200 token = resp.json()["token"] payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) - assert "sub" in payload + assert "user_id" in payload assert "email" in payload assert "secret" not in payload @@ -717,7 +721,7 @@ class Meta: def test_jwt_extra_claims_partial_overlap_only_included_in_token(self): def validator_partial(username: str, password: str): if username == "alice" and password == "secret": - return {"sub": "1", "email": "a@b.com"} + return {"user_id": "1", "email": "a@b.com"} return None class JWTPartialExtraEndpoint(RestEndpoint): @@ -726,7 +730,7 @@ class JWTPartialExtraEndpoint(RestEndpoint): class Meta: authentication = Authentication( backend=JWTAuthentication, - jwt_extra_claims=["sub", "missing"], + jwt_extra_claims=["user_id", "missing"], ) os.environ["LIGHTAPI_JWT_SECRET"] = "test-secret-key" @@ -745,8 +749,8 @@ class Meta: assert resp.status_code == 200 token = resp.json()["token"] payload = jwt.decode(token, "test-secret-key", algorithms=["HS256"]) - assert "sub" in payload - assert payload["sub"] == "1" + assert "user_id" in payload + assert payload["user_id"] == "1" assert "missing" not in payload diff --git a/tests/test_yaml_config.py b/tests/test_yaml_config.py index fc76a8d..bb3f089 100644 --- a/tests/test_yaml_config.py +++ b/tests/test_yaml_config.py @@ -114,7 +114,7 @@ def test_dynamic_fields_create_endpoint_class(self): meta: methods: [GET, POST] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) assert "/articles" in app._endpoint_map def test_dynamic_fields_are_on_class_annotations(self): @@ -129,7 +129,7 @@ def test_dynamic_fields_are_on_class_annotations(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/items2") assert "title" in cls.__annotations__ assert "count" in cls.__annotations__ @@ -150,7 +150,7 @@ def test_defaults_applied_to_endpoint(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/secure") meta = cls.Meta assert hasattr(meta, "authentication") @@ -175,7 +175,7 @@ def test_endpoint_auth_overrides_defaults(self): authentication: permission: AllowAny """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/public") from lightapi.auth import AllowAny @@ -198,7 +198,7 @@ def test_per_method_auth_in_meta(self): authentication: backend: JWTAuthentication """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/itemsauth") from lightapi.auth import AllowAny, IsAdminUser @@ -222,7 +222,7 @@ def test_filtering_config_auto_selects_backends(self): fields: [published] ordering: [title] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/posts") from lightapi.filters import FieldFilter, OrderingFilter @@ -247,7 +247,7 @@ def test_pagination_config_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = self._route_cls(app, "/things") meta = cls.Meta assert hasattr(meta, "pagination") @@ -276,7 +276,7 @@ def test_middleware_resolved_by_name(self): url: "sqlite:///:memory:" middleware: [CORSMiddleware] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) assert app is not None def test_from_config_kwargs_override_yaml(self): @@ -332,7 +332,7 @@ def test_auth_path_from_yaml(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) from starlette.testclient import TestClient client = TestClient(app.build_app()) @@ -396,7 +396,7 @@ def test_jwt_expiration_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = app._endpoint_map["/x"] assert cls.Meta.authentication.jwt_expiration == 300 @@ -409,7 +409,7 @@ def test_jwt_extra_claims_from_defaults(self): authentication: backend: JWTAuthentication permission: IsAuthenticated - jwt_extra_claims: [sub, email] + jwt_extra_claims: [user_id, email] endpoints: - route: /x fields: @@ -417,9 +417,9 @@ def test_jwt_extra_claims_from_defaults(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) cls = app._endpoint_map["/x"] - assert cls.Meta.authentication.jwt_extra_claims == ["sub", "email"] + assert cls.Meta.authentication.jwt_extra_claims == ["user_id", "email"] def test_basic_authentication_from_yaml(self): """BasicAuthentication can be specified as backend in YAML.""" @@ -437,7 +437,7 @@ def test_basic_authentication_from_yaml(self): meta: methods: [GET] """ - app = _from_str(content) + app = _from_str(content, login_validator=_dummy_login_validator) from lightapi.auth import BasicAuthentication cls = app._endpoint_map["/items"]