-
Notifications
You must be signed in to change notification settings - Fork 147
switch backend to listen on free port and require auth #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+297
−128
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,16 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import hmac | ||
| from collections.abc import Awaitable, Callable | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from fastapi import FastAPI, Request | ||
| from fastapi.exceptions import RequestValidationError | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import JSONResponse | ||
| from starlette.responses import Response as StarletteResponse | ||
|
|
||
| from _routes._errors import HTTPError | ||
| from _routes.generation import router as generation_router | ||
|
|
@@ -36,6 +40,7 @@ def create_app( | |
| handler: "AppHandler", | ||
| allowed_origins: list[str] | None = None, | ||
| title: str = "LTX-2 Video Generation Server", | ||
| auth_token: str = "", | ||
| ) -> FastAPI: | ||
| """Create a configured FastAPI app bound to the provided handler.""" | ||
| init_state_service(handler) | ||
|
|
@@ -48,6 +53,37 @@ def create_app( | |
| allow_headers=["*"], | ||
| ) | ||
|
|
||
| @app.middleware("http") | ||
| async def _auth_middleware( # pyright: ignore[reportUnusedFunction] | ||
| request: Request, | ||
| call_next: Callable[[Request], Awaitable[StarletteResponse]], | ||
| ) -> StarletteResponse: | ||
| if not auth_token: | ||
| return await call_next(request) | ||
| if request.method == "OPTIONS": | ||
| return await call_next(request) | ||
| def _token_matches(candidate: str) -> bool: | ||
| return hmac.compare_digest(candidate, auth_token) | ||
|
|
||
| # WebSocket: check query param | ||
| if request.headers.get("upgrade", "").lower() == "websocket": | ||
| if _token_matches(request.query_params.get("token", "")): | ||
| return await call_next(request) | ||
| return JSONResponse(status_code=401, content={"error": "Unauthorized"}) | ||
| # HTTP: Bearer or Basic auth | ||
| auth_header = request.headers.get("authorization", "") | ||
| if auth_header.startswith("Bearer ") and _token_matches(auth_header[7:]): | ||
| return await call_next(request) | ||
| if auth_header.startswith("Basic "): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why needed in addition to bearer token?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it isn't but curl is slightly more convenient to use with basic: |
||
| try: | ||
| decoded = base64.b64decode(auth_header[6:]).decode() | ||
| _, _, password = decoded.partition(":") | ||
| if _token_matches(password): | ||
| return await call_next(request) | ||
| except Exception: | ||
| pass | ||
| return JSONResponse(status_code=401, content={"error": "Unauthorized"}) | ||
|
|
||
| async def _route_http_error_handler(request: Request, exc: Exception) -> JSONResponse: | ||
| if isinstance(exc, HTTPError): | ||
| log_http_error(request, exc) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| """Tests for shared-secret authentication middleware.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
|
|
||
| from starlette.testclient import TestClient | ||
|
|
||
| from app_factory import create_app | ||
|
|
||
|
|
||
| def test_request_without_token_returns_401(test_state): | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| with TestClient(app) as client: | ||
| response = client.get("/health") | ||
| assert response.status_code == 401 | ||
| assert response.json() == {"error": "Unauthorized"} | ||
|
|
||
|
|
||
| def test_request_with_correct_bearer_token(test_state): | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| with TestClient(app) as client: | ||
| response = client.get("/health", headers={"Authorization": "Bearer test-secret"}) | ||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_request_with_correct_basic_auth(test_state): | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| credentials = base64.b64encode(b":test-secret").decode() | ||
| with TestClient(app) as client: | ||
| response = client.get("/health", headers={"Authorization": f"Basic {credentials}"}) | ||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_request_with_wrong_token_returns_401(test_state): | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| with TestClient(app) as client: | ||
| response = client.get("/health", headers={"Authorization": "Bearer wrong-token"}) | ||
| assert response.status_code == 401 | ||
|
|
||
|
|
||
| def test_health_without_token_returns_401(test_state): | ||
| """Health endpoint is NOT exempt from auth.""" | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| with TestClient(app) as client: | ||
| response = client.get("/health") | ||
| assert response.status_code == 401 | ||
|
|
||
|
|
||
| def test_no_auth_token_disables_middleware(test_state): | ||
| """When auth_token is empty string, auth is disabled (dev/test mode).""" | ||
| app = create_app(handler=test_state, auth_token="") | ||
| with TestClient(app) as client: | ||
| response = client.get("/health") | ||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_websocket_with_token_query_param(test_state): | ||
| app = create_app(handler=test_state, auth_token="test-secret") | ||
| with TestClient(app) as client: | ||
| # WebSocket upgrade without token should fail with 401 | ||
| response = client.get( | ||
| "/ws/download/test", | ||
| headers={"upgrade": "websocket", "connection": "upgrade"}, | ||
| ) | ||
| assert response.status_code == 401 | ||
|
|
||
| # WebSocket upgrade with correct token query param | ||
| response = client.get( | ||
| "/ws/download/test?token=test-secret", | ||
| headers={"upgrade": "websocket", "connection": "upgrade"}, | ||
| ) | ||
| # The route may not exist, but auth should pass (not 401) | ||
| assert response.status_code != 401 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this under assumption the options requests get swallowed by the cors middleware? If so, worth adding a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OPTIONS mustn't require auth, so the code lets it pass. There is no assumption cors middleware exists at all...