From c6068da4d611bc815bdfd4c0567257f7188f56cf Mon Sep 17 00:00:00 2001 From: Alex Gershovich Date: Sat, 7 Mar 2026 23:23:00 +0000 Subject: [PATCH] switch backend to listen on free port and require auth - instead of hard-coded port we dynamically pick a free one - backend<->frontend channel is guarded by a random auth token --- backend/app_factory.py | 36 ++++++++++ backend/ltx2_server.py | 34 ++++++++-- backend/tests/test_auth.py | 74 +++++++++++++++++++++ electron/config.ts | 2 - electron/gpu.ts | 12 ++-- electron/ipc/app-handlers.ts | 7 +- electron/preload.ts | 6 +- electron/python-backend.ts | 63 ++++++++++++++---- frontend/App.tsx | 4 +- frontend/components/FirstRunSetup.tsx | 17 ++--- frontend/components/ICLoraPanel.tsx | 13 ++-- frontend/components/ModelStatusDropdown.tsx | 22 ++---- frontend/components/SettingsModal.tsx | 9 ++- frontend/contexts/AppSettingsContext.tsx | 45 ++++++------- frontend/hooks/use-backend.ts | 17 +++-- frontend/hooks/use-generation.ts | 26 +++----- frontend/hooks/use-retake.ts | 4 +- frontend/lib/backend.ts | 24 +++++++ frontend/views/editor/useGapGeneration.ts | 4 +- frontend/views/editor/useRegeneration.ts | 4 +- frontend/vite-env.d.ts | 2 +- 21 files changed, 297 insertions(+), 128 deletions(-) create mode 100644 backend/tests/test_auth.py create mode 100644 frontend/lib/backend.ts diff --git a/backend/app_factory.py b/backend/app_factory.py index bd9651e4..08620354 100644 --- a/backend/app_factory.py +++ b/backend/app_factory.py @@ -2,12 +2,16 @@ from __future__ import annotations +import base64 +import hmac +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from starlette.responses import Response as StarletteResponse from _routes._errors import HTTPError from _routes.generation import router as generation_router @@ -36,6 +40,7 @@ def create_app( handler: "AppHandler", allowed_origins: list[str] | None = None, title: str = "LTX-2 Video Generation Server", + auth_token: str = "", ) -> FastAPI: """Create a configured FastAPI app bound to the provided handler.""" init_state_service(handler) @@ -48,6 +53,37 @@ def create_app( allow_headers=["*"], ) + @app.middleware("http") + async def _auth_middleware( # pyright: ignore[reportUnusedFunction] + request: Request, + call_next: Callable[[Request], Awaitable[StarletteResponse]], + ) -> StarletteResponse: + if not auth_token: + return await call_next(request) + if request.method == "OPTIONS": + return await call_next(request) + def _token_matches(candidate: str) -> bool: + return hmac.compare_digest(candidate, auth_token) + + # WebSocket: check query param + if request.headers.get("upgrade", "").lower() == "websocket": + if _token_matches(request.query_params.get("token", "")): + return await call_next(request) + return JSONResponse(status_code=401, content={"error": "Unauthorized"}) + # HTTP: Bearer or Basic auth + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer ") and _token_matches(auth_header[7:]): + return await call_next(request) + if auth_header.startswith("Basic "): + try: + decoded = base64.b64decode(auth_header[6:]).decode() + _, _, password = decoded.partition(":") + if _token_matches(password): + return await call_next(request) + except Exception: + pass + return JSONResponse(status_code=401, content={"error": "Unauthorized"}) + async def _route_http_error_handler(request: Request, exc: Exception) -> JSONResponse: if isinstance(exc, HTTPError): log_http_error(request, exc) diff --git a/backend/ltx2_server.py b/backend/ltx2_server.py index e65c66c2..2d55aabc 100644 --- a/backend/ltx2_server.py +++ b/backend/ltx2_server.py @@ -99,7 +99,7 @@ def patched_sdpa( # Constants & Paths # ============================================================ -PORT = 8000 +PORT = 0 def _get_device() -> torch.device: @@ -219,7 +219,10 @@ def _resolve_force_api_generations() -> bool: ) handler = build_initial_state(runtime_config, DEFAULT_APP_SETTINGS) -app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS) + +auth_token = os.environ.get("LTX_AUTH_TOKEN", "") + +app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS, auth_token=auth_token) def precache_model_files(model_dir: Path) -> int: @@ -257,9 +260,10 @@ def log_hardware_info() -> None: if __name__ == "__main__": + import asyncio import uvicorn - port = int(os.environ.get("LTX_PORT", PORT)) + port = int(os.environ.get("LTX_PORT", "") or PORT) logger.info("=" * 60) logger.info("LTX-2 Video Generation Server (FastAPI + Uvicorn)") log_hardware_info() @@ -285,4 +289,26 @@ def log_hardware_info() -> None: "uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False}, }, } - uvicorn.run(app, host="127.0.0.1", port=port, log_level="info", access_log=False, log_config=log_config) + + import socket as _socket + + # Bind the socket ourselves so we know the actual port before uvicorn starts. + sock = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) + sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", port)) + actual_port = int(sock.getsockname()[1]) + + config = uvicorn.Config(app, host="127.0.0.1", port=actual_port, log_level="info", access_log=False, log_config=log_config) + server = uvicorn.Server(config) + + _orig_startup = server.startup + + async def _startup_with_ready_msg(sockets: list[_socket.socket] | None = None) -> None: + await _orig_startup(sockets=sockets) + if server.started: + # Machine-parseable ready message — Electron matches this line + print(f"Server running on http://127.0.0.1:{actual_port}", flush=True) + + server.startup = _startup_with_ready_msg # type: ignore[assignment] + + asyncio.run(server.serve(sockets=[sock])) diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 00000000..ee90f9b1 --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,74 @@ +"""Tests for shared-secret authentication middleware.""" + +from __future__ import annotations + +import base64 + +from starlette.testclient import TestClient + +from app_factory import create_app + + +def test_request_without_token_returns_401(test_state): + app = create_app(handler=test_state, auth_token="test-secret") + with TestClient(app) as client: + response = client.get("/health") + assert response.status_code == 401 + assert response.json() == {"error": "Unauthorized"} + + +def test_request_with_correct_bearer_token(test_state): + app = create_app(handler=test_state, auth_token="test-secret") + with TestClient(app) as client: + response = client.get("/health", headers={"Authorization": "Bearer test-secret"}) + assert response.status_code == 200 + + +def test_request_with_correct_basic_auth(test_state): + app = create_app(handler=test_state, auth_token="test-secret") + credentials = base64.b64encode(b":test-secret").decode() + with TestClient(app) as client: + response = client.get("/health", headers={"Authorization": f"Basic {credentials}"}) + assert response.status_code == 200 + + +def test_request_with_wrong_token_returns_401(test_state): + app = create_app(handler=test_state, auth_token="test-secret") + with TestClient(app) as client: + response = client.get("/health", headers={"Authorization": "Bearer wrong-token"}) + assert response.status_code == 401 + + +def test_health_without_token_returns_401(test_state): + """Health endpoint is NOT exempt from auth.""" + app = create_app(handler=test_state, auth_token="test-secret") + with TestClient(app) as client: + response = client.get("/health") + assert response.status_code == 401 + + +def test_no_auth_token_disables_middleware(test_state): + """When auth_token is empty string, auth is disabled (dev/test mode).""" + app = create_app(handler=test_state, auth_token="") + with TestClient(app) as client: + response = client.get("/health") + assert response.status_code == 200 + + +def test_websocket_with_token_query_param(test_state): + app = create_app(handler=test_state, auth_token="test-secret") + with TestClient(app) as client: + # WebSocket upgrade without token should fail with 401 + response = client.get( + "/ws/download/test", + headers={"upgrade": "websocket", "connection": "upgrade"}, + ) + assert response.status_code == 401 + + # WebSocket upgrade with correct token query param + response = client.get( + "/ws/download/test?token=test-secret", + headers={"upgrade": "websocket", "connection": "upgrade"}, + ) + # The route may not exist, but auth should pass (not 401) + assert response.status_code != 401 diff --git a/electron/config.ts b/electron/config.ts index d59a9be2..f636f7b7 100644 --- a/electron/config.ts +++ b/electron/config.ts @@ -3,8 +3,6 @@ import path from 'path' import os from 'os' import { getProjectAssetsPath } from './app-state' -export const PYTHON_PORT = 8000 -export const BACKEND_BASE_URL = `http://localhost:${PYTHON_PORT}` export const isDev = !app.isPackaged // Get directory - works in both CJS and ESM contexts diff --git a/electron/gpu.ts b/electron/gpu.ts index 35acd86b..2fe414c2 100644 --- a/electron/gpu.ts +++ b/electron/gpu.ts @@ -1,15 +1,19 @@ import { execSync } from 'child_process' -import { BACKEND_BASE_URL } from './config' import { logger } from './logger' -import { getPythonPath } from './python-backend' +import { getAuthToken, getBackendUrl, getPythonPath } from './python-backend' // Check if NVIDIA GPU is available export async function checkGPU(): Promise<{ available: boolean; name?: string; vram?: number }> { try { + const url = getBackendUrl() + if (!url) throw new Error('Backend URL not available yet') // Try to get GPU info from the backend API first (more reliable) - const response = await fetch(`${BACKEND_BASE_URL}/api/gpu-info`, { + const headers: Record = { 'Content-Type': 'application/json' } + const token = getAuthToken() + if (token) headers['Authorization'] = `Bearer ${token}` + const response = await fetch(`${url}/api/gpu-info`, { method: 'GET', - headers: { 'Content-Type': 'application/json' }, + headers, }) if (response.ok) { diff --git a/electron/ipc/app-handlers.ts b/electron/ipc/app-handlers.ts index cf6effc5..88cb9f25 100644 --- a/electron/ipc/app-handlers.ts +++ b/electron/ipc/app-handlers.ts @@ -1,10 +1,9 @@ import { app, ipcMain } from 'electron' import path from 'path' import fs from 'fs' -import { BACKEND_BASE_URL } from '../config' import { checkGPU } from '../gpu' import { isPythonReady, downloadPythonEmbed } from '../python-setup' -import { getBackendHealthStatus, startPythonBackend } from '../python-backend' +import { getBackendHealthStatus, getBackendUrl, getAuthToken, startPythonBackend } from '../python-backend' import { getMainWindow } from '../window' import { getAnalyticsState, setAnalyticsEnabled, sendAnalyticsEvent } from '../analytics' @@ -68,8 +67,8 @@ function markLicenseAccepted(settingsPath: string): void { } export function registerAppHandlers(): void { - ipcMain.handle('get-backend-url', () => { - return BACKEND_BASE_URL + ipcMain.handle('get-backend', () => { + return { url: getBackendUrl() ?? '', token: getAuthToken() ?? '' } }) ipcMain.handle('get-models-path', () => { diff --git a/electron/preload.ts b/electron/preload.ts index 93b813a3..80f86fa6 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -3,8 +3,8 @@ const { contextBridge, ipcRenderer } = require('electron') // Expose protected methods to the renderer process contextBridge.exposeInMainWorld('electronAPI', { - // Get the backend URL - getBackendUrl: (): Promise => ipcRenderer.invoke('get-backend-url'), + // Get the backend URL and auth token + getBackend: (): Promise<{ url: string; token: string }> => ipcRenderer.invoke('get-backend'), // Get the path where models are stored getModelsPath: (): Promise => ipcRenderer.invoke('get-models-path'), @@ -138,7 +138,7 @@ interface BackendHealthStatus { declare global { interface Window { electronAPI: { - getBackendUrl: () => Promise + getBackend: () => Promise<{ url: string; token: string }> getModelsPath: () => Promise readLocalFile: (filePath: string) => Promise<{ data: string; mimeType: string }> checkGpu: () => Promise<{ available: boolean; name?: string; vram?: number }> diff --git a/electron/python-backend.ts b/electron/python-backend.ts index 5d489bf0..ed42f81d 100644 --- a/electron/python-backend.ts +++ b/electron/python-backend.ts @@ -1,9 +1,11 @@ import { ChildProcess, spawn } from 'child_process' +import crypto from 'crypto' import fs from 'fs' import path from 'path' import { getAppDataDir } from './app-paths' -import { BACKEND_BASE_URL, getCurrentDir, isDev, PYTHON_PORT } from './config' +import { getCurrentDir, isDev } from './config' import { logger, writeLog } from './logger' +import { getCurrentLogFilename } from './logging-management' import { getPythonDir } from './python-setup' import { getMainWindow } from './window' @@ -14,6 +16,12 @@ const CRASH_DEBOUNCE_MS = 10_000 let startPromise: Promise | null = null let takeoverInFlight: Promise | null = null +let backendUrl: string | null = null +let authToken: string | null = null + +export function getBackendUrl(): string | null { return backendUrl } +export function getAuthToken(): string | null { return authToken } + type BackendOwnership = 'managed' | 'adopted' | null let backendOwnership: BackendOwnership = null @@ -50,12 +58,17 @@ function isPortConflictOutput(output: string): boolean { ) } -async function probeBackendHealth(timeoutMs = 1500): Promise { +async function probeBackendHealth(timeoutMs = 1500, probeUrl?: string): Promise { + const url = probeUrl || backendUrl + if (!url) return false const controller = new AbortController() const timeout = setTimeout(() => controller.abort(), timeoutMs) try { - const response = await fetch(`${BACKEND_BASE_URL}/health`, { + const headers: Record = {} + if (authToken) headers['Authorization'] = `Bearer ${authToken}` + const response = await fetch(`${url}/health`, { signal: controller.signal, + headers, }) return response.ok } catch { @@ -66,12 +79,16 @@ async function probeBackendHealth(timeoutMs = 1500): Promise { } async function requestAdoptedBackendShutdown(timeoutMs = 2000): Promise { + if (!backendUrl) return false const controller = new AbortController() const timeout = setTimeout(() => controller.abort(), timeoutMs) try { - const response = await fetch(`${BACKEND_BASE_URL}/api/system/shutdown`, { + const headers: Record = {} + if (authToken) headers['Authorization'] = `Bearer ${authToken}` + const response = await fetch(`${backendUrl}/api/system/shutdown`, { method: 'POST', signal: controller.signal, + headers, }) return response.ok } catch { @@ -216,13 +233,19 @@ export async function startPythonBackend(): Promise { pythonArgs = isDev ? ['-Xfrozen_modules=off', '-u', mainPy] : ['-u', mainPy] } + // Generate auth token for this backend session + authToken = crypto.randomBytes(32).toString('base64url') + pythonProcess = spawn(pythonPath, pythonArgs, { cwd: backendPath, env: { ...process.env, PYTHONUNBUFFERED: '1', PYTHONNOUSERSITE: '1', - LTX_PORT: String(PYTHON_PORT), + // Only pass LTX_PORT when the developer explicitly set it + ...(process.env.LTX_PORT ? { LTX_PORT: process.env.LTX_PORT } : {}), + LTX_AUTH_TOKEN: authToken, + LTX_LOG_FILE: getCurrentLogFilename(), LTX_APP_DATA_DIR: getAppDataDir(), PYTORCH_ENABLE_MPS_FALLBACK: '1', // Set PYTHONHOME for bundled Python on macOS so it finds its stdlib @@ -254,12 +277,22 @@ export async function startPythonBackend(): Promise { sawPortConflict = true } - // Check if server has started - if (!started && (output.includes('Server running on') || output.includes('Uvicorn running'))) { - started = true - backendOwnership = 'managed' - publishBackendHealthStatus({ status: 'alive' }) - settleResolve() + // Check if server has started — parse URL from ready message + if (!started) { + const readyMatch = output.match(/Server running on (http:\/\/\S+)/) + if (readyMatch) { + backendUrl = readyMatch[1] + started = true + backendOwnership = 'managed' + publishBackendHealthStatus({ status: 'alive' }) + settleResolve() + } else if (output.includes('Uvicorn running')) { + // Fallback for legacy/dev uvicorn output + started = true + backendOwnership = 'managed' + publishBackendHealthStatus({ status: 'alive' }) + settleResolve() + } } } @@ -295,6 +328,8 @@ export async function startPythonBackend(): Promise { pythonProcess.on('exit', async (code) => { logger.info(`Python backend exited with code ${code}`) pythonProcess = null + backendUrl = null + authToken = null if (!started) { if (isIntentionalShutdown) { @@ -304,9 +339,11 @@ export async function startPythonBackend(): Promise { return } - if (sawPortConflict) { - const healthyExistingBackend = await probeBackendHealth() + if (sawPortConflict && process.env.LTX_PORT) { + const explicitUrl = `http://127.0.0.1:${process.env.LTX_PORT}` + const healthyExistingBackend = await probeBackendHealth(1500, explicitUrl) if (healthyExistingBackend) { + backendUrl = explicitUrl backendOwnership = 'adopted' publishBackendHealthStatus({ status: 'alive' }) settleResolve() diff --git a/frontend/App.tsx b/frontend/App.tsx index 425217c5..a8ee1094 100644 --- a/frontend/App.tsx +++ b/frontend/App.tsx @@ -1,5 +1,6 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { Loader2, AlertCircle, Settings, FileText } from 'lucide-react' +import { backendFetch } from './lib/backend' import { ProjectProvider, useProjects } from './contexts/ProjectContext' import { KeyboardShortcutsProvider } from './contexts/KeyboardShortcutsContext' import { AppSettingsProvider, useAppSettings } from './contexts/AppSettingsContext' @@ -177,8 +178,7 @@ function AppContent() { isForcedFirstRun && isLoaded && settings.hasLtxApiKey && !isFinalizingFirstRun && !firstRunFinalizeError const areRequiredModelsDownloaded = useCallback(async () => { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/models/status`) + const response = await backendFetch('/api/models/status') if (!response.ok) { throw new Error(`Model status fetch failed with status ${response.status}`) } diff --git a/frontend/components/FirstRunSetup.tsx b/frontend/components/FirstRunSetup.tsx index 5cac0968..771d2099 100644 --- a/frontend/components/FirstRunSetup.tsx +++ b/frontend/components/FirstRunSetup.tsx @@ -1,4 +1,5 @@ import { useState, useEffect } from 'react' +import { backendFetch } from '../lib/backend' import { logger } from '../lib/logger' import './FirstRunSetup.css' @@ -51,7 +52,6 @@ export function LaunchGate({ const [availableSpace, setAvailableSpace] = useState('...') const [videoPath, setVideoPath] = useState('/splash/splash.mp4') const [ltxApiKey, setLtxApiKey] = useState('') - const [backendUrl, setBackendUrl] = useState(null) const [licenseAccepted, setLicenseAccepted] = useState(false) const [licenseText, setLicenseText] = useState(null) const [licenseError, setLicenseError] = useState(null) @@ -114,9 +114,7 @@ export function LaunchGate({ // Get models path from backend try { - const url = await window.electronAPI.getBackendUrl() - setBackendUrl(url) - const response = await fetch(`${url}/api/models/status`) + const response = await backendFetch('/api/models/status') if (response.ok) { const data = await response.json() if (data.models_path) { @@ -152,11 +150,11 @@ export function LaunchGate({ // Poll download progress during installation useEffect(() => { - if (currentStep !== 'installing' || !backendUrl) return + if (currentStep !== 'installing') return const pollProgress = async () => { try { - const response = await fetch(`${backendUrl}/api/models/download/progress`) + const response = await backendFetch('/api/models/download/progress') if (response.ok) { const progress = await response.json() setDownloadProgress(progress) @@ -175,17 +173,16 @@ export function LaunchGate({ pollProgress() const interval = setInterval(pollProgress, 500) return () => clearInterval(interval) - }, [currentStep, backendUrl]) + }, [currentStep]) // Start installation const startInstallation = async () => { - if (!backendUrl) return setCurrentStep('installing') try { // If API key is provided, save it to settings first and skip text encoder download if (ltxApiKey.trim()) { try { - await fetch(`${backendUrl}/api/settings`, { + await backendFetch('/api/settings', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ ltxApiKey: ltxApiKey.trim() }), @@ -196,7 +193,7 @@ export function LaunchGate({ } // Start download - skip text encoder if API key is provided - await fetch(`${backendUrl}/api/models/download`, { + await backendFetch('/api/models/download', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ skipTextEncoder: !!ltxApiKey.trim() }), diff --git a/frontend/components/ICLoraPanel.tsx b/frontend/components/ICLoraPanel.tsx index ece5ff58..83bf3bd4 100644 --- a/frontend/components/ICLoraPanel.tsx +++ b/frontend/components/ICLoraPanel.tsx @@ -3,6 +3,7 @@ import { X, Play, Pause, Upload, Loader2, Film, Sparkles, FolderOpen, ChevronDown, RefreshCw, Settings, Download, Check, AlertCircle, } from 'lucide-react' +import { backendFetch } from '../lib/backend' import { logger } from '../lib/logger' interface ICLoraModel { @@ -136,8 +137,7 @@ export function ICLoraPanel({ // Fetch available models const fetchModels = useCallback(async () => { try { - const backendUrl = await window.electronAPI.getBackendUrl() - const resp = await fetch(`${backendUrl}/api/ic-lora/list-models`) + const resp = await backendFetch('/api/ic-lora/list-models') if (resp.ok) { const data = await resp.json() setModels(data.models || []) @@ -155,8 +155,7 @@ export function ICLoraPanel({ if (!inputVideoPath || isExtracting) return setIsExtracting(true) try { - const backendUrl = await window.electronAPI.getBackendUrl() - const resp = await fetch(`${backendUrl}/api/ic-lora/extract-conditioning`, { + const resp = await backendFetch('/api/ic-lora/extract-conditioning', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ @@ -254,8 +253,7 @@ export function ICLoraPanel({ if (downloadingModels[modelDef.id] === 'downloading') return setDownloadingModels(prev => ({ ...prev, [modelDef.id]: 'downloading' })) try { - const backendUrl = await window.electronAPI.getBackendUrl() - const resp = await fetch(`${backendUrl}/api/ic-lora/download-model`, { + const resp = await backendFetch('/api/ic-lora/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model: modelDef.id }), @@ -302,9 +300,8 @@ export function ICLoraPanel({ setOutputVideoPath(null) try { - const backendUrl = await window.electronAPI.getBackendUrl() setGenerationStatus('Generating video with IC-LoRA...') - const resp = await fetch(`${backendUrl}/api/ic-lora/generate`, { + const resp = await backendFetch('/api/ic-lora/generate', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ diff --git a/frontend/components/ModelStatusDropdown.tsx b/frontend/components/ModelStatusDropdown.tsx index 21b6662d..7cbf764e 100644 --- a/frontend/components/ModelStatusDropdown.tsx +++ b/frontend/components/ModelStatusDropdown.tsx @@ -1,5 +1,6 @@ import { useState, useEffect, useRef } from 'react' import { Loader2, CheckCircle2, Download, Clock, ChevronDown, AlertCircle } from 'lucide-react' +import { backendFetch } from '../lib/backend' import { logger } from '../lib/logger' interface ModelInfo { @@ -42,21 +43,13 @@ export function ModelStatusDropdown({ className = '' }: ModelStatusDropdownProps const [isOpen, setIsOpen] = useState(false) const [modelsStatus, setModelsStatus] = useState(null) const [downloadProgress, setDownloadProgress] = useState(null) - const [backendUrl, setBackendUrl] = useState(null) const dropdownRef = useRef(null) - // Fetch backend URL once on mount - useEffect(() => { - window.electronAPI.getBackendUrl().then(setBackendUrl) - }, []) - // Fetch models status periodically useEffect(() => { - if (!backendUrl) return - const fetchModelsStatus = async () => { try { - const response = await fetch(`${backendUrl}/api/models/status`) + const response = await backendFetch('/api/models/status') if (response.ok) { setModelsStatus(await response.json()) } @@ -68,15 +61,15 @@ export function ModelStatusDropdown({ className = '' }: ModelStatusDropdownProps fetchModelsStatus() const interval = setInterval(fetchModelsStatus, 5000) return () => clearInterval(interval) - }, [backendUrl]) + }, []) // Poll download progress when downloading useEffect(() => { - if (!isOpen || !backendUrl) return + if (!isOpen) return const pollProgress = async () => { try { - const response = await fetch(`${backendUrl}/api/models/download/progress`) + const response = await backendFetch('/api/models/download/progress') if (response.ok) { setDownloadProgress(await response.json()) } @@ -88,7 +81,7 @@ export function ModelStatusDropdown({ className = '' }: ModelStatusDropdownProps pollProgress() const interval = setInterval(pollProgress, 1000) return () => clearInterval(interval) - }, [isOpen, backendUrl]) + }, [isOpen]) // Close dropdown when clicking outside useEffect(() => { @@ -135,9 +128,8 @@ export function ModelStatusDropdown({ className = '' }: ModelStatusDropdownProps } const startDownload = async () => { - if (!backendUrl) return try { - await fetch(`${backendUrl}/api/models/download`, { + await backendFetch('/api/models/download', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({}), diff --git a/frontend/components/SettingsModal.tsx b/frontend/components/SettingsModal.tsx index add5bf6e..6f6e9881 100644 --- a/frontend/components/SettingsModal.tsx +++ b/frontend/components/SettingsModal.tsx @@ -2,6 +2,7 @@ import { AlertCircle, Check, Download, Film, Folder, Info, KeyRound, Settings, S import React, { useEffect, useRef, useState } from 'react' import { Button } from './ui/button' import { useAppSettings, type AppSettings } from '../contexts/AppSettingsContext' +import { backendFetch } from '../lib/backend' import { logger } from '../lib/logger' import { ApiKeyHelperRow, LtxApiKeyInput, LtxApiKeyHelperRow } from './LtxApiKeyInput' @@ -86,8 +87,7 @@ export function SettingsModal({ isOpen, onClose, initialTab }: SettingsModalProp const fetchStatus = async () => { try { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/models/status`) + const response = await backendFetch('/api/models/status') if (response.ok) { const data = await response.json() setTextEncoderStatus(data.text_encoder_status) @@ -108,8 +108,7 @@ export function SettingsModal({ isOpen, onClose, initialTab }: SettingsModalProp setIsDownloading(true) setDownloadError(null) try { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/text-encoder/download`, { method: 'POST' }) + const response = await backendFetch('/api/text-encoder/download', { method: 'POST' }) const data = await response.json() if (data.status === 'already_downloaded') { @@ -118,7 +117,7 @@ export function SettingsModal({ isOpen, onClose, initialTab }: SettingsModalProp // Poll for completion const pollInterval = setInterval(async () => { try { - const statusRes = await fetch(`${backendUrl}/api/models/status`) + const statusRes = await backendFetch('/api/models/status') if (statusRes.ok) { const statusData = await statusRes.json() setTextEncoderStatus(statusData.text_encoder_status) diff --git a/frontend/contexts/AppSettingsContext.tsx b/frontend/contexts/AppSettingsContext.tsx index 63d1655c..97ab2c8c 100644 --- a/frontend/contexts/AppSettingsContext.tsx +++ b/frontend/contexts/AppSettingsContext.tsx @@ -1,4 +1,5 @@ import { createContext, useCallback, useContext, useEffect, useMemo, useState, type ReactNode } from 'react' +import { backendFetch, resetBackendCredentials } from '../lib/backend' export interface InferenceSettings { steps: number @@ -95,23 +96,18 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { const [settings, setSettings] = useState(DEFAULT_APP_SETTINGS) const [isLoaded, setIsLoaded] = useState(false) const [runtimePolicyLoaded, setRuntimePolicyLoaded] = useState(false) - const [backendUrl, setBackendUrl] = useState(null) const [forceApiGenerations, setForceApiGenerations] = useState(true) const [backendProcessStatus, setBackendProcessStatus] = useState(null) useEffect(() => { - window.electronAPI.getBackendUrl().then(setBackendUrl).catch(() => setBackendUrl(null)) - }, []) - - useEffect(() => { - if (!backendUrl || backendProcessStatus !== 'alive') return + if (backendProcessStatus !== 'alive') return let cancelled = false setRuntimePolicyLoaded(false) const fetchRuntimePolicy = async () => { try { - const response = await fetch(`${backendUrl}/api/runtime-policy`) + const response = await backendFetch('/api/runtime-policy') if (!response.ok) { throw new Error(`Runtime policy fetch failed with status ${response.status}`) } @@ -141,7 +137,7 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { return () => { cancelled = true } - }, [backendProcessStatus, backendUrl]) + }, [backendProcessStatus]) useEffect(() => { let cancelled = false @@ -151,6 +147,9 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { if (!nextStatus || cancelled) { return } + if (nextStatus === 'alive') { + resetBackendCredentials() + } setBackendProcessStatus(nextStatus) } @@ -173,18 +172,17 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { }, []) const refreshSettings = useCallback(async () => { - if (!backendUrl) return - const response = await fetch(`${backendUrl}/api/settings`) + const response = await backendFetch('/api/settings') if (!response.ok) { throw new Error(`Settings fetch failed with status ${response.status}`) } const data = await response.json() setSettings(normalizeAppSettings(data)) setIsLoaded(true) - }, [backendUrl]) + }, []) useEffect(() => { - if (!backendUrl || isLoaded || backendProcessStatus !== 'alive') return + if (isLoaded || backendProcessStatus !== 'alive') return let cancelled = false let retryTimer: ReturnType | null = null @@ -206,14 +204,14 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { cancelled = true if (retryTimer) clearTimeout(retryTimer) } - }, [backendProcessStatus, backendUrl, isLoaded, refreshSettings]) + }, [backendProcessStatus, isLoaded, refreshSettings]) useEffect(() => { - if (!backendUrl || !isLoaded || backendProcessStatus !== 'alive') return + if (!isLoaded || backendProcessStatus !== 'alive') return const syncTimer = setTimeout(async () => { try { const { hasLtxApiKey: _a, hasFalApiKey: _b, hasGeminiApiKey: _c, ...syncPayload } = settings - await fetch(`${backendUrl}/api/settings`, { + await backendFetch('/api/settings', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(syncPayload), @@ -223,7 +221,7 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { } }, 150) return () => clearTimeout(syncTimer) - }, [backendProcessStatus, backendUrl, isLoaded, settings]) + }, [backendProcessStatus, isLoaded, settings]) const updateSettings = useCallback((patch: Partial | ((prev: AppSettings) => AppSettings)) => { if (typeof patch === 'function') { @@ -234,8 +232,7 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { }, []) const saveLtxApiKey = useCallback(async (value: string) => { - if (!backendUrl) return - const response = await fetch(`${backendUrl}/api/settings`, { + const response = await backendFetch('/api/settings', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ ltxApiKey: value }), @@ -245,11 +242,10 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { throw new Error(detail || 'Failed to save LTX API key.') } await refreshSettings() - }, [backendUrl, refreshSettings]) + }, [refreshSettings]) const saveGeminiApiKey = useCallback(async (value: string) => { - if (!backendUrl) return - const response = await fetch(`${backendUrl}/api/settings`, { + const response = await backendFetch('/api/settings', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ geminiApiKey: value }), @@ -259,11 +255,10 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { throw new Error(detail || 'Failed to save Gemini API key.') } await refreshSettings() - }, [backendUrl, refreshSettings]) + }, [refreshSettings]) const saveFalApiKey = useCallback(async (value: string) => { - if (!backendUrl) return - const response = await fetch(`${backendUrl}/api/settings`, { + const response = await backendFetch('/api/settings', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ falApiKey: value }), @@ -273,7 +268,7 @@ export function AppSettingsProvider({ children }: { children: ReactNode }) { throw new Error(detail || 'Failed to save FAL API key.') } await refreshSettings() - }, [backendUrl, refreshSettings]) + }, [refreshSettings]) const shouldVideoGenerateWithLtxApi = forceApiGenerations || (settings.userPrefersLtxApiVideoGenerations && settings.hasLtxApiKey) diff --git a/frontend/hooks/use-backend.ts b/frontend/hooks/use-backend.ts index c1a2bedc..8caf9868 100644 --- a/frontend/hooks/use-backend.ts +++ b/frontend/hooks/use-backend.ts @@ -1,4 +1,5 @@ import { useState, useEffect, useCallback } from 'react' +import { backendFetch, backendWsUrl, resetBackendCredentials } from '../lib/backend' import { logger } from '../lib/logger' interface BackendStatus { @@ -65,9 +66,8 @@ export function useBackend(): UseBackendReturn { const checkHealth = useCallback(async (): Promise => { try { - const backendUrl = await window.electronAPI.getBackendUrl() - logger.info(`Checking backend health at: ${backendUrl}`) - const response = await fetch(`${backendUrl}/health`) + logger.info('Checking backend health...') + const response = await backendFetch('/health') if (response.ok) { const data = await response.json() @@ -92,8 +92,7 @@ export function useBackend(): UseBackendReturn { const fetchModels = useCallback(async () => { try { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/models`) + const response = await backendFetch('/api/models') if (response.ok) { const data = await response.json() @@ -106,10 +105,8 @@ export function useBackend(): UseBackendReturn { const downloadModel = useCallback(async (modelId: string) => { try { - const backendUrl = await window.electronAPI.getBackendUrl() - // Connect to WebSocket for download progress - const wsUrl = backendUrl.replace('http://', 'ws://') + `/ws/download/${modelId}` + const wsUrl = await backendWsUrl(`/ws/download/${modelId}`) const ws = new WebSocket(wsUrl) ws.onmessage = (event) => { @@ -130,7 +127,7 @@ export function useBackend(): UseBackendReturn { } // Trigger download - await fetch(`${backendUrl}/api/models/${modelId}/download`, { + await backendFetch(`/api/models/${modelId}/download`, { method: 'POST', }) } catch (err) { @@ -142,6 +139,8 @@ export function useBackend(): UseBackendReturn { setProcessStatus(payload.status) if (payload.status === 'alive') { + // Reset cached credentials so the new port/token are fetched + resetBackendCredentials() const healthy = await checkHealth() if (healthy) { await fetchModels() diff --git a/frontend/hooks/use-generation.ts b/frontend/hooks/use-generation.ts index 97ac1181..eb185117 100644 --- a/frontend/hooks/use-generation.ts +++ b/frontend/hooks/use-generation.ts @@ -1,5 +1,6 @@ import { useState, useCallback, useRef } from 'react' import type { GenerationSettings } from '../components/SettingsPanel' +import { backendFetch } from '../lib/backend' import { useAppSettings } from '../contexts/AppSettingsContext' interface GenerationState { @@ -133,9 +134,6 @@ export function useGeneration(): UseGenerationReturn { let shouldApplyPollingUpdates = true try { - // Get backend URL from Electron - const backendUrl = await window.electronAPI.getBackendUrl() - // Prepare JSON body const body: Record = { prompt, @@ -163,11 +161,11 @@ export function useGeneration(): UseGenerationReturn { const pollProgress = async () => { if (!shouldApplyPollingUpdates) return try { - const res = await fetch(`${backendUrl}/api/generation/progress`) + const res = await backendFetch('/api/generation/progress') if (res.ok) { const data: GenerationProgress = await res.json() if (!shouldApplyPollingUpdates) return - + let displayProgress = data.progress let statusMessage = getPhaseMessage(data.phase) @@ -205,7 +203,7 @@ export function useGeneration(): UseGenerationReturn { progressInterval = setInterval(pollProgress, 500) // Start generation (HTTP POST - synchronous, returns when done) - const response = await fetch(`${backendUrl}/api/generate`, { + const response = await backendFetch('/api/generate', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(body), @@ -275,11 +273,8 @@ export function useGeneration(): UseGenerationReturn { // Also tell the backend to cancel try { - const backendUrl = await window.electronAPI.getBackendUrl() - await fetch(`${backendUrl}/api/generate/cancel`, { - method: 'POST', - }) - } catch (e) { + await backendFetch('/api/generate/cancel', { method: 'POST' }) + } catch { // Ignore errors from cancel request } @@ -296,8 +291,7 @@ export function useGeneration(): UseGenerationReturn { ) => { if (forceApiGenerations) { try { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/settings`) + const response = await backendFetch('/api/settings') if (response.ok) { const payload = await response.json() if (!payload?.hasFalApiKey) { @@ -346,8 +340,6 @@ export function useGeneration(): UseGenerationReturn { abortControllerRef.current = new AbortController() try { - const backendUrl = await window.electronAPI.getBackendUrl() - // Skip prompt enhancement for T2I - use original prompt directly const finalPrompt = prompt @@ -357,7 +349,7 @@ export function useGeneration(): UseGenerationReturn { // Poll for progress const pollProgress = async () => { try { - const res = await fetch(`${backendUrl}/api/generation/progress`) + const res = await backendFetch('/api/generation/progress') if (res.ok) { const data = await res.json() const currentImage = data.currentStep || 0 @@ -383,7 +375,7 @@ export function useGeneration(): UseGenerationReturn { const progressInterval = setInterval(pollProgress, 500) - const response = await fetch(`${backendUrl}/api/generate-image`, { + const response = await backendFetch('/api/generate-image', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ diff --git a/frontend/hooks/use-retake.ts b/frontend/hooks/use-retake.ts index 7a4a6ab9..d08606be 100644 --- a/frontend/hooks/use-retake.ts +++ b/frontend/hooks/use-retake.ts @@ -1,4 +1,5 @@ import { useCallback, useState } from 'react' +import { backendFetch } from '../lib/backend' import { logger } from '../lib/logger' export type RetakeMode = 'replace_audio_and_video' | 'replace_video' | 'replace_audio' @@ -42,8 +43,7 @@ export function useRetake() { }) try { - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/retake`, { + const response = await backendFetch('/api/retake', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ diff --git a/frontend/lib/backend.ts b/frontend/lib/backend.ts new file mode 100644 index 00000000..b78d6ef3 --- /dev/null +++ b/frontend/lib/backend.ts @@ -0,0 +1,24 @@ +let cached: { url: string; token: string } | null = null + +export async function getBackendCredentials(): Promise<{ url: string; token: string }> { + if (!cached) cached = await window.electronAPI.getBackend() + return cached +} + +export function resetBackendCredentials(): void { + cached = null +} + +export async function backendFetch(path: string, init?: RequestInit): Promise { + const { url, token } = await getBackendCredentials() + const headers = new Headers(init?.headers) + if (token) headers.set('Authorization', `Bearer ${token}`) + return fetch(`${url}${path}`, { ...init, headers }) +} + +export async function backendWsUrl(path: string): Promise { + const { url, token } = await getBackendCredentials() + const ws = url.replace('http://', 'ws://') + const sep = path.includes('?') ? '&' : '?' + return `${ws}${path}${sep}token=${token}` +} diff --git a/frontend/views/editor/useGapGeneration.ts b/frontend/views/editor/useGapGeneration.ts index e105c7c6..a99a1d4b 100644 --- a/frontend/views/editor/useGapGeneration.ts +++ b/frontend/views/editor/useGapGeneration.ts @@ -3,6 +3,7 @@ import type { TimelineClip, Track, SubtitleClip, Asset } from '../../types/proje import { DEFAULT_COLOR_CORRECTION } from '../../types/project' import type { GenerationSettings } from '../../components/SettingsPanel' import { copyToAssetFolder } from '../../lib/asset-copy' +import { backendFetch } from '../../lib/backend' import { fileUrlToPath } from '../../lib/url-to-path' export interface UseGapGenerationParams { @@ -458,8 +459,7 @@ export function useGapGeneration({ } } - const backendUrl = await window.electronAPI.getBackendUrl() - const response = await fetch(`${backendUrl}/api/suggest-gap-prompt`, { + const response = await backendFetch('/api/suggest-gap-prompt', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ diff --git a/frontend/views/editor/useRegeneration.ts b/frontend/views/editor/useRegeneration.ts index 90132bee..aedf2d7b 100644 --- a/frontend/views/editor/useRegeneration.ts +++ b/frontend/views/editor/useRegeneration.ts @@ -2,6 +2,7 @@ import { useState, useCallback, useEffect } from 'react' import type { Asset, TimelineClip } from '../../types/project' import type { GenerationSettings } from '../../components/SettingsPanel' import { copyToAssetFolder } from '../../lib/asset-copy' +import { backendFetch } from '../../lib/backend' import { fileUrlToPath } from '../../lib/url-to-path' import { sanitizeForcedApiVideoSettings } from '../../lib/api-video-options' import { logger } from '../../lib/logger' @@ -208,8 +209,7 @@ export function useRegeneration(params: UseRegenerationParams) { if (framePath) { // Ask Gemini to describe the frame - const backendUrl = await window.electronAPI.getBackendUrl() - const resp = await fetch(`${backendUrl}/api/suggest-gap-prompt`, { + const resp = await backendFetch('/api/suggest-gap-prompt', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ diff --git a/frontend/vite-env.d.ts b/frontend/vite-env.d.ts index 433da967..a543d956 100644 --- a/frontend/vite-env.d.ts +++ b/frontend/vite-env.d.ts @@ -13,7 +13,7 @@ interface BackendHealthStatus { interface Window { electronAPI: { - getBackendUrl: () => Promise + getBackend: () => Promise<{ url: string; token: string }> getModelsPath: () => Promise readLocalFile: (filePath: string) => Promise<{ data: string; mimeType: string }> checkGpu: () => Promise<{ available: boolean; name?: string; vram?: number }>