Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions backend/app_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

from __future__ import annotations

import base64
import hmac
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.responses import Response as StarletteResponse

from _routes._errors import HTTPError
from _routes.generation import router as generation_router
Expand Down Expand Up @@ -36,6 +40,7 @@ def create_app(
handler: "AppHandler",
allowed_origins: list[str] | None = None,
title: str = "LTX-2 Video Generation Server",
auth_token: str = "",
) -> FastAPI:
"""Create a configured FastAPI app bound to the provided handler."""
init_state_service(handler)
Expand All @@ -48,6 +53,37 @@ def create_app(
allow_headers=["*"],
)

@app.middleware("http")
async def _auth_middleware( # pyright: ignore[reportUnusedFunction]
request: Request,
call_next: Callable[[Request], Awaitable[StarletteResponse]],
) -> StarletteResponse:
if not auth_token:
return await call_next(request)
if request.method == "OPTIONS":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this under assumption the options requests get swallowed by the cors middleware? If so, worth adding a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OPTIONS mustn't require auth, so the code lets it pass. There is no assumption cors middleware exists at all...

return await call_next(request)
def _token_matches(candidate: str) -> bool:
return hmac.compare_digest(candidate, auth_token)

# WebSocket: check query param
if request.headers.get("upgrade", "").lower() == "websocket":
if _token_matches(request.query_params.get("token", "")):
return await call_next(request)
return JSONResponse(status_code=401, content={"error": "Unauthorized"})
# HTTP: Bearer or Basic auth
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer ") and _token_matches(auth_header[7:]):
return await call_next(request)
if auth_header.startswith("Basic "):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needed in addition to bearer token?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it isn't but curl is slightly more convenient to use with basic: curl -u ':<token'> ...

try:
decoded = base64.b64decode(auth_header[6:]).decode()
_, _, password = decoded.partition(":")
if _token_matches(password):
return await call_next(request)
except Exception:
pass
return JSONResponse(status_code=401, content={"error": "Unauthorized"})

async def _route_http_error_handler(request: Request, exc: Exception) -> JSONResponse:
if isinstance(exc, HTTPError):
log_http_error(request, exc)
Expand Down
34 changes: 30 additions & 4 deletions backend/ltx2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def patched_sdpa(
# Constants & Paths
# ============================================================

PORT = 8000
PORT = 0


def _get_device() -> torch.device:
Expand Down Expand Up @@ -219,7 +219,10 @@ def _resolve_force_api_generations() -> bool:
)

handler = build_initial_state(runtime_config, DEFAULT_APP_SETTINGS)
app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS)

auth_token = os.environ.get("LTX_AUTH_TOKEN", "")

app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS, auth_token=auth_token)


def precache_model_files(model_dir: Path) -> int:
Expand Down Expand Up @@ -257,9 +260,10 @@ def log_hardware_info() -> None:


if __name__ == "__main__":
import asyncio
import uvicorn

port = int(os.environ.get("LTX_PORT", PORT))
port = int(os.environ.get("LTX_PORT", "") or PORT)
logger.info("=" * 60)
logger.info("LTX-2 Video Generation Server (FastAPI + Uvicorn)")
log_hardware_info()
Expand All @@ -285,4 +289,26 @@ def log_hardware_info() -> None:
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
},
}
uvicorn.run(app, host="127.0.0.1", port=port, log_level="info", access_log=False, log_config=log_config)

import socket as _socket

# Bind the socket ourselves so we know the actual port before uvicorn starts.
sock = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1)
sock.bind(("127.0.0.1", port))
actual_port = int(sock.getsockname()[1])

config = uvicorn.Config(app, host="127.0.0.1", port=actual_port, log_level="info", access_log=False, log_config=log_config)
server = uvicorn.Server(config)

_orig_startup = server.startup

async def _startup_with_ready_msg(sockets: list[_socket.socket] | None = None) -> None:
await _orig_startup(sockets=sockets)
if server.started:
# Machine-parseable ready message — Electron matches this line
print(f"Server running on http://127.0.0.1:{actual_port}", flush=True)

server.startup = _startup_with_ready_msg # type: ignore[assignment]

asyncio.run(server.serve(sockets=[sock]))
74 changes: 74 additions & 0 deletions backend/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Tests for shared-secret authentication middleware."""

from __future__ import annotations

import base64

from starlette.testclient import TestClient

from app_factory import create_app


def test_request_without_token_returns_401(test_state):
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 401
assert response.json() == {"error": "Unauthorized"}


def test_request_with_correct_bearer_token(test_state):
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": "Bearer test-secret"})
assert response.status_code == 200


def test_request_with_correct_basic_auth(test_state):
app = create_app(handler=test_state, auth_token="test-secret")
credentials = base64.b64encode(b":test-secret").decode()
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": f"Basic {credentials}"})
assert response.status_code == 200


def test_request_with_wrong_token_returns_401(test_state):
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health", headers={"Authorization": "Bearer wrong-token"})
assert response.status_code == 401


def test_health_without_token_returns_401(test_state):
"""Health endpoint is NOT exempt from auth."""
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 401


def test_no_auth_token_disables_middleware(test_state):
"""When auth_token is empty string, auth is disabled (dev/test mode)."""
app = create_app(handler=test_state, auth_token="")
with TestClient(app) as client:
response = client.get("/health")
assert response.status_code == 200


def test_websocket_with_token_query_param(test_state):
app = create_app(handler=test_state, auth_token="test-secret")
with TestClient(app) as client:
# WebSocket upgrade without token should fail with 401
response = client.get(
"/ws/download/test",
headers={"upgrade": "websocket", "connection": "upgrade"},
)
assert response.status_code == 401

# WebSocket upgrade with correct token query param
response = client.get(
"/ws/download/test?token=test-secret",
headers={"upgrade": "websocket", "connection": "upgrade"},
)
# The route may not exist, but auth should pass (not 401)
assert response.status_code != 401
2 changes: 0 additions & 2 deletions electron/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import path from 'path'
import os from 'os'
import { getProjectAssetsPath } from './app-state'

export const PYTHON_PORT = 8000
export const BACKEND_BASE_URL = `http://localhost:${PYTHON_PORT}`
export const isDev = !app.isPackaged

// Get directory - works in both CJS and ESM contexts
Expand Down
12 changes: 8 additions & 4 deletions electron/gpu.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { execSync } from 'child_process'
import { BACKEND_BASE_URL } from './config'
import { logger } from './logger'
import { getPythonPath } from './python-backend'
import { getAuthToken, getBackendUrl, getPythonPath } from './python-backend'

// Check if NVIDIA GPU is available
export async function checkGPU(): Promise<{ available: boolean; name?: string; vram?: number }> {
try {
const url = getBackendUrl()
if (!url) throw new Error('Backend URL not available yet')
// Try to get GPU info from the backend API first (more reliable)
const response = await fetch(`${BACKEND_BASE_URL}/api/gpu-info`, {
const headers: Record<string, string> = { 'Content-Type': 'application/json' }
const token = getAuthToken()
if (token) headers['Authorization'] = `Bearer ${token}`
const response = await fetch(`${url}/api/gpu-info`, {
method: 'GET',
headers: { 'Content-Type': 'application/json' },
headers,
})

if (response.ok) {
Expand Down
7 changes: 3 additions & 4 deletions electron/ipc/app-handlers.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { app, ipcMain } from 'electron'
import path from 'path'
import fs from 'fs'
import { BACKEND_BASE_URL } from '../config'
import { checkGPU } from '../gpu'
import { isPythonReady, downloadPythonEmbed } from '../python-setup'
import { getBackendHealthStatus, startPythonBackend } from '../python-backend'
import { getBackendHealthStatus, getBackendUrl, getAuthToken, startPythonBackend } from '../python-backend'
import { getMainWindow } from '../window'
import { getAnalyticsState, setAnalyticsEnabled, sendAnalyticsEvent } from '../analytics'

Expand Down Expand Up @@ -68,8 +67,8 @@ function markLicenseAccepted(settingsPath: string): void {
}

export function registerAppHandlers(): void {
ipcMain.handle('get-backend-url', () => {
return BACKEND_BASE_URL
ipcMain.handle('get-backend', () => {
return { url: getBackendUrl() ?? '', token: getAuthToken() ?? '' }
})

ipcMain.handle('get-models-path', () => {
Expand Down
6 changes: 3 additions & 3 deletions electron/preload.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ const { contextBridge, ipcRenderer } = require('electron')

// Expose protected methods to the renderer process
contextBridge.exposeInMainWorld('electronAPI', {
// Get the backend URL
getBackendUrl: (): Promise<string> => ipcRenderer.invoke('get-backend-url'),
// Get the backend URL and auth token
getBackend: (): Promise<{ url: string; token: string }> => ipcRenderer.invoke('get-backend'),

// Get the path where models are stored
getModelsPath: (): Promise<string> => ipcRenderer.invoke('get-models-path'),
Expand Down Expand Up @@ -138,7 +138,7 @@ interface BackendHealthStatus {
declare global {
interface Window {
electronAPI: {
getBackendUrl: () => Promise<string>
getBackend: () => Promise<{ url: string; token: string }>
getModelsPath: () => Promise<string>
readLocalFile: (filePath: string) => Promise<{ data: string; mimeType: string }>
checkGpu: () => Promise<{ available: boolean; name?: string; vram?: number }>
Expand Down
Loading