From 21bb3406b69b35de33c6bf4c89b5af4f70a62737 Mon Sep 17 00:00:00 2001 From: lapy Date: Sun, 8 Mar 2026 21:10:12 +0000 Subject: [PATCH] Refactoring --- Dockerfile | 18 +- README.md | 1050 ++-- backend/architecture_profiles.py | 265 - backend/cuda_installer.py | 4864 ++++++++--------- backend/data_store.py | 216 + backend/database.py | 364 -- backend/gguf_reader.py | 20 +- backend/huggingface.py | 655 ++- backend/llama_manager.py | 435 +- backend/llama_swap_config.py | 213 +- backend/llama_swap_manager.py | 159 +- backend/lmdeploy_installer.py | 778 +-- backend/lmdeploy_manager.py | 1718 +++--- backend/main.py | 306 +- backend/param_registry.py | 117 + backend/presets.py | 170 - backend/progress_manager.py | 236 + backend/routes/llama_version_manager.py | 244 +- backend/routes/llama_versions.py | 486 +- backend/routes/lmdeploy.py | 133 +- backend/routes/models.py | 1795 +++--- backend/routes/status.py | 61 +- backend/routes/unified_monitoring.py | 75 - backend/smart_auto/__init__.py | 480 -- backend/smart_auto/architecture_config.py | 103 - backend/smart_auto/calculators.py | 402 -- backend/smart_auto/config_builder.py | 79 - backend/smart_auto/constants.py | 184 - backend/smart_auto/cpu_config.py | 200 - backend/smart_auto/generation_params.py | 87 - backend/smart_auto/gpu_config.py | 469 -- backend/smart_auto/kv_cache.py | 30 - backend/smart_auto/memory_estimator.py | 478 -- backend/smart_auto/model_metadata.py | 106 - backend/smart_auto/models.py | 216 - backend/smart_auto/moe_handler.py | 98 - backend/smart_auto/optimizer.py | 199 - backend/smart_auto/recommendations.py | 372 -- backend/tests/conftest.py | 8 + backend/tests/test_app_smoke.py | 58 + backend/tests/test_architecture_profiles.py | 70 - backend/unified_monitor.py | 716 --- backend/websocket_manager.py | 192 - docker-compose.cuda.yml | 9 +- docker-compose.rocm.yml | 40 - docker-compose.vulkan.yml | 37 - docker-entrypoint.sh | 11 +- frontend/src/App.vue | 22 +- frontend/src/components/BuildProgress.vue | 375 -- frontend/src/components/DownloadProgress.vue | 217 - frontend/src/components/GgufModelList.vue | 455 -- .../src/components/SafetensorsModelList.vue | 1917 ------- frontend/src/components/SliderInput.vue | 477 -- frontend/src/components/common/BaseCard.vue | 35 - frontend/src/components/common/BaseDialog.vue | 70 - .../src/components/common/BaseFormField.vue | 62 - frontend/src/components/common/LogViewer.vue | 451 -- .../src/components/common/ProgressTracker.vue | 124 + .../src/components/common/StatusBadge.vue | 99 - .../src/components/config/AdvancedSection.vue | 81 - .../config/AdvancedSettingsSection.vue | 83 - .../components/config/ConfigChangePreview.vue | 329 -- .../src/components/config/ConfigField.vue | 91 - .../src/components/config/ConfigSection.vue | 145 - .../src/components/config/ConfigWarnings.vue | 90 - .../src/components/config/ConfigWizard.vue | 764 --- .../config/ContextParamsSection.vue | 163 - .../components/config/CustomArgsSection.vue | 32 - frontend/src/components/config/EmptyState.vue | 145 - .../config/EssentialSettingsSection.vue | 182 - .../config/GenerationParamsSection.vue | 244 - .../src/components/config/MemoryMonitor.vue | 439 -- .../components/config/MemoryParamsSection.vue | 148 - .../components/config/ModelInfoSection.vue | 177 - .../src/components/config/OnboardingTour.vue | 397 -- .../components/config/PerformanceSection.vue | 393 -- .../src/components/config/QuickStartModal.vue | 352 -- .../src/components/config/SettingsTooltip.vue | 129 - frontend/src/components/layout/AppFooter.vue | 81 +- .../src/components/layout/AppNavigation.vue | 146 +- .../src/components/system/CudaInstaller.vue | 534 -- .../src/components/system/LMDeployTab.vue | 347 -- .../system/LlamaCppManager/BuildDialog.vue | 715 --- .../LlamaCppManager/CudaInstallDialog.vue | 274 - .../system/LlamaCppManager/ReleaseDialog.vue | 419 -- .../system/LlamaCppManager/UpdateInfo.vue | 135 - .../system/LlamaCppManager/VersionCard.vue | 250 - .../system/LlamaCppManager/VersionList.vue | 52 - .../src/components/system/LlamaCppTab.vue | 7 - frontend/src/components/system/SystemTab.vue | 748 --- .../src/components/system/VersionTable.vue | 131 + frontend/src/main.js | 7 - frontend/src/router/index.js | 10 +- frontend/src/stores/engines.js | 189 + frontend/src/stores/lmdeploy.js | 172 - frontend/src/stores/models.js | 559 +- frontend/src/stores/progress.js | 151 + frontend/src/stores/system.js | 178 - frontend/src/stores/websocket.js | 322 -- frontend/src/styles/_base.css | 3 +- frontend/src/styles/_components.css | 61 - frontend/src/styles/_variables.css | 2 + frontend/src/utils/formatting.js | 24 +- frontend/src/views/EnginesView.vue | 1059 ++++ frontend/src/views/LMDeploy.vue | 298 - frontend/src/views/LlamaCppManager.vue | 339 -- frontend/src/views/ModelConfig.vue | 4131 ++------------ frontend/src/views/ModelLibrary.vue | 1329 ++--- frontend/src/views/ModelSearch.vue | 2055 ++----- frontend/src/views/System.vue | 74 - frontend/src/views/SystemStatus.vue | 723 --- frontend/vite.config.js | 87 +- migrate_db.py | 404 -- migrate_gguf_storage.py | 96 - package-lock.json | 184 + package.json | 14 +- requirements.txt | 10 +- 117 files changed, 10642 insertions(+), 33388 deletions(-) delete mode 100644 backend/architecture_profiles.py create mode 100644 backend/data_store.py delete mode 100644 backend/database.py create mode 100644 backend/param_registry.py delete mode 100644 backend/presets.py create mode 100644 backend/progress_manager.py delete mode 100644 backend/routes/unified_monitoring.py delete mode 100644 backend/smart_auto/__init__.py delete mode 100644 backend/smart_auto/architecture_config.py delete mode 100644 backend/smart_auto/calculators.py delete mode 100644 backend/smart_auto/config_builder.py delete mode 100644 backend/smart_auto/constants.py delete mode 100644 backend/smart_auto/cpu_config.py delete mode 100644 backend/smart_auto/generation_params.py delete mode 100644 backend/smart_auto/gpu_config.py delete mode 100644 backend/smart_auto/kv_cache.py delete mode 100644 backend/smart_auto/memory_estimator.py delete mode 100644 backend/smart_auto/model_metadata.py delete mode 100644 backend/smart_auto/models.py delete mode 100644 backend/smart_auto/moe_handler.py delete mode 100644 backend/smart_auto/optimizer.py delete mode 100644 backend/smart_auto/recommendations.py create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_app_smoke.py delete mode 100644 backend/tests/test_architecture_profiles.py delete mode 100644 backend/unified_monitor.py delete mode 100644 backend/websocket_manager.py delete mode 100644 docker-compose.rocm.yml delete mode 100644 docker-compose.vulkan.yml delete mode 100644 frontend/src/components/BuildProgress.vue delete mode 100644 frontend/src/components/DownloadProgress.vue delete mode 100644 frontend/src/components/GgufModelList.vue delete mode 100644 frontend/src/components/SafetensorsModelList.vue delete mode 100644 frontend/src/components/SliderInput.vue delete mode 100644 frontend/src/components/common/BaseCard.vue delete mode 100644 frontend/src/components/common/BaseDialog.vue delete mode 100644 frontend/src/components/common/BaseFormField.vue delete mode 100644 frontend/src/components/common/LogViewer.vue create mode 100644 frontend/src/components/common/ProgressTracker.vue delete mode 100644 frontend/src/components/common/StatusBadge.vue delete mode 100644 frontend/src/components/config/AdvancedSection.vue delete mode 100644 frontend/src/components/config/AdvancedSettingsSection.vue delete mode 100644 frontend/src/components/config/ConfigChangePreview.vue delete mode 100644 frontend/src/components/config/ConfigField.vue delete mode 100644 frontend/src/components/config/ConfigSection.vue delete mode 100644 frontend/src/components/config/ConfigWarnings.vue delete mode 100644 frontend/src/components/config/ConfigWizard.vue delete mode 100644 frontend/src/components/config/ContextParamsSection.vue delete mode 100644 frontend/src/components/config/CustomArgsSection.vue delete mode 100644 frontend/src/components/config/EmptyState.vue delete mode 100644 frontend/src/components/config/EssentialSettingsSection.vue delete mode 100644 frontend/src/components/config/GenerationParamsSection.vue delete mode 100644 frontend/src/components/config/MemoryMonitor.vue delete mode 100644 frontend/src/components/config/MemoryParamsSection.vue delete mode 100644 frontend/src/components/config/ModelInfoSection.vue delete mode 100644 frontend/src/components/config/OnboardingTour.vue delete mode 100644 frontend/src/components/config/PerformanceSection.vue delete mode 100644 frontend/src/components/config/QuickStartModal.vue delete mode 100644 frontend/src/components/config/SettingsTooltip.vue delete mode 100644 frontend/src/components/system/CudaInstaller.vue delete mode 100644 frontend/src/components/system/LMDeployTab.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/BuildDialog.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/CudaInstallDialog.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/ReleaseDialog.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/UpdateInfo.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/VersionCard.vue delete mode 100644 frontend/src/components/system/LlamaCppManager/VersionList.vue delete mode 100644 frontend/src/components/system/LlamaCppTab.vue delete mode 100644 frontend/src/components/system/SystemTab.vue create mode 100644 frontend/src/components/system/VersionTable.vue create mode 100644 frontend/src/stores/engines.js delete mode 100644 frontend/src/stores/lmdeploy.js create mode 100644 frontend/src/stores/progress.js delete mode 100644 frontend/src/stores/system.js delete mode 100644 frontend/src/stores/websocket.js create mode 100644 frontend/src/views/EnginesView.vue delete mode 100644 frontend/src/views/LMDeploy.vue delete mode 100644 frontend/src/views/LlamaCppManager.vue delete mode 100644 frontend/src/views/System.vue delete mode 100644 frontend/src/views/SystemStatus.vue delete mode 100644 migrate_db.py delete mode 100644 migrate_gguf_storage.py diff --git a/Dockerfile b/Dockerfile index 13ee749..0b8bd9b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,8 +19,8 @@ RUN if [ -f package-lock.json ] || [ -f npm-shrinkwrap.json ]; then \ npm install; \ fi -# Copy frontend source (vite.config.js expects files at /build root, not /build/frontend) -COPY frontend/ ./ +# Copy frontend source using the same layout as the repo root scripts expect +COPY frontend/ ./frontend/ RUN npm run build ################################################################################ @@ -81,6 +81,8 @@ ENV DEBIAN_FRONTEND=noninteractive \ CUDA_VISIBLE_DEVICES=all \ NVIDIA_VISIBLE_DEVICES=all \ NVIDIA_DRIVER_CAPABILITIES=compute,utility \ + HF_HOME=/app/data/temp/.cache/huggingface \ + HUGGINGFACE_HUB_CACHE=/app/data/temp/.cache/huggingface/hub \ VENV_PATH=/opt/venv \ PYTHONPATH=/app \ PATH="/app/data/cuda/current/bin:${PATH}" \ @@ -97,7 +99,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ pkg-config \ ninja-build \ curl \ - wget \ ca-certificates \ # Core libs for Python packages libssl3 \ @@ -112,8 +113,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ocl-icd-libopencl1 \ libnuma1 \ pciutils \ - usbutils \ - lshw \ # Optional: ROCm (fails gracefully if unavailable) && (apt-get install -y --no-install-recommends rocminfo rocm-smi || echo "ROCm unavailable") \ && rm -rf /var/lib/apt/lists/* \ @@ -127,7 +126,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Ubuntu 24.04 may have a newer cmake, but we install a specific version for consistency # Placed here to avoid re-downloading when application code changes ARG CMAKE_VERSION=3.31.3 -RUN wget -qO /tmp/cmake.sh "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.sh" \ +RUN curl -fsSL "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-x86_64.sh" -o /tmp/cmake.sh \ && chmod +x /tmp/cmake.sh \ && /tmp/cmake.sh --skip-license --prefix=/usr/local \ && rm /tmp/cmake.sh \ @@ -135,7 +134,7 @@ RUN wget -qO /tmp/cmake.sh "https://github.com/Kitware/CMake/releases/download/v # Install llama-swap binary ARG LLAMA_SWAP_VERSION=179 -RUN wget -q https://github.com/mostlygeek/llama-swap/releases/download/v${LLAMA_SWAP_VERSION}/llama-swap_${LLAMA_SWAP_VERSION}_linux_amd64.tar.gz -O /tmp/llama-swap.tar.gz && \ +RUN curl -fsSL "https://github.com/mostlygeek/llama-swap/releases/download/v${LLAMA_SWAP_VERSION}/llama-swap_${LLAMA_SWAP_VERSION}_linux_amd64.tar.gz" -o /tmp/llama-swap.tar.gz && \ tar -xzf /tmp/llama-swap.tar.gz -C /tmp && \ mv /tmp/llama-swap /usr/local/bin/llama-swap && \ chmod +x /usr/local/bin/llama-swap && \ @@ -151,8 +150,7 @@ WORKDIR /app # Copy application code (excluding data via .dockerignore) COPY backend/ ./backend/ -COPY migrate_db.py ./ -COPY --from=frontend-builder /build/dist ./frontend/dist +COPY --from=frontend-builder /build/frontend/dist ./frontend/dist COPY frontend/public ./frontend/public # Copy and setup entrypoint script and CUDA environment helper @@ -170,7 +168,7 @@ RUN ln -sf /usr/bin/python3 /usr/bin/python # Create non-root user and data directory structure RUN useradd -m -s /bin/bash appuser && \ - mkdir -p /app/data/models /app/data/configs /app/data/logs /app/data/llama-cpp /app/data/temp && \ + mkdir -p /app/data/models /app/data/config /app/data/configs /app/data/logs /app/data/llama-cpp /app/data/temp/.cache/huggingface/hub && \ chown -R appuser:appuser /app && \ # Ensure entrypoint script is accessible to appuser chmod 755 /usr/local/bin/docker-entrypoint.sh diff --git a/README.md b/README.md index fed2b77..89567a4 100644 --- a/README.md +++ b/README.md @@ -1,525 +1,525 @@ -# llama.cpp Studio - -A professional AI model management platform for llama.cpp models and versions, designed for modern AI workflows with comprehensive GPU support (NVIDIA CUDA, AMD Vulkan/ROCm, Metal, OpenBLAS). - -## Features - -### Model Management -- **Search & Download**: Search HuggingFace for GGUF models with comprehensive metadata and size information for each quantization -- **Multi-Quantization Support**: Download and manage multiple quantizations of the same model -- **Model Library**: Manage downloaded models with start/stop/delete functionality -- **Smart Configuration**: Auto-generate optimal llama.cpp parameters based on GPU capabilities -- **VRAM Estimation**: Real-time VRAM usage estimation with warnings for memory constraints -- **Metadata Extraction**: Rich model information including parameters, architecture, license, tags, and more -- **Safetensors Runner**: Configure and run safetensors checkpoints via LMDeploy TurboMind with an OpenAI-compatible endpoint on port 2001 - -### llama.cpp Version Management -- **Release Installation**: Download and install pre-built binaries from GitHub releases -- **Source Building**: Build from source with optional patches from GitHub PRs -- **Custom Build Configuration**: Customize GPU backends (CUDA, Vulkan, Metal, OpenBLAS), build type, and compiler flags -- **Update Checking**: Check for updates to both releases and source code -- **Version Management**: Install, update, and delete multiple llama.cpp versions -- **Build Validation**: Automatic validation of built binaries to ensure they work correctly - -### GPU Support -- **Multi-GPU Support**: Automatic detection and configuration for NVIDIA, AMD, and other GPUs -- **NVIDIA CUDA**: Full support for CUDA compute capabilities, flash attention, and multi-GPU -- **AMD GPU Support**: Vulkan and ROCm support for AMD GPUs -- **Apple Metal**: Support for Apple Silicon GPUs -- **OpenBLAS**: CPU acceleration with optimized BLAS routines -- **VRAM Monitoring**: Real-time GPU memory usage and temperature monitoring -- **NVLink Detection**: Automatic detection of NVLink connections and topology analysis - -### Multi-Model Serving -- **Concurrent Execution**: Run multiple models simultaneously via llama-swap proxy -- **OpenAI-Compatible API**: Standard API format for easy integration -- **Port 2000**: All models served through a single unified endpoint -- **Automatic Lifecycle Management**: Seamless starting/stopping of models - -### Web Interface -- **Modern UI**: Vue.js 3 with PrimeVue components -- **Real-time Updates**: WebSocket-based progress tracking and system monitoring -- **Responsive Design**: Works on desktop and mobile devices -- **System Status**: CPU, memory, disk, and GPU monitoring -- **LMDeploy Installer**: Dedicated UI to install/remove LMDeploy at runtime with live logs -- **Dark Mode**: Built-in theme support - -## Quick Start - -### Using Docker Compose - -1. Clone the repository: -```bash -git clone -cd llama-cpp-studio -``` - -2. Start the application: -```bash -# CPU-only mode -docker-compose -f docker-compose.cpu.yml up -d - -# GPU mode (NVIDIA CUDA) -docker-compose -f docker-compose.cuda.yml up -d - -# Vulkan/AMD GPU mode -docker-compose -f docker-compose.vulkan.yml up -d - -# ROCm mode -docker-compose -f docker-compose.rocm.yml up -d -``` - -3. Access the web interface at `http://localhost:8080` - -### Published Container Images - -Prebuilt images are pushed to GitHub Container Registry whenever the `publish-docker` workflow runs. - -- `ghcr.io//llama-cpp-studio:latest` – standard image based on `ubuntu:22.04` with GPU tooling installed at runtime - -Pull the image from GHCR: - -```bash -docker pull ghcr.io//llama-cpp-studio:latest -``` - -### Manual Docker Build - -1. Build the image: -```bash -docker build -t llama-cpp-studio . -``` - -2. Run the container: -```bash -# With GPU support -docker run -d \ - --name llama-cpp-studio \ - --gpus all \ - -p 8080:8080 \ - -v ./data:/app/data \ - llama-cpp-studio - -# CPU-only -docker run -d \ - --name llama-cpp-studio \ - -p 8080:8080 \ - -v ./data:/app/data \ - llama-cpp-studio -``` - -## Configuration - -### Environment Variables -- `CUDA_VISIBLE_DEVICES`: GPU device selection (default: all, set to "" for CPU-only) -- `PORT`: Web server port (default: 8080) -- `HUGGINGFACE_API_KEY`: HuggingFace API token for model search and download (optional) -- `LMDEPLOY_BIN`: Override path to the `lmdeploy` CLI (default: `lmdeploy` on PATH) -- `LMDEPLOY_PORT`: Override the LMDeploy OpenAI port (default: 2001) - -### Volume Mounts -- `/app/data`: Persistent storage for models, configurations, and database - -### HuggingFace API Key - -To enable model search and download functionality, you need to set your HuggingFace API key. You can do this in several ways: - -#### Option 1: Docker Compose Environment Variable -Uncomment and set the token in your `docker-compose.yml`: -```yaml -environment: - - CUDA_VISIBLE_DEVICES=all - - HUGGINGFACE_API_KEY=your_huggingface_token_here -``` - -#### Option 2: .env File -Create a `.env` file in your project root: -```bash -HUGGINGFACE_API_KEY=your_huggingface_token_here -``` - -Then uncomment the `env_file` section in `docker-compose.yml`: -```yaml -env_file: - - .env -``` - -#### Option 3: System Environment Variable -Set the environment variable before running Docker Compose: -```bash -export HUGGINGFACE_API_KEY=your_huggingface_token_here -docker-compose up -d -``` - -#### Getting Your HuggingFace Token -1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens) -2. Create a new token with "Read" permissions -3. Copy the token and use it in one of the methods above - -**Note**: When the API key is set via environment variable, it cannot be modified through the web UI for security reasons. - -### GPU Requirements -- **NVIDIA**: NVIDIA GPU with CUDA support, NVIDIA Container Toolkit installed -- **AMD**: AMD GPU with Vulkan/ROCm drivers -- **Apple**: Apple Silicon with Metal support -- **CPU**: OpenBLAS for CPU acceleration (included in Docker image) -- Minimum 8GB VRAM recommended for most models - -### LMDeploy Requirement - -Safetensors execution relies on [LMDeploy](https://github.com/InternLM/lmdeploy), but the base image intentionally omits it to keep Docker builds lightweight (critical for GitHub Actions). Use the **LMDeploy** page in the UI to install or remove LMDeploy inside the running container—installs happen via `pip` at runtime and logs are streamed live. The installer creates a dedicated virtual environment under `/app/data/lmdeploy/venv`, so the package lives on the writable volume and can be removed by deleting that folder. If you are running outside the container, you can still `pip install lmdeploy` manually or point `LMDEPLOY_BIN` to a custom binary. The runtime uses `lmdeploy serve turbomind` to expose an OpenAI-compatible server on port `2001`. - -## Usage - -### 1. Model Management - -#### Search Models -- Use the search bar to find GGUF models on HuggingFace -- Filter by tags, parameters, or model name -- View comprehensive metadata including downloads, likes, tags, and file sizes - -#### Download Models -- Click download on any quantization to start downloading -- Multiple quantizations of the same model are automatically grouped -- Progress tracking with real-time updates via WebSocket - -#### Configure Models -- Set llama.cpp parameters or use Smart Auto for optimal settings -- View VRAM estimation before starting -- Configure context size, batch sizes, temperature, and more - -#### Run Models -- Start/stop models with one click -- Multiple models can run simultaneously -- View running instances and resource usage - -### 2. llama.cpp Versions - -#### Check Updates -- View available releases and source updates -- See commit history and release notes - -#### Install Release -- Download pre-built binaries from GitHub -- Automatic verification and installation - -#### Build from Source -- Compile from source with custom configuration -- Select GPU backends (CUDA, Vulkan, Metal, OpenBLAS) -- Configure build type (Release, Debug, RelWithDebInfo) -- Add custom CMake flags and compiler options -- Apply patches from GitHub PRs -- Automatic validation of built binaries - -#### Manage Versions -- Delete old versions to free up space -- View installation details and build configuration - -### 3. System Monitoring -- **Overview**: CPU, memory, disk, and GPU usage -- **GPU Details**: Individual GPU information and utilization -- **Running Instances**: Active model instances with resource usage -- **WebSocket**: Real-time updates for all metrics - -## Multi-Model Serving - -llama-cpp-studio uses llama-swap to serve multiple models simultaneously on port 2000. - -### Starting Models - -Simply start any model from the Model Library. All models run on port 2000 simultaneously. - -### OpenAI-Compatible API - -```bash -curl http://localhost:2000/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "llama-3-2-1b-instruct-iq2-xs", - "messages": [{"role": "user", "content": "Hello!"}] - }' -``` - -Model names are shown in System Status after starting a model. - -### Features - -- Multiple models run concurrently -- No loading time - instant switching between models -- Standard OpenAI API format -- Automatic lifecycle management -- Single unified endpoint - -### Troubleshooting - -- Check available models: `http://localhost:2000/v1/models` -- Check proxy health: `http://localhost:2000/health` -- View logs: `docker logs llama-cpp-studio` - -### LMDeploy TurboMind (Safetensors) - -- Run exactly one safetensors checkpoint at a time via LMDeploy -- Configure tensor/pipeline parallelism, context length, temperature, and other runtime flags from the Model Library -- Serves an OpenAI-compatible endpoint at `http://localhost:2001/v1/chat/completions` -- Install LMDeploy on demand from the LMDeploy page (or manually via `pip`) before starting safetensors runtimes -- Start/stop directly from the Safetensors panel; status is reported in System Status and the LMDeploy status chip - -## Build Customization - -### GPU Backends - -Enable specific GPU backends during source builds: - -- **CUDA**: NVIDIA GPU acceleration with cuBLAS -- **Vulkan**: AMD/Intel GPU acceleration with Vulkan compute -- **Metal**: Apple Silicon GPU acceleration -- **OpenBLAS**: CPU optimization with OpenBLAS routines - -### Build Configuration - -Customize your build with: - -- **Build Type**: Release (optimal), Debug (development), RelWithDebInfo -- **Custom CMake Flags**: Additional CMake configuration -- **Compiler Flags**: CFLAGS and CXXFLAGS for optimization -- **Git Patches**: Apply patches from GitHub PRs - -### Example Build Configuration - -```json -{ - "commit_sha": "master", - "patches": [ - "https://github.com/ggerganov/llama.cpp/pull/1234.patch" - ], - "build_config": { - "build_type": "Release", - "enable_cuda": true, - "enable_vulkan": false, - "enable_metal": false, - "enable_openblas": true, - "custom_cmake_args": "-DGGML_CUDA_CUBLAS=ON", - "cflags": "-O3 -march=native", - "cxxflags": "-O3 -march=native" - } -} -``` - -## Smart Auto Configuration - -The Smart Auto feature automatically generates optimal llama.cpp parameters based on: - -- **GPU Capabilities**: VRAM, compute capability, multi-GPU support -- **NVLink Topology**: Automatic detection and optimization for NVLink clusters -- **Model Architecture**: Detected from model name (Llama, Mistral, etc.) -- **Available Resources**: CPU cores, memory, disk space -- **Performance Optimization**: Flash attention, tensor parallelism, batch sizing - -### NVLink Optimization Strategies - -The system automatically detects NVLink topology and applies appropriate strategies: - -- **Unified NVLink**: All GPUs connected via NVLink - uses aggressive tensor splitting and higher parallelism -- **Clustered NVLink**: Multiple NVLink clusters - optimizes for the largest cluster -- **Partial NVLink**: Some GPUs connected via NVLink - uses hybrid approach -- **PCIe Only**: No NVLink detected - uses conservative PCIe-based configuration - -### Supported Parameters -- Context size, batch sizes, GPU layers -- Temperature, top-k, top-p, repeat penalty -- CPU threads, parallel sequences -- RoPE scaling, YaRN factors -- Multi-GPU tensor splitting -- Custom arguments via YAML config - -## API Endpoints - -### Models -- `GET /api/models` - List all models -- `POST /api/models/search` - Search HuggingFace -- `POST /api/models/download` - Download model -- `GET /api/models/{id}/config` - Get model configuration -- `PUT /api/models/{id}/config` - Update configuration -- `POST /api/models/{id}/auto-config` - Generate smart configuration -- `POST /api/models/{id}/start` - Start model -- `POST /api/models/{id}/stop` - Stop model -- `DELETE /api/models/{id}` - Delete model -- `GET /api/models/safetensors/{model_id}/lmdeploy/config` - Get LMDeploy config for a safetensors download -- `PUT /api/models/safetensors/{model_id}/lmdeploy/config` - Update LMDeploy config -- `POST /api/models/safetensors/{model_id}/lmdeploy/start` - Start LMDeploy runtime -- `POST /api/models/safetensors/{model_id}/lmdeploy/stop` - Stop LMDeploy runtime -- `GET /api/models/safetensors/lmdeploy/status` - LMDeploy manager status - -### LMDeploy Installer -- `GET /api/lmdeploy/status` - Installer status (version, binary path, current operation) -- `POST /api/lmdeploy/install` - Install LMDeploy via pip at runtime -- `POST /api/lmdeploy/remove` - Remove LMDeploy from the runtime environment -- `GET /api/lmdeploy/logs` - Tail the LMDeploy installer log - -### llama.cpp Versions -- `GET /api/llama-versions` - List installed versions -- `GET /api/llama-versions/check-updates` - Check for updates -- `GET /api/llama-versions/build-capabilities` - Get build capabilities -- `POST /api/llama-versions/install-release` - Install release -- `POST /api/llama-versions/build-source` - Build from source -- `DELETE /api/llama-versions/{id}` - Delete version - -### System -- `GET /api/status` - System status -- `GET /api/gpu-info` - GPU information -- `WebSocket /ws` - Real-time updates - -## Database Migration - -If upgrading from an older version, you may need to migrate your database: - -```bash -# Run migration to support multi-quantization -python migrate_db.py -``` - -## Troubleshooting - -### Common Issues - -1. **GPU Not Detected** - - Ensure NVIDIA Container Toolkit is installed (for NVIDIA) - - Check `nvidia-smi` output - - Verify `--gpus all` flag in docker run - - For AMD: Check Vulkan/ROCm drivers - -2. **Build Failures** - - Check CUDA version compatibility (for NVIDIA) - - Ensure sufficient disk space (at least 10GB free) - - Verify internet connectivity for downloads - - For Vulkan builds: Ensure `glslang-tools` is installed - - Check build logs for specific errors - -3. **Memory Issues** - - Use Smart Auto configuration - - Reduce context size or batch size - - Enable memory mapping - - Check available system RAM and VRAM - -4. **Model Download Failures** - - Check HuggingFace connectivity - - Verify model exists and is public - - Ensure sufficient disk space - - Set HUGGINGFACE_API_KEY if using private models - -5. **Validation Failed** - - Binary exists and is executable - - Binary runs `--version` successfully - - Output contains "llama" or "version:" string - -### Logs -- Application logs: `docker logs llama-cpp-studio` -- Model logs: Available in the web interface -- Build logs: Shown during source compilation -- WebSocket logs: DEBUG level for detailed connection info - -## Development - -### Backend -- FastAPI with async support -- SQLAlchemy for database management -- WebSocket for real-time updates -- Background tasks for long operations -- Llama-swap integration for multi-model serving - -### Frontend -- Vue.js 3 with Composition API -- PrimeVue component library -- Pinia for state management -- Vite for build tooling -- Dark mode support - -### Database -- SQLite for simplicity -- Models, versions, and instances tracking -- Configuration storage -- Multi-quantization support - -## Memory Estimation Model - -The studio’s capacity planning tooling is grounded in a three-component model for llama.cpp that provides a conservative upper bound on peak memory usage. - -- **Formula**: `M_total = M_weights + M_kv + M_compute` -- **Model weights (`M_weights`)**: Treat the GGUF file size as the ground truth. When `--no-mmap` is disabled (default), the file is memory-mapped so only referenced pages touch physical RAM, but the virtual footprint still equals the file size. -- **KV cache (`M_kv`)**: Uses the GQA-aware formula `n_ctx × N_layers × N_head_kv × (N_embd / N_head) × (p_a_k + p_a_v)`, where `p_a_*` are the bytes-per-value chosen via `--cache-type-k` / `--cache-type-v`. -- **Compute buffers (`M_compute`)**: Approximate as a fixed CUDA overhead (~550 MB) plus a scratch buffer that scales with micro-batch size (`n_ubatch × 0.5 MB` by default). - -### RAM vs VRAM Allocation - -- `-ngl 0` (CPU-only): All components stay in RAM. -- `-ngl > 0` (hybrid/full GPU): Model weights split by layer between RAM and VRAM, while **both `M_kv` and `M_compute` move entirely to VRAM**—the “VRAM trap”. -- Full offload avoids PCIe contention; hybrid splits suffer a “performance cliff” because activations bounce between CPU and GPU. - -### Optimization Strategy - -1. Attempt full offload first (best throughput). If weights + compute fit, deduce `n_ctx_max` from remaining VRAM budget. -2. When full offload fails, search decreasing `n_ngl` values that satisfy RAM limits while maximizing context length, accepting the hybrid performance penalty. -3. Iterate quantization choices to find the smallest model that still enables full offload on the target hardware profile. - -## Smart Auto Module Report - -The Smart Auto subsystem applies the model above to recommend llama.cpp launch parameters. Priority 1 fixes are complete, eliminating prior memory underestimation bugs. - -- **Resolutions**: - - Corrected KV cache math to respect grouped-query attention head counts. - - Removed the dangerous 0.30 multiplier on cache size; estimates now use real memory. - - Ensured KV cache/compute buffers migrate to VRAM whenever GPU layers are in play. - - Modeled compute overhead as `550 MB + 0.5 MB × n_ubatch`. - - Improved GPU layer estimation using GGUF file size with a 20 % safety buffer. -- **Open improvements**: - - Reorder calculations so KV cache quantization feeds batch/context sizing directly. - - Replace remaining heuristics with joint optimization across `n_ctx`, `n_ngl`, and `n_ubatch`. - -### Recommended Validation - -- Benchmark against known examples (e.g., 13B @ 2 048 tokens → ~1.6 GB KV cache, 7B @ 4 096 tokens → ~6 GB total). -- Stress-test large contexts, tight VRAM scenarios, MoE models, and hybrid modes. -- Expand automated regression coverage around the estimator and Smart Auto flows. - -## Memory Estimation Test Results - -Empirical testing with `Llama-3.2-1B-Instruct.IQ1_M` demonstrates that the estimator acts as a safe upper bound. - -- **Setup**: `n_ctx ≈ 35 K`, batch 32, CPU-only run. -- **Estimated peak**: 4.99 GB (weights 394 MB, KV cache 4.34 GB, batch 12 MB, llama.cpp overhead 256 MB). -- **Observed deltas**: - - With mmap enabled: ~608 MB (11.9 % of estimate). Lower usage is expected because the KV cache grows as context fills and weights are paged on demand. - - With `--no-mmap`: ~1.16 GB (23 % of estimate). Weights load fully, but KV cache still expands progressively. -- **Takeaways**: - - Estimates intentionally err on the high side to prevent OOM once the context window reaches capacity. - - Divergence between virtual and physical usage stems from memory mapping and lazy KV cache allocation. - - Additional GPU-focused measurements and long session traces are encouraged to correlate VRAM predictions with reality. - -## License - -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. - -Copyright (c) 2024 llama.cpp Studio - -## Contributing - -1. Fork the repository -2. Create a feature branch -3. Make your changes -4. Add tests if applicable -5. Submit a pull request - -## Support - -For issues and questions: -- Create an issue on GitHub -- Check the troubleshooting section -- Review the API documentation - -## Acknowledgments - -- [llama.cpp](https://github.com/ggerganov/llama.cpp) - The core inference engine -- [llama-swap](https://github.com/mostlygeek/llama-swap) - Multi-model serving proxy -- [HuggingFace](https://huggingface.co) - Model hosting and search -- [Vue.js](https://vuejs.org) - Frontend framework -- [FastAPI](https://fastapi.tiangolo.com) - Backend framework +# llama.cpp Studio + +A professional AI model management platform for llama.cpp models and versions, designed for modern AI workflows with comprehensive GPU support (NVIDIA CUDA, AMD Vulkan/ROCm, Metal, OpenBLAS). + +## Features + +### Model Management +- **Search & Download**: Search HuggingFace for GGUF models with comprehensive metadata and size information for each quantization +- **Multi-Quantization Support**: Download and manage multiple quantizations of the same model +- **Model Library**: Manage downloaded models with start/stop/delete functionality +- **Smart Configuration**: Auto-generate optimal llama.cpp parameters based on GPU capabilities +- **VRAM Estimation**: Real-time VRAM usage estimation with warnings for memory constraints +- **Metadata Extraction**: Rich model information including parameters, architecture, license, tags, and more +- **Safetensors Runner**: Configure and run safetensors checkpoints via LMDeploy TurboMind with an OpenAI-compatible endpoint on port 2001 + +### llama.cpp Version Management +- **Release Installation**: Download and install pre-built binaries from GitHub releases +- **Source Building**: Build from source with optional patches from GitHub PRs +- **Custom Build Configuration**: Customize GPU backends (CUDA, Vulkan, Metal, OpenBLAS), build type, and compiler flags +- **Update Checking**: Check for updates to both releases and source code +- **Version Management**: Install, update, and delete multiple llama.cpp versions +- **Build Validation**: Automatic validation of built binaries to ensure they work correctly + +### GPU Support +- **Multi-GPU Support**: Automatic detection and configuration for NVIDIA, AMD, and other GPUs +- **NVIDIA CUDA**: Full support for CUDA compute capabilities, flash attention, and multi-GPU +- **AMD GPU Support**: Vulkan and ROCm support for AMD GPUs +- **Apple Metal**: Support for Apple Silicon GPUs +- **OpenBLAS**: CPU acceleration with optimized BLAS routines +- **VRAM Monitoring**: Real-time GPU memory usage and temperature monitoring +- **NVLink Detection**: Automatic detection of NVLink connections and topology analysis + +### Multi-Model Serving +- **Concurrent Execution**: Run multiple models simultaneously via llama-swap proxy +- **OpenAI-Compatible API**: Standard API format for easy integration +- **Port 2000**: All models served through a single unified endpoint +- **Automatic Lifecycle Management**: Seamless starting/stopping of models + +### Web Interface +- **Modern UI**: Vue.js 3 with PrimeVue components +- **Real-time Updates**: SSE-based progress tracking and system monitoring +- **Responsive Design**: Works on desktop and mobile devices +- **System Status**: CPU, memory, disk, and GPU monitoring +- **LMDeploy Installer**: Dedicated UI to install/remove LMDeploy at runtime with live logs +- **Dark Mode**: Built-in theme support + +## Quick Start + +### Using Docker Compose + +1. Clone the repository: +```bash +git clone +cd llama-cpp-studio +``` + +2. Start the application: +```bash +# CPU-only mode +docker-compose -f docker-compose.cpu.yml up -d + +# GPU mode (NVIDIA CUDA) +docker-compose -f docker-compose.cuda.yml up -d + +# Vulkan/AMD GPU mode +docker-compose -f docker-compose.vulkan.yml up -d + +# ROCm mode +docker-compose -f docker-compose.rocm.yml up -d +``` + +3. Access the web interface at `http://localhost:8080` + +### Published Container Images + +Prebuilt images are pushed to GitHub Container Registry whenever the `publish-docker` workflow runs. + +- `ghcr.io//llama-cpp-studio:latest` – standard image based on `ubuntu:22.04` with GPU tooling installed at runtime + +Pull the image from GHCR: + +```bash +docker pull ghcr.io//llama-cpp-studio:latest +``` + +### Manual Docker Build + +1. Build the image: +```bash +docker build -t llama-cpp-studio . +``` + +2. Run the container: +```bash +# With GPU support +docker run -d \ + --name llama-cpp-studio \ + --gpus all \ + -p 8080:8080 \ + -v ./data:/app/data \ + llama-cpp-studio + +# CPU-only +docker run -d \ + --name llama-cpp-studio \ + -p 8080:8080 \ + -v ./data:/app/data \ + llama-cpp-studio +``` + +## Configuration + +### Environment Variables +- `CUDA_VISIBLE_DEVICES`: GPU device selection (default: all, set to "" for CPU-only) +- `PORT`: Web server port (default: 8080) +- `HUGGINGFACE_API_KEY`: HuggingFace API token for model search and download (optional) +- `LMDEPLOY_BIN`: Override path to the `lmdeploy` CLI (default: `lmdeploy` on PATH) +- `LMDEPLOY_PORT`: Override the LMDeploy OpenAI port (default: 2001) + +### Volume Mounts +- `/app/data`: Persistent storage for models, configurations, and database + +### HuggingFace API Key + +To enable model search and download functionality, you need to set your HuggingFace API key. You can do this in several ways: + +#### Option 1: Docker Compose Environment Variable +Uncomment and set the token in your `docker-compose.yml`: +```yaml +environment: + - CUDA_VISIBLE_DEVICES=all + - HUGGINGFACE_API_KEY=your_huggingface_token_here +``` + +#### Option 2: .env File +Create a `.env` file in your project root: +```bash +HUGGINGFACE_API_KEY=your_huggingface_token_here +``` + +Then uncomment the `env_file` section in `docker-compose.yml`: +```yaml +env_file: + - .env +``` + +#### Option 3: System Environment Variable +Set the environment variable before running Docker Compose: +```bash +export HUGGINGFACE_API_KEY=your_huggingface_token_here +docker-compose up -d +``` + +#### Getting Your HuggingFace Token +1. Go to [HuggingFace Settings](https://huggingface.co/settings/tokens) +2. Create a new token with "Read" permissions +3. Copy the token and use it in one of the methods above + +**Note**: When the API key is set via environment variable, it cannot be modified through the web UI for security reasons. + +### GPU Requirements +- **NVIDIA**: NVIDIA GPU with CUDA support, NVIDIA Container Toolkit installed +- **AMD**: AMD GPU with Vulkan/ROCm drivers +- **Apple**: Apple Silicon with Metal support +- **CPU**: OpenBLAS for CPU acceleration (included in Docker image) +- Minimum 8GB VRAM recommended for most models + +### LMDeploy Requirement + +Safetensors execution relies on [LMDeploy](https://github.com/InternLM/lmdeploy), but the base image intentionally omits it to keep Docker builds lightweight (critical for GitHub Actions). Use the **LMDeploy** page in the UI to install or remove LMDeploy inside the running container—installs happen via `pip` at runtime and logs are streamed live. The installer creates a dedicated virtual environment under `/app/data/lmdeploy/venv`, so the package lives on the writable volume and can be removed by deleting that folder. If you are running outside the container, you can still `pip install lmdeploy` manually or point `LMDEPLOY_BIN` to a custom binary. The runtime uses `lmdeploy serve turbomind` to expose an OpenAI-compatible server on port `2001`. + +## Usage + +### 1. Model Management + +#### Search Models +- Use the search bar to find GGUF models on HuggingFace +- Filter by tags, parameters, or model name +- View comprehensive metadata including downloads, likes, tags, and file sizes + +#### Download Models +- Click download on any quantization to start downloading +- Multiple quantizations of the same model are automatically grouped +- Progress tracking with real-time updates via SSE + +#### Configure Models +- Set llama.cpp parameters or use Smart Auto for optimal settings +- View VRAM estimation before starting +- Configure context size, batch sizes, temperature, and more + +#### Run Models +- Start/stop models with one click +- Multiple models can run simultaneously +- View running instances and resource usage + +### 2. llama.cpp Versions + +#### Check Updates +- View available releases and source updates +- See commit history and release notes + +#### Install Release +- Download pre-built binaries from GitHub +- Automatic verification and installation + +#### Build from Source +- Compile from source with custom configuration +- Select GPU backends (CUDA, Vulkan, Metal, OpenBLAS) +- Configure build type (Release, Debug, RelWithDebInfo) +- Add custom CMake flags and compiler options +- Apply patches from GitHub PRs +- Automatic validation of built binaries + +#### Manage Versions +- Delete old versions to free up space +- View installation details and build configuration + +### 3. System Monitoring +- **Overview**: CPU, memory, disk, and GPU usage +- **GPU Details**: Individual GPU information and utilization +- **Running Instances**: Active model instances with resource usage +- **SSE**: Real-time updates for all metrics + +## Multi-Model Serving + +llama-cpp-studio uses llama-swap to serve multiple models simultaneously on port 2000. + +### Starting Models + +Simply start any model from the Model Library. All models run on port 2000 simultaneously. + +### OpenAI-Compatible API + +```bash +curl http://localhost:2000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama-3-2-1b-instruct-iq2-xs", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +Model names are shown in System Status after starting a model. + +### Features + +- Multiple models run concurrently +- No loading time - instant switching between models +- Standard OpenAI API format +- Automatic lifecycle management +- Single unified endpoint + +### Troubleshooting + +- Check available models: `http://localhost:2000/v1/models` +- Check proxy health: `http://localhost:2000/health` +- View logs: `docker logs llama-cpp-studio` + +### LMDeploy TurboMind (Safetensors) + +- Run exactly one safetensors checkpoint at a time via LMDeploy +- Configure tensor/pipeline parallelism, context length, temperature, and other runtime flags from the Model Library +- Serves an OpenAI-compatible endpoint at `http://localhost:2001/v1/chat/completions` +- Install LMDeploy on demand from the LMDeploy page (or manually via `pip`) before starting safetensors runtimes +- Start/stop directly from the Safetensors panel; status is reported in System Status and the LMDeploy status chip + +## Build Customization + +### GPU Backends + +Enable specific GPU backends during source builds: + +- **CUDA**: NVIDIA GPU acceleration with cuBLAS +- **Vulkan**: AMD/Intel GPU acceleration with Vulkan compute +- **Metal**: Apple Silicon GPU acceleration +- **OpenBLAS**: CPU optimization with OpenBLAS routines + +### Build Configuration + +Customize your build with: + +- **Build Type**: Release (optimal), Debug (development), RelWithDebInfo +- **Custom CMake Flags**: Additional CMake configuration +- **Compiler Flags**: CFLAGS and CXXFLAGS for optimization +- **Git Patches**: Apply patches from GitHub PRs + +### Example Build Configuration + +```json +{ + "commit_sha": "master", + "patches": [ + "https://github.com/ggerganov/llama.cpp/pull/1234.patch" + ], + "build_config": { + "build_type": "Release", + "enable_cuda": true, + "enable_vulkan": false, + "enable_metal": false, + "enable_openblas": true, + "custom_cmake_args": "-DGGML_CUDA_CUBLAS=ON", + "cflags": "-O3 -march=native", + "cxxflags": "-O3 -march=native" + } +} +``` + +## Smart Auto Configuration + +The Smart Auto feature automatically generates optimal llama.cpp parameters based on: + +- **GPU Capabilities**: VRAM, compute capability, multi-GPU support +- **NVLink Topology**: Automatic detection and optimization for NVLink clusters +- **Model Architecture**: Detected from model name (Llama, Mistral, etc.) +- **Available Resources**: CPU cores, memory, disk space +- **Performance Optimization**: Flash attention, tensor parallelism, batch sizing + +### NVLink Optimization Strategies + +The system automatically detects NVLink topology and applies appropriate strategies: + +- **Unified NVLink**: All GPUs connected via NVLink - uses aggressive tensor splitting and higher parallelism +- **Clustered NVLink**: Multiple NVLink clusters - optimizes for the largest cluster +- **Partial NVLink**: Some GPUs connected via NVLink - uses hybrid approach +- **PCIe Only**: No NVLink detected - uses conservative PCIe-based configuration + +### Supported Parameters +- Context size, batch sizes, GPU layers +- Temperature, top-k, top-p, repeat penalty +- CPU threads, parallel sequences +- RoPE scaling, YaRN factors +- Multi-GPU tensor splitting +- Custom arguments via YAML config + +## API Endpoints + +### Models +- `GET /api/models` - List all models +- `POST /api/models/search` - Search HuggingFace +- `POST /api/models/download` - Download model +- `GET /api/models/{id}/config` - Get model configuration +- `PUT /api/models/{id}/config` - Update configuration +- `POST /api/models/{id}/auto-config` - Generate smart configuration +- `POST /api/models/{id}/start` - Start model +- `POST /api/models/{id}/stop` - Stop model +- `DELETE /api/models/{id}` - Delete model +- `GET /api/models/safetensors/{model_id}/lmdeploy/config` - Get LMDeploy config for a safetensors download +- `PUT /api/models/safetensors/{model_id}/lmdeploy/config` - Update LMDeploy config +- `POST /api/models/safetensors/{model_id}/lmdeploy/start` - Start LMDeploy runtime +- `POST /api/models/safetensors/{model_id}/lmdeploy/stop` - Stop LMDeploy runtime +- `GET /api/models/safetensors/lmdeploy/status` - LMDeploy manager status + +### LMDeploy Installer +- `GET /api/lmdeploy/status` - Installer status (version, binary path, current operation) +- `POST /api/lmdeploy/install` - Install LMDeploy via pip at runtime +- `POST /api/lmdeploy/remove` - Remove LMDeploy from the runtime environment +- `GET /api/lmdeploy/logs` - Tail the LMDeploy installer log + +### llama.cpp Versions +- `GET /api/llama-versions` - List installed versions +- `GET /api/llama-versions/check-updates` - Check for updates +- `GET /api/llama-versions/build-capabilities` - Get build capabilities +- `POST /api/llama-versions/install-release` - Install release +- `POST /api/llama-versions/build-source` - Build from source +- `DELETE /api/llama-versions/{id}` - Delete version + +### System +- `GET /api/status` - System status +- `GET /api/gpu-info` - GPU information +- `GET /api/events` - Server-Sent Events for real-time updates + +## Database Migration + +If upgrading from an older version, you may need to migrate your database: + +```bash +# Run migration to support multi-quantization +python migrate_db.py +``` + +## Troubleshooting + +### Common Issues + +1. **GPU Not Detected** + - Ensure NVIDIA Container Toolkit is installed (for NVIDIA) + - Check `nvidia-smi` output + - Verify `--gpus all` flag in docker run + - For AMD: Check Vulkan/ROCm drivers + +2. **Build Failures** + - Check CUDA version compatibility (for NVIDIA) + - Ensure sufficient disk space (at least 10GB free) + - Verify internet connectivity for downloads + - For Vulkan builds: Ensure `glslang-tools` is installed + - Check build logs for specific errors + +3. **Memory Issues** + - Use Smart Auto configuration + - Reduce context size or batch size + - Enable memory mapping + - Check available system RAM and VRAM + +4. **Model Download Failures** + - Check HuggingFace connectivity + - Verify model exists and is public + - Ensure sufficient disk space + - Set HUGGINGFACE_API_KEY if using private models + +5. **Validation Failed** + - Binary exists and is executable + - Binary runs `--version` successfully + - Output contains "llama" or "version:" string + +### Logs +- Application logs: `docker logs llama-cpp-studio` +- Model logs: Available in the web interface +- Build logs: Shown during source compilation +- SSE event stream: GET /api/events for real-time progress and status + +## Development + +### Backend +- FastAPI with async support +- YAML-backed data store (models, engines, settings) +- SSE (GET /api/events) for real-time updates +- Background tasks for long operations +- Llama-swap integration for multi-model serving + +### Frontend +- Vue.js 3 with Composition API +- PrimeVue component library +- Pinia for state management +- Vite for build tooling +- Dark mode support + +### Testing +- Backend tests: `pytest` (install deps first: `pip install -r requirements.txt pytest pytest-asyncio`) +- Run from repo root: `PYTHONPATH=. pytest backend/tests/ -v` +- Smoke tests in `backend/tests/test_app_smoke.py` verify the app starts and key API routes respond (`/api/status`, `/api/models/param-registry`, `/api/models/`, `/api/events`) +- LMDeploy installer and config validation tests in `backend/tests/test_lmdeploy_*.py` + +## Memory Estimation Model + +The studio’s capacity planning tooling is grounded in a three-component model for llama.cpp that provides a conservative upper bound on peak memory usage. + +- **Formula**: `M_total = M_weights + M_kv + M_compute` +- **Model weights (`M_weights`)**: Treat the GGUF file size as the ground truth. When `--no-mmap` is disabled (default), the file is memory-mapped so only referenced pages touch physical RAM, but the virtual footprint still equals the file size. +- **KV cache (`M_kv`)**: Uses the GQA-aware formula `n_ctx × N_layers × N_head_kv × (N_embd / N_head) × (p_a_k + p_a_v)`, where `p_a_*` are the bytes-per-value chosen via `--cache-type-k` / `--cache-type-v`. +- **Compute buffers (`M_compute`)**: Approximate as a fixed CUDA overhead (~550 MB) plus a scratch buffer that scales with micro-batch size (`n_ubatch × 0.5 MB` by default). + +### RAM vs VRAM Allocation + +- `-ngl 0` (CPU-only): All components stay in RAM. +- `-ngl > 0` (hybrid/full GPU): Model weights split by layer between RAM and VRAM, while **both `M_kv` and `M_compute` move entirely to VRAM**—the “VRAM trap”. +- Full offload avoids PCIe contention; hybrid splits suffer a “performance cliff” because activations bounce between CPU and GPU. + +### Optimization Strategy + +1. Attempt full offload first (best throughput). If weights + compute fit, deduce `n_ctx_max` from remaining VRAM budget. +2. When full offload fails, search decreasing `n_ngl` values that satisfy RAM limits while maximizing context length, accepting the hybrid performance penalty. +3. Iterate quantization choices to find the smallest model that still enables full offload on the target hardware profile. + +## Smart Auto Module Report + +The Smart Auto subsystem applies the model above to recommend llama.cpp launch parameters. Priority 1 fixes are complete, eliminating prior memory underestimation bugs. + +- **Resolutions**: + - Corrected KV cache math to respect grouped-query attention head counts. + - Removed the dangerous 0.30 multiplier on cache size; estimates now use real memory. + - Ensured KV cache/compute buffers migrate to VRAM whenever GPU layers are in play. + - Modeled compute overhead as `550 MB + 0.5 MB × n_ubatch`. + - Improved GPU layer estimation using GGUF file size with a 20 % safety buffer. +- **Open improvements**: + - Reorder calculations so KV cache quantization feeds batch/context sizing directly. + - Replace remaining heuristics with joint optimization across `n_ctx`, `n_ngl`, and `n_ubatch`. + +### Recommended Validation + +- Benchmark against known examples (e.g., 13B @ 2 048 tokens → ~1.6 GB KV cache, 7B @ 4 096 tokens → ~6 GB total). +- Stress-test large contexts, tight VRAM scenarios, MoE models, and hybrid modes. +- Expand automated regression coverage around the estimator and Smart Auto flows. + +## Memory Estimation Test Results + +Empirical testing with `Llama-3.2-1B-Instruct.IQ1_M` demonstrates that the estimator acts as a safe upper bound. + +- **Setup**: `n_ctx ≈ 35 K`, batch 32, CPU-only run. +- **Estimated peak**: 4.99 GB (weights 394 MB, KV cache 4.34 GB, batch 12 MB, llama.cpp overhead 256 MB). +- **Observed deltas**: + - With mmap enabled: ~608 MB (11.9 % of estimate). Lower usage is expected because the KV cache grows as context fills and weights are paged on demand. + - With `--no-mmap`: ~1.16 GB (23 % of estimate). Weights load fully, but KV cache still expands progressively. +- **Takeaways**: + - Estimates intentionally err on the high side to prevent OOM once the context window reaches capacity. + - Divergence between virtual and physical usage stems from memory mapping and lazy KV cache allocation. + - Additional GPU-focused measurements and long session traces are encouraged to correlate VRAM predictions with reality. + +## License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +Copyright (c) 2024 llama.cpp Studio + +## Contributing + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Add tests if applicable +5. Submit a pull request + +## Support + +For issues and questions: +- Create an issue on GitHub +- Check the troubleshooting section +- Review the API documentation + +## Acknowledgments + +- [llama.cpp](https://github.com/ggerganov/llama.cpp) - The core inference engine +- [llama-swap](https://github.com/mostlygeek/llama-swap) - Multi-model serving proxy +- [HuggingFace](https://huggingface.co) - Model hosting and search +- [Vue.js](https://vuejs.org) - Frontend framework +- [FastAPI](https://fastapi.tiangolo.com) - Backend framework diff --git a/backend/architecture_profiles.py b/backend/architecture_profiles.py deleted file mode 100644 index 319e0f1..0000000 --- a/backend/architecture_profiles.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -Architecture-aware profiles for interpreting GGUF metadata. - -Each profile is responsible for turning raw GGUF metadata into: -- block_count: architectural depth (number of transformer blocks) -- effective_layer_count: layers llama.cpp can offload (including output layer) -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -from backend.logging_config import get_logger - -logger = get_logger(__name__) - - -@dataclass(frozen=True) -class LayerConfig: - """Standardized output for layer calculations.""" - - block_count: int - effective_layer_count: int - - -# --- Helper Utilities --- - - -def _get_first_valid_int( - metadata: Dict[str, Any], keys: List[str], default: Optional[int] = None -) -> Optional[int]: - """ - Scans metadata for the first key that contains a valid >0 number. - """ - for key in keys: - val = metadata.get(key) - # GGUF metadata values can be various numeric types - if isinstance(val, (int, float)) and val > 0: - return int(val) - return default - - -# --- Registry System --- - -_PROFILE_REGISTRY: List["ArchitectureProfile"] = [] - - -def register_profile(cls: Type["ArchitectureProfile"]) -> Type["ArchitectureProfile"]: - """ - Registers a profile class. - Profiles are stored and later sorted by specificity (longest name match first). - """ - _PROFILE_REGISTRY.append(cls()) - return cls - - -# --- Base Class --- - - -class ArchitectureProfile(ABC): - """Base class for architecture-specific GGUF metadata interpretation.""" - - def __init__(self, names: Tuple[str, ...]): - self.names = names - - def matches(self, architecture: str) -> bool: - """ - Checks if the architecture string matches this profile. - """ - arch = architecture.lower() - # "llama" should match "llama", "llama-2", etc. - return any(arch == n or arch.startswith(n) for n in self.names) - - def compute( - self, - metadata: Dict[str, Any], - base_block_count: int, - ) -> LayerConfig: - """ - Public interface that wraps the calculation with standardized logging. - """ - result = self._calculate_layers(metadata, base_block_count) - - logger.debug( - "%s: matched. block_count=%s, effective_layer_count=%s (base=%s)", - self.__class__.__name__, - result.block_count, - result.effective_layer_count, - base_block_count, - ) - return result - - @abstractmethod - def _calculate_layers( - self, - metadata: Dict[str, Any], - base_block_count: int, - ) -> LayerConfig: - """Implementation specific logic.""" - raise NotImplementedError - - -# --- Standard Profile (Handles 95% of cases) --- - - -class StandardDecoderProfile(ArchitectureProfile): - """ - Generic profile for standard decoder-only models (Llama, Qwen, DeepSeek, etc.). - - Logic: - 1. Look for specific keys (names.block_count, names.n_layer). - 2. Fallback to base_block_count. - 3. Effective layers = block_count + 1 (for the output head). - - Note on MoEs: Even for MoE models (Qwen2-MoE, DeepSeek-V2), llama.cpp - counts the 'offloadable layers' as the number of transformer blocks. - Expert offloading is managed internally within those blocks. - """ - - def _calculate_layers( - self, metadata: Dict[str, Any], base_block_count: int - ) -> LayerConfig: - # Generate candidate keys based on architecture names - # e.g., ["llama.block_count", "llama.n_layer", ...] - candidate_keys = [] - for name in self.names: - candidate_keys.extend( - [ - f"{name}.block_count", - f"{name}.n_layer", - f"{name}.n_layers", # Some older models use plural - f"{name}.num_hidden_layers", # Some models use num_hidden_layers (e.g., Seed OSS) - ] - ) - - block_count = ( - _get_first_valid_int(metadata, candidate_keys, default=base_block_count) - or 0 - ) - - # Standard decoder: blocks + output head - effective = (block_count + 1) if block_count > 0 else 0 - - return LayerConfig(block_count=block_count, effective_layer_count=effective) - - -# --- Concrete Profiles --- - - -@register_profile -class GlmProfile(StandardDecoderProfile): - """ - Profile for GLM family (GLM-4, GLM-4-MoE, etc.). - """ - - def __init__(self) -> None: - super().__init__(names=("glm", "glm4", "glm4moe")) - - -@register_profile -class DeepseekProfile(StandardDecoderProfile): - """ - DeepSeek decoder LMs and MoE variants. - Crucial: Must check 'deepseek2' for V2/V3 models. - """ - - def __init__(self) -> None: - super().__init__(names=("deepseek", "deepseek2")) - - -@register_profile -class QwenFamilyProfile(StandardDecoderProfile): - """Qwen / Qwen2 / Qwen2.5 / Qwen2-MoE.""" - - def __init__(self) -> None: - super().__init__(names=("qwen", "qwen2", "qwen3", "qwen2moe", "qwen3moe")) - - -@register_profile -class LlamaLikeProfile(StandardDecoderProfile): - """LLaMA, Mistral, Mixtral, Gemma, Phi, etc.""" - - def __init__(self) -> None: - # "phi" added as it follows the same decoder structure in GGUF - super().__init__( - names=( - "llama", - "mistral", - "mixtral", - "gemma", - "phi", - "seed", - "seed-oss", - "seedoss", - "seed_oss", - ) - ) - - -@register_profile -class MiniMaxProfile(StandardDecoderProfile): - """MiniMax models (MiniMax-M2.1 and variants).""" - - def __init__(self) -> None: - super().__init__(names=("minimax", "minimax-m2", "minimax_m2", "m2")) - - -# --- Main Accessor --- - - -def get_sorted_profiles() -> List[ArchitectureProfile]: - """ - Returns profiles sorted by specificity (longest name match first). - Example: 'glm4moe' (len 7) is checked before 'glm' (len 3). - """ - return sorted( - _PROFILE_REGISTRY, key=lambda p: max(len(n) for n in p.names), reverse=True - ) - - -def compute_layers_for_architecture( - architecture: str, - metadata: Dict[str, Any], - base_block_count: int, -) -> Dict[str, int]: - """ - Compute block_count and effective_layer_count. - """ - arch = architecture.lower() - - # Iterate through automatically sorted profiles - for profile in get_sorted_profiles(): - if profile.matches(arch): - result = profile.compute(metadata, base_block_count) - return { - "block_count": result.block_count, - "effective_layer_count": result.effective_layer_count, - } - - # --- Generic Fallback --- - # Even if unknown architecture, if we have a base_block_count, - # it's safe to assume it's a decoder stack + 1 output head. - block_count = base_block_count or 0 - - if block_count > 0: - effective_layer_count = block_count + 1 - logger.info( - "Generic profile: architecture=%s, block_count=%s, " - "effective_layer_count=%s", - arch, - block_count, - effective_layer_count, - ) - return { - "block_count": block_count, - "effective_layer_count": effective_layer_count, - } - - # Complete fallback - logger.warning( - "Could not determine block_count for architecture=%s; " - "using default effective_layer_count=32", - arch, - ) - return {"block_count": 0, "effective_layer_count": 32} diff --git a/backend/cuda_installer.py b/backend/cuda_installer.py index e407c6c..12b307b 100644 --- a/backend/cuda_installer.py +++ b/backend/cuda_installer.py @@ -1,2432 +1,2432 @@ -""" -CUDA Toolkit Installer - -Handles downloading and installing CUDA Toolkit on Linux systems. -""" - -import asyncio -import json -import os -import platform -import re -import shutil -import subprocess -import sys -import tempfile -import time -import gzip -from datetime import datetime, timezone -from typing import Any, Awaitable, Dict, Optional, Tuple -import aiohttp -import aiofiles - -from backend.logging_config import get_logger -from backend.websocket_manager import websocket_manager - -logger = get_logger(__name__) - -_installer_instance: Optional["CUDAInstaller"] = None - - -def get_cuda_installer() -> "CUDAInstaller": - global _installer_instance - if _installer_instance is None: - _installer_instance = CUDAInstaller() - return _installer_instance - - -def _utcnow() -> str: - return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - - -class CUDAInstaller: - """Install CUDA Toolkit on Linux systems.""" - - # Supported CUDA versions - URLs are fetched dynamically from NVIDIA's archive - # Format: version -> platform -> architecture (URLs fetched on demand) - SUPPORTED_VERSIONS = [ - "13.0", - "12.9", - "12.8", - "12.7", - "12.6", - "12.5", - "12.4", - "12.3", - "12.2", - "12.1", - "12.0", - "11.9", - "11.8", - ] - - # cuDNN version mappings by CUDA major version - CUDNN_VERSIONS = { - "13": "9.5.1", # cuDNN 9.x for CUDA 13.x - "12": "9.5.1", # cuDNN 9.x for CUDA 12.x - "11": "8.9.7", # cuDNN 8.x for CUDA 11.x - } - - # TensorRT version mappings by CUDA major version - TENSORRT_VERSIONS = { - "13": "10.7.0", # TensorRT 10.x for CUDA 13.x - "12": "10.7.0", # TensorRT 10.x for CUDA 12.x - "11": "8.6.1", # TensorRT 8.x for CUDA 11.x - } - - def __init__( - self, - *, - log_path: Optional[str] = None, - state_path: Optional[str] = None, - download_dir: Optional[str] = None, - ) -> None: - self._lock = asyncio.Lock() - self._operation: Optional[str] = None - self._operation_started_at: Optional[str] = None - self._current_task: Optional[asyncio.Task] = None - self._last_error: Optional[str] = None - self._download_progress: Dict[str, Any] = {} - self._last_logged_percentage: int = -1 - self._last_progress_broadcast_time: float = 0.0 - self._pending_progress: Optional[Dict[str, Any]] = None - self._progress_broadcast_count: int = 0 - - # Determine data root - check Docker path first, then fallback to local - if os.path.exists("/app/data"): - data_root = "/app/data" - else: - data_root = os.path.abspath("data") - - log_path = log_path or os.path.join(data_root, "logs", "cuda_install.log") - state_path = state_path or os.path.join( - data_root, "configs", "cuda_installer.json" - ) - download_dir = download_dir or os.path.join( - data_root, "temp", "cuda_installers" - ) - self._cuda_install_dir = os.path.join(data_root, "cuda") - - self._log_path = os.path.abspath(log_path) - self._state_path = os.path.abspath(state_path) - self._download_dir = os.path.abspath(download_dir) - self._url_cache: Dict[str, str] = {} # Cache for dynamically fetched URLs - self._repo_cache: Dict[str, list] = {} # Cache for NVIDIA repo packages - self._ensure_directories() - - def _ensure_directories(self) -> None: - os.makedirs(self._download_dir, exist_ok=True) - os.makedirs(os.path.dirname(self._log_path), exist_ok=True) - os.makedirs(os.path.dirname(self._state_path), exist_ok=True) - os.makedirs(self._cuda_install_dir, exist_ok=True) - - def _update_current_symlink(self, install_path: str) -> None: - """Create or update the /app/data/cuda/current symlink to point to the active CUDA installation.""" - current_symlink = os.path.join(self._cuda_install_dir, "current") - try: - # Remove existing symlink if it exists - if os.path.islink(current_symlink): - os.remove(current_symlink) - elif os.path.exists(current_symlink): - # If it's not a symlink but exists, remove it (shouldn't happen, but be safe) - os.remove(current_symlink) - - # Create new symlink pointing to the installation - os.symlink(install_path, current_symlink) - logger.info(f"Updated CUDA current symlink: {current_symlink} -> {install_path}") - except OSError as e: - logger.warning(f"Failed to update CUDA current symlink: {e}") - - def _remove_current_symlink(self) -> None: - """Remove the current symlink and optionally re-point it to another installed version.""" - current_symlink = os.path.join(self._cuda_install_dir, "current") - try: - if os.path.islink(current_symlink) or os.path.exists(current_symlink): - os.remove(current_symlink) - - # Try to find another installed version to point to - state = self._load_state() - installations = state.get("installations", {}) - - # Find the most recently installed version that still exists - latest_version = None - latest_time = None - for v, info in installations.items(): - install_path = info.get("path") - if install_path and os.path.exists(install_path): - installed_at = info.get("installed_at", "") - if not latest_time or installed_at > latest_time: - latest_time = installed_at - latest_version = v - - # Re-point to the latest remaining installation - if latest_version: - install_path = installations[latest_version].get("path") - if install_path and os.path.exists(install_path): - os.symlink(install_path, current_symlink) - logger.info(f"Re-pointed CUDA current symlink to: {install_path}") - except OSError as e: - logger.warning(f"Failed to update CUDA current symlink: {e}") - - def _get_platform(self) -> Tuple[str, str]: - """Get platform (os, arch) tuple.""" - system = platform.system().lower() - machine = platform.machine().lower() - - if machine in ("x86_64", "amd64"): - arch = "x86_64" - else: - arch = machine - - return system, arch - - def _get_ubuntu_version(self) -> str: - """Get Ubuntu version for NVIDIA repository URLs.""" - # Try to detect Ubuntu version from /etc/os-release - try: - if os.path.exists("/etc/os-release"): - with open("/etc/os-release", "r") as f: - for line in f: - if line.startswith("VERSION_ID="): - version = line.split("=")[1].strip().strip('"') - # Extract major.minor (e.g., "24.04" from "24.04.1") - parts = version.split(".") - if len(parts) >= 2: - major_minor = f"{parts[0]}{parts[1]}" - # Check if it's 24.04 or newer - if major_minor >= "2404": - return "ubuntu2404" - else: - return "ubuntu2204" - except Exception: - pass - - # Default to ubuntu2404 for Ubuntu 24.04 base image - return "ubuntu2404" - - def _get_archive_target_version(self) -> str: - """Get archive target version for CUDA runfile lookups.""" - ubuntu_version = self._get_ubuntu_version() - if ubuntu_version == "ubuntu2404": - return "24.04" - return "22.04" - - async def _get_repo_packages(self, ubuntu_version: str) -> list: - """Fetch and cache NVIDIA CUDA repo package metadata.""" - if ubuntu_version in self._repo_cache: - return self._repo_cache[ubuntu_version] - - base_url = ( - f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64" - ) - packages_url = f"{base_url}/Packages.gz" - packages_plain_url = f"{base_url}/Packages" - packages: list = [] - - async with aiohttp.ClientSession() as session: - data = None - try: - async with session.get(packages_url) as response: - if response.status == 200: - compressed = await response.read() - data = gzip.decompress(compressed) - except Exception: - data = None - - if data is None: - try: - async with session.get(packages_plain_url) as response: - if response.status == 200: - data = await response.read() - except Exception: - data = None - - if not data: - self._repo_cache[ubuntu_version] = [] - return [] - - text = data.decode("utf-8", errors="replace") - current = {} - for line in text.splitlines(): - if not line.strip(): - if current: - packages.append(current) - current = {} - continue - if line.startswith("Package:"): - current["Package"] = line.split(":", 1)[1].strip() - elif line.startswith("Version:"): - current["Version"] = line.split(":", 1)[1].strip() - elif line.startswith("Filename:"): - current["Filename"] = line.split(":", 1)[1].strip() - - if current: - packages.append(current) - - self._repo_cache[ubuntu_version] = packages - return packages - - def _version_key(self, version: str) -> tuple: - """Create a sortable key for package version strings.""" - tokens = re.split(r"[^\w]+", version) - key = [] - for token in tokens: - if token.isdigit(): - key.append(int(token)) - elif token: - key.append(token) - return tuple(key) - - def _select_repo_package( - self, - packages: list, - package_name: str, - version_prefix: Optional[str] = None, - version_contains: Optional[str] = None, - ) -> Optional[Dict[str, str]]: - """Select the best matching package from repo metadata.""" - candidates = [ - pkg for pkg in packages if pkg.get("Package") == package_name - ] - if version_prefix: - candidates = [ - pkg - for pkg in candidates - if pkg.get("Version", "").startswith(version_prefix) - ] - if version_contains: - candidates = [ - pkg - for pkg in candidates - if version_contains in pkg.get("Version", "") - ] - if not candidates: - return None - return max(candidates, key=lambda pkg: self._version_key(pkg.get("Version", ""))) - - def _load_state(self) -> Dict[str, Any]: - if not os.path.exists(self._state_path): - return {} - try: - with open(self._state_path, "r", encoding="utf-8") as f: - data = json.load(f) - return data if isinstance(data, dict) else {} - except Exception as exc: - logger.warning(f"Failed to load CUDA installer state: {exc}") - return {} - - def _save_state(self, state: Dict[str, Any]) -> None: - tmp_path = f"{self._state_path}.tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - json.dump(state, f, indent=2) - os.replace(tmp_path, self._state_path) - - def _detect_installed_version(self) -> Optional[str]: - """Detect installed CUDA version by checking nvcc or state.""" - # First check state for installed versions - state = self._load_state() - installations = state.get("installations", {}) - if installations: - # Return the most recently installed version - latest_version = None - latest_time = None - for v, info in installations.items(): - installed_at = info.get("installed_at", "") - if not latest_time or installed_at > latest_time: - latest_time = installed_at - latest_version = v - if latest_version: - install_path = installations[latest_version].get("path") - if install_path and os.path.exists(install_path): - return latest_version - - # Fallback: try to detect via nvcc command - try: - # Get CUDA environment to find nvcc - cuda_env = self.get_cuda_env() - env = os.environ.copy() - env.update(cuda_env) - - nvcc_path = shutil.which("nvcc", path=env.get("PATH", "")) - if not nvcc_path: - return None - - result = subprocess.run( - [nvcc_path, "--version"], - capture_output=True, - text=True, - timeout=5, - env=env, - ) - if result.returncode == 0: - # Parse version from output - for line in result.stdout.split("\n"): - if "release" in line.lower(): - parts = line.split() - for i, part in enumerate(parts): - if "release" in part.lower() and i + 1 < len(parts): - version_str = parts[i + 1].rstrip(",") - # Extract major.minor - version_parts = version_str.split(".") - if len(version_parts) >= 2: - return f"{version_parts[0]}.{version_parts[1]}" - except (subprocess.TimeoutExpired, FileNotFoundError, OSError): - pass - return None - - def _get_cuda_path(self, version: Optional[str] = None) -> Optional[str]: - """Get CUDA installation path.""" - # First, check the current symlink (most reliable for active installation) - current_symlink = os.path.join(self._cuda_install_dir, "current") - if os.path.islink(current_symlink) or os.path.exists(current_symlink): - try: - resolved_path = os.path.realpath(current_symlink) - if os.path.exists(resolved_path): - nvcc_path = os.path.join(resolved_path, "bin", "nvcc") - if os.path.exists(nvcc_path): - return resolved_path - except (OSError, ValueError): - pass - - # Check state for installed versions - state = self._load_state() - installations = state.get("installations", {}) - - # If version specified, return that installation path - if version and version in installations: - install_path = installations[version].get("path") - if install_path and os.path.exists(install_path): - return install_path - - # Check for latest installed version in state - if installations: - # Get the most recently installed version - latest_version = None - latest_time = None - for v, info in installations.items(): - installed_at = info.get("installed_at", "") - if not latest_time or installed_at > latest_time: - latest_time = installed_at - latest_version = v - - if latest_version: - install_path = installations[latest_version].get("path") - if install_path and os.path.exists(install_path): - return install_path - - # Check environment variables (only accept paths under data directory) - env_path = os.environ.get("CUDA_PATH") or os.environ.get("CUDA_HOME") - if ( - env_path - and os.path.exists(env_path) - and os.path.abspath(env_path).startswith(self._cuda_install_dir) - ): - return env_path - - # Scan the data directory for CUDA installs as fallback - try: - if os.path.exists(self._cuda_install_dir): - for item in sorted(os.listdir(self._cuda_install_dir), reverse=True): - # Skip the current symlink - if item == "current": - continue - full_path = os.path.join(self._cuda_install_dir, item) - if os.path.isdir(full_path): - nvcc_path = os.path.join(full_path, "bin", "nvcc") - if os.path.exists(nvcc_path): - return full_path - except OSError: - pass - - return None - - def get_cuda_env(self, version: Optional[str] = None) -> Dict[str, str]: - """Get environment variables for CUDA installation.""" - cuda_path = self._get_cuda_path(version) - if not cuda_path: - return {} - - cuda_bin = os.path.join(cuda_path, "bin") - cuda_lib = os.path.join(cuda_path, "lib64") - - env = { - "CUDA_HOME": cuda_path, - "CUDA_PATH": cuda_path, - } - - # Add to PATH if bin directory exists - if os.path.exists(cuda_bin): - current_path = os.environ.get("PATH", "") - if cuda_bin not in current_path: - env["PATH"] = f"{cuda_bin}:{current_path}" if current_path else cuda_bin - - # Add to LD_LIBRARY_PATH if lib64 directory exists - if os.path.exists(cuda_lib): - current_ld_path = os.environ.get("LD_LIBRARY_PATH", "") - if cuda_lib not in current_ld_path: - env["LD_LIBRARY_PATH"] = ( - f"{cuda_lib}:{current_ld_path}" if current_ld_path else cuda_lib - ) - - # Add TensorRT path if TensorRT is installed - tensorrt_version = self._detect_tensorrt_version(cuda_path) - if tensorrt_version: - env["TENSORRT_PATH"] = cuda_path - env["TENSORRT_ROOT"] = cuda_path - - return env - - def _get_archive_url(self, version: str) -> str: - """Get NVIDIA download archive URL for a CUDA version.""" - # Convert version like "12.8" to "12-8-0" for URL - version_parts = version.split(".") - major = version_parts[0] - minor = version_parts[1] if len(version_parts) > 1 else "0" - patch = version_parts[2] if len(version_parts) > 2 else "0" - version_slug = f"{major}-{minor}-{patch}" - target_version = self._get_archive_target_version() - - return ( - f"https://developer.nvidia.com/cuda-{version_slug}-download-archive" - f"?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version={target_version}&target_type=runfile_local" - ) - - async def _fetch_download_url(self, version: str) -> str: - """Fetch the actual download URL from NVIDIA's archive page.""" - # Check cache first - cache_key = f"{version}_linux_x86_64" - if cache_key in self._url_cache: - return self._url_cache[cache_key] - - archive_url = self._get_archive_url(version) - logger.info(f"Fetching CUDA {version} download URL from {archive_url}") - - async with aiohttp.ClientSession() as session: - try: - async with session.get( - archive_url, timeout=aiohttp.ClientTimeout(total=30) - ) as response: - if response.status != 200: - raise RuntimeError( - f"Failed to fetch archive page: HTTP {response.status}" - ) - - html = await response.text() - - # The page contains JSON data with download URLs - # The JSON structure has keys like "Linux/x86_64/Ubuntu/24.04/runfile_local" - # The URL is in the "details" field which contains HTML with href attributes - target_version = self._get_archive_target_version() - json_key = f"Linux/x86_64/Ubuntu/{target_version}/runfile_local" - - # Pattern 1: Look for href in the details field (HTML may be escaped) - # Match: "Linux/x86_64/Ubuntu//runfile_local":{..."details":"...href=\"URL\"..."} - pattern1 = rf'"{re.escape(json_key)}"[^}}]*"details"[^"]*href[=:][\\"]*([^"\\s<>]+cuda_\d+\.\d+\.\d+_[^"\\s<>]+_linux\.run)' - matches = re.findall(pattern1, html, re.IGNORECASE | re.DOTALL) - - if not matches: - # Pattern 2: Look for href with escaped quotes (\u0022 or \") - pattern2 = rf'"{re.escape(json_key)}"[^}}]*href[\\u0022=:]*([^"\\s<>]+cuda_\d+\.\d+\.\d+_[^"\\s<>]+_linux\.run)' - matches = re.findall(pattern2, html, re.IGNORECASE | re.DOTALL) - - if not matches: - # Pattern 3: Look for the filename field and construct URL - pattern3 = rf'"{re.escape(json_key)}"[^}}]*"filename"[^"]*"([^"]+_linux\.run)"' - filename_matches = re.findall(pattern3, html, re.IGNORECASE) - if filename_matches: - filename = filename_matches[0] - version_full = f"{version}.0" - url = f"https://developer.download.nvidia.com/compute/cuda/{version_full}/local_installers/{filename}" - matches = [url] - - if not matches: - # Pattern 4: Fallback - look for any URL matching the pattern - version_escaped = version.replace(".", r"\.") - pattern4 = rf'https://developer\.download\.nvidia\.com/compute/cuda/{version_escaped}\.0/local_installers/cuda_{version_escaped}\.0_[^"\'\s<>]+_linux\.run' - matches = re.findall(pattern4, html, re.IGNORECASE) - - if matches: - url = matches[0] - # Cache it - self._url_cache[cache_key] = url - logger.info(f"Found CUDA {version} download URL: {url}") - return url - else: - raise RuntimeError( - f"Could not find download URL for CUDA {version} on archive page" - ) - - except aiohttp.ClientError as e: - raise RuntimeError(f"Failed to fetch archive page: {e}") - - async def _broadcast_log_line(self, line: str) -> None: - try: - await websocket_manager.broadcast( - { - "type": "cuda_install_log", - "line": line, - "timestamp": _utcnow(), - } - ) - except Exception as exc: - logger.debug(f"Failed to broadcast CUDA log line: {exc}") - - async def _broadcast_progress(self, progress: Dict[str, Any]) -> None: - """Broadcast progress updates, throttled to 1 second intervals.""" - try: - current_time = time.time() - progress_value = progress.get("progress", 0) - is_complete = progress_value >= 100 - - # Always send completion updates immediately - if is_complete: - await websocket_manager.broadcast( - { - "type": "cuda_install_progress", - **progress, - "timestamp": _utcnow(), - } - ) - self._last_progress_broadcast_time = current_time - self._pending_progress = None - return - - # Always send the first few updates immediately (first 3 updates) - # then throttle to 1 second intervals - is_first_update = self._last_progress_broadcast_time == 0.0 - time_since_last_broadcast = current_time - self._last_progress_broadcast_time - is_early_update = self._progress_broadcast_count < 3 - should_send = is_first_update or is_early_update or time_since_last_broadcast >= 1.0 - - if should_send: - await websocket_manager.broadcast( - { - "type": "cuda_install_progress", - **progress, - "timestamp": _utcnow(), - } - ) - self._last_progress_broadcast_time = current_time - self._pending_progress = None - self._progress_broadcast_count += 1 - else: - # Store the latest progress data for next send - self._pending_progress = progress - except Exception as exc: - logger.exception(f"Failed to broadcast CUDA progress: {exc}") - - async def _set_operation(self, operation: str) -> None: - self._operation = operation - self._operation_started_at = _utcnow() - self._last_error = None - await websocket_manager.broadcast( - { - "type": "cuda_install_status", - "status": operation, - "started_at": self._operation_started_at, - } - ) - - async def _finish_operation(self, success: bool, message: str = "") -> None: - payload = { - "type": "cuda_install_status", - "status": "completed" if success else "failed", - "operation": self._operation, - "message": message, - "ended_at": _utcnow(), - } - await websocket_manager.broadcast(payload) - self._operation = None - self._operation_started_at = None - - def _create_task(self, coro: Awaitable[Any]) -> None: - loop = asyncio.get_running_loop() - task = loop.create_task(coro) - self._current_task = task - - def _cleanup(fut: asyncio.Future) -> None: - try: - fut.result() - except Exception as exc: - logger.exception("CUDA installer task error") - finally: - self._current_task = None - - task.add_done_callback(_cleanup) - - async def _download_installer( - self, version: str, url: str, installer_path: str - ) -> None: - """Download CUDA installer with progress tracking.""" - # Check if installer already exists - if os.path.exists(installer_path): - file_size = os.path.getsize(installer_path) - file_size_mb = file_size / (1024 * 1024) - - # Verify existing file is not corrupted (should be at least 100MB for CUDA installers) - if file_size < 100 * 1024 * 1024: - await self._broadcast_log_line( - f"Existing installer file appears corrupted (too small: {file_size_mb:.1f} MB), re-downloading..." - ) - try: - os.remove(installer_path) - except OSError: - pass - else: - # Verify the file is actually valid and matches expected size from server - try: - # First, check if it's a valid shell script - with open(installer_path, "rb") as f: - header = f.read(100) - if not header.startswith(b"#!/"): - await self._broadcast_log_line( - f"Existing installer file is not a valid shell script, re-downloading..." - ) - try: - os.remove(installer_path) - except OSError: - pass - else: - # File appears valid, now verify size matches server expectation - # Fetch the expected file size from the server - try: - async with aiohttp.ClientSession() as session: - async with session.head(url, allow_redirects=True) as head_response: - expected_size = int(head_response.headers.get("Content-Length", 0)) - - if expected_size > 0: - # Verify file size matches (with small tolerance) - size_diff = abs(file_size - expected_size) - if size_diff > 1024: # Allow 1KB tolerance - await self._broadcast_log_line( - f"Existing installer file size mismatch: expected {expected_size / (1024*1024):.1f} MB, " - f"got {file_size_mb:.1f} MB (difference: {size_diff} bytes). Re-downloading..." - ) - try: - os.remove(installer_path) - except OSError: - pass - else: - # File size matches, verify it's stable (not currently being written) - await asyncio.sleep(0.2) # Brief pause to ensure file is fully written if being written - new_size = os.path.getsize(installer_path) - if new_size != file_size: - await self._broadcast_log_line( - f"File size changed during verification (was {file_size_mb:.1f} MB, now {new_size / (1024*1024):.1f} MB), " - f"file may still be downloading. Re-downloading..." - ) - try: - os.remove(installer_path) - except OSError: - pass - else: - await self._broadcast_log_line( - f"Installer file already exists and verified: {installer_path} ({file_size_mb:.1f} MB)" - ) - await self._broadcast_progress( - { - "stage": "download", - "progress": 100, - "message": f"Using existing installer file ({file_size_mb:.1f} MB)", - } - ) - return - else: - # Couldn't get expected size, but file looks valid - use it - await self._broadcast_log_line( - f"Installer file already exists: {installer_path} ({file_size_mb:.1f} MB). " - f"Could not verify size from server, but file appears valid." - ) - await self._broadcast_progress( - { - "stage": "download", - "progress": 100, - "message": f"Using existing installer file ({file_size_mb:.1f} MB)", - } - ) - return - except Exception as size_check_error: - # If we can't verify size from server, but file looks valid, use it - await self._broadcast_log_line( - f"Could not verify file size from server: {size_check_error}. " - f"File appears valid, using existing file: {installer_path} ({file_size_mb:.1f} MB)" - ) - await self._broadcast_progress( - { - "stage": "download", - "progress": 100, - "message": f"Using existing installer file ({file_size_mb:.1f} MB)", - } - ) - return - except (OSError, IOError) as e: - await self._broadcast_log_line( - f"Failed to verify existing installer file: {e}, re-downloading..." - ) - try: - os.remove(installer_path) - except OSError: - pass - - # Reset logging state for new download - self._last_logged_percentage = -1 - self._last_progress_broadcast_time = 0.0 - self._pending_progress = None - self._progress_broadcast_count = 0 - - log_header = f"[{_utcnow()}] Downloading CUDA {version} installer from {url}\n" - with open(self._log_path, "w", encoding="utf-8") as log_file: - log_file.write(log_header) - - await self._broadcast_log_line( - f"Starting download of CUDA {version} installer..." - ) - await self._broadcast_progress( - { - "stage": "download", - "progress": 0, - "message": f"Downloading CUDA {version} installer...", - } - ) - - # Configure timeout for large file downloads: - # - total: 1 hour (3600s) for very large files and slow connections - # - connect: 30s to establish connection - # - sock_read: 5 minutes (300s) to allow for slow network during chunk reads - timeout = aiohttp.ClientTimeout( - total=3600, # 1 hour total timeout - connect=30, # 30 seconds to connect - sock_read=300, # 5 minutes per read operation - ) - - downloaded = 0 - total_size = 0 - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url) as response: - response.raise_for_status() - total_size = int(response.headers.get("Content-Length", 0)) - - async with aiofiles.open(installer_path, "wb") as f: - async for chunk in response.content.iter_chunked(8192): - await f.write(chunk) - downloaded += len(chunk) - - if total_size > 0: - progress = int((downloaded / total_size) * 100) - # Format sizes in MB - downloaded_mb = downloaded / (1024 * 1024) - total_mb = total_size / (1024 * 1024) - await self._broadcast_progress( - { - "stage": "download", - "progress": progress, - "message": f"Downloading CUDA {version} installer... ({downloaded_mb:.1f}/{total_mb:.1f} MB)", - "bytes_downloaded": downloaded, - "total_bytes": total_size, - } - ) - - # Log progress only at key percentage milestones (10%, 25%, 50%, 75%, 90%, 100%) - # Only log when we cross a milestone, not when we're within it - should_log = False - - # Check if we've crossed a key percentage milestone - if total_size > 0: - progress = int((downloaded / total_size) * 100) - if progress != self._last_logged_percentage and progress in [ - 10, - 25, - 50, - 75, - 90, - 100, - ]: - should_log = True - self._last_logged_percentage = progress - - if should_log: - downloaded_mb = downloaded / (1024 * 1024) - total_mb = total_size / (1024 * 1024) - log_line = f"Downloaded {downloaded_mb:.1f}/{total_mb:.1f} MB ({progress}%)\n" - with open( - self._log_path, "a", encoding="utf-8" - ) as log_file: - log_file.write(log_line) - await self._broadcast_log_line( - f"Downloaded {downloaded_mb:.1f} MB / {total_mb:.1f} MB ({progress}%)" - ) - - # File is automatically flushed when the context manager exits - except asyncio.TimeoutError as e: - # Clean up partial download on timeout - if os.path.exists(installer_path): - try: - os.remove(installer_path) - except OSError: - pass - downloaded_mb = downloaded / (1024 * 1024) if downloaded > 0 else 0 - total_mb = total_size / (1024 * 1024) if total_size > 0 else 0 - error_msg = ( - f"Download timeout: Failed to download CUDA {version} installer. " - f"Downloaded {downloaded_mb:.1f} MB of {total_mb:.1f} MB. " - f"This may be due to a slow network connection. Please try again." - ) - await self._broadcast_log_line(error_msg) - raise RuntimeError(error_msg) from e - except aiohttp.ClientError as e: - # Clean up partial download on client error - if os.path.exists(installer_path): - try: - os.remove(installer_path) - except OSError: - pass - error_msg = ( - f"Network error while downloading CUDA {version} installer: {e}. " - f"Please check your network connection and try again." - ) - await self._broadcast_log_line(error_msg) - raise RuntimeError(error_msg) from e - - # Wait a brief moment to ensure file system has fully written the file - # This helps ensure the file is completely written to disk before verification - await asyncio.sleep(0.5) - - # Verify downloaded file exists and is complete - if not os.path.exists(installer_path): - raise RuntimeError(f"Downloaded file not found: {installer_path}") - - # Verify file size matches expected size (with a small tolerance for filesystem differences) - actual_size = os.path.getsize(installer_path) - if total_size > 0: - size_diff = abs(actual_size - total_size) - if size_diff > 1024: # Allow 1KB tolerance for filesystem differences - raise RuntimeError( - f"Downloaded file size mismatch: expected {total_size} bytes, " - f"got {actual_size} bytes (difference: {size_diff} bytes). File may be corrupted or incomplete." - ) - - if actual_size < 100 * 1024 * 1024: # Less than 100MB is suspicious - raise RuntimeError( - f"Downloaded file appears to be corrupted or incomplete: " - f"{installer_path} (size: {actual_size} bytes)" - ) - - # Verify the file is a valid shell script (CUDA .run files are self-extracting) - try: - with open(installer_path, "rb") as verify_file: - header = verify_file.read(100) - if not header.startswith(b"#!/"): - raise RuntimeError( - f"Downloaded file does not appear to be a valid shell script: {installer_path}" - ) - except Exception as e: - raise RuntimeError( - f"Failed to verify downloaded file integrity: {installer_path}, error: {e}" - ) - - await self._broadcast_log_line( - f"Download completed and verified: {installer_path} ({actual_size / (1024*1024):.1f} MB)" - ) - await self._broadcast_progress( - { - "stage": "download", - "progress": 100, - "message": "Download completed and verified", - } - ) - - def _is_docker_container(self) -> bool: - """Check if running inside a Docker container.""" - # Check for Docker-specific files - docker_indicators = [ - "/.dockerenv", - "/proc/self/cgroup", - ] - - # Check /.dockerenv - if os.path.exists("/.dockerenv"): - return True - - # Check /proc/self/cgroup for Docker - try: - if os.path.exists("/proc/self/cgroup"): - with open("/proc/self/cgroup", "r") as f: - content = f.read() - if "docker" in content or "containerd" in content: - return True - except (OSError, IOError): - pass - - return False - - async def _install_linux( - self, - installer_path: str, - version: str, - install_cudnn: bool = False, - install_tensorrt: bool = False, - ) -> str: - """ - Install CUDA on Linux using runfile installer. - - Uses optimized installer options for custom location installation: - - Silent installation with EULA acceptance - - Toolkit-only installation (no driver) - - Override installation checks for custom paths - - Skip OpenGL libraries (not needed in Docker/headless environments) - - Skip man pages to reduce installation size - - Args: - installer_path: Path to the CUDA installer runfile - version: CUDA version being installed - install_cudnn: Whether to install cuDNN - install_tensorrt: Whether to install TensorRT - """ - await self._broadcast_log_line("Starting CUDA installation on Linux...") - await self._broadcast_progress( - { - "stage": "install", - "progress": 0, - "message": "Installing CUDA Toolkit...", - } - ) - - # Verify installer file exists and is not corrupted - if not os.path.exists(installer_path): - raise RuntimeError(f"Installer file not found: {installer_path}") - - file_size = os.path.getsize(installer_path) - if file_size < 100 * 1024 * 1024: # Less than 100MB is suspicious for CUDA installers - raise RuntimeError( - f"Installer file appears to be corrupted or incomplete: {installer_path} " - f"(size: {file_size / (1024*1024):.1f} MB, expected > 100 MB)" - ) - - # Verify the file starts with a shell script header (CUDA .run files are self-extracting) - try: - with open(installer_path, "rb") as f: - header = f.read(100) - if not header.startswith(b"#!/"): - raise RuntimeError( - f"Installer file does not appear to be a valid shell script: {installer_path}" - ) - except Exception as e: - raise RuntimeError( - f"Failed to verify installer file: {installer_path}, error: {e}" - ) - - await self._broadcast_log_line( - f"Verifying installer file: {installer_path} ({file_size / (1024*1024):.1f} MB)" - ) - - # Make installer executable - os.chmod(installer_path, 0o755) - - # Always install to the data directory for persistence - install_path = os.path.join(self._cuda_install_dir, f"cuda-{version}") - await self._broadcast_log_line(f"Installing to data directory: {install_path}") - os.makedirs(install_path, exist_ok=True) - - # Build installer arguments with optimized options for custom location installation - # - # Selected options based on NVIDIA CUDA installer documentation: - # - --silent: Required for silent installation, implies EULA acceptance - # - --toolkit: Install toolkit only (not driver) - required for non-root installations - # - --override: Override compiler, third-party library, and toolkit detection checks - # (essential for custom installation paths) - # - --toolkitpath: Install to custom data directory path - # - --no-opengl-libs: Skip OpenGL libraries (not needed in Docker/headless environments) - # - --no-man-page: Skip man pages to reduce installation size - # - install_args = [ - "bash", - installer_path, - "--silent", # Silent installation with EULA acceptance - "--toolkit", # Install toolkit only (not driver) - "--override", # Override installation checks for custom paths - f"--toolkitpath={install_path}", # Install to custom data directory - "--no-opengl-libs", # Skip OpenGL libraries (not needed in Docker) - "--no-man-page", # Skip man pages to reduce size - ] - - await self._broadcast_log_line(f"Installer arguments: {' '.join(install_args[2:])}") # Skip 'bash' and installer_path - - # Set up environment to prevent /dev/tty access issues in Docker - env = os.environ.copy() - env["DEBIAN_FRONTEND"] = "noninteractive" - # Disable interactive prompts - env["PERL_BADLANG"] = "0" - # Ensure we're in a non-interactive environment - env["TERM"] = "dumb" - # Prevent installer from trying to access /dev/tty - env["CI"] = "true" # Indicate we're in a CI/non-interactive environment - - process = await asyncio.create_subprocess_exec( - *install_args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - stdin=asyncio.subprocess.DEVNULL, # Redirect stdin to prevent /dev/tty access - env=env, - ) - - # Collect output for error analysis - output_lines = [] - - async def _stream_output(): - if process.stdout is None: - return - with open(self._log_path, "a", encoding="utf-8", buffering=1) as log_file: - while True: - chunk = await process.stdout.readline() - if not chunk: - break - text = chunk.decode("utf-8", errors="replace") - output_lines.append(text) - log_file.write(text) - await self._broadcast_log_line(text.rstrip("\n")) - - await asyncio.gather(process.wait(), _stream_output()) - - if process.returncode != 0: - # Check for specific error patterns - output_text = "".join(output_lines) - - # Check for /dev/tty errors - if "/dev/tty" in output_text.lower() or "cannot create /dev/tty" in output_text.lower(): - error_msg = ( - f"CUDA installer failed due to /dev/tty access issue (common in Docker). " - f"This may indicate the installer file is corrupted or the environment is not properly configured. " - f"Exit code: {process.returncode}. " - f"Please check the installation logs for details. " - f"If the file appears corrupted, try deleting it and re-downloading." - ) - # Check for gzip/corruption errors - elif "gzip" in output_text.lower() and ("unexpected end" in output_text.lower() or "corrupt" in output_text.lower()): - error_msg = ( - f"CUDA installer file appears to be corrupted (gzip error detected). " - f"Please delete the installer file at {installer_path} and try again. " - f"Exit code: {process.returncode}." - ) - else: - error_msg = ( - f"CUDA installer exited with code {process.returncode}. " - "Please check the installation logs for details." - ) - - raise RuntimeError(error_msg) - - # Verify installation and set up environment - cuda_home = install_path - cuda_bin = os.path.join(cuda_home, "bin") - cuda_lib = os.path.join(cuda_home, "lib64") - - # Verify key directories exist - if not os.path.exists(cuda_bin) or not os.path.exists(cuda_lib): - raise RuntimeError( - f"CUDA installation completed but expected directories not found. " - f"Expected: {cuda_bin}, {cuda_lib}" - ) - - await self._broadcast_log_line( - f"CUDA installed successfully to: {install_path}" - ) - await self._broadcast_log_line(f"CUDA_HOME={cuda_home}") - await self._broadcast_log_line(f"Adding to PATH: {cuda_bin}") - await self._broadcast_log_line(f"Adding to LD_LIBRARY_PATH: {cuda_lib}") - - # Install NCCL (required for multi-GPU and llama.cpp CUDA builds) - await self._install_nccl_linux(version, install_path) - - # Install nvidia-smi (required for GPU monitoring) - await self._install_nvidia_smi_linux(install_path) - - # Install cuDNN if requested - if install_cudnn: - await self._install_cudnn_linux(version, install_path) - - # Install TensorRT if requested - if install_tensorrt: - await self._install_tensorrt_linux(version, install_path) - - # Save installation path to state - state = self._load_state() - if "installations" not in state: - state["installations"] = {} - state["installations"][version] = { - "path": install_path, - "installed_at": _utcnow(), - "is_system_install": False, - "cudnn_installed": install_cudnn, - "tensorrt_installed": install_tensorrt, - } - self._save_state(state) - - # Update the current symlink to point to this installation - self._update_current_symlink(install_path) - await self._broadcast_log_line( - f"Updated CUDA current symlink: /app/data/cuda/current -> {install_path}" - ) - - components = ["CUDA", "NCCL", "nvidia-smi"] - if install_cudnn: - components.append("cuDNN") - if install_tensorrt: - components.append("TensorRT") - - await self._broadcast_progress( - { - "stage": "install", - "progress": 100, - "message": f"{', '.join(components)} installation completed", - } - ) - - return install_path - - async def _install_nccl_linux(self, cuda_version: str, cuda_path: str) -> None: - """Install NCCL library for multi-GPU support.""" - await self._broadcast_log_line( - "Installing NCCL (NVIDIA Collective Communications Library)..." - ) - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 0, - "message": "Installing NCCL...", - } - ) - - ubuntu_version = self._get_ubuntu_version() - - # Download NCCL from NVIDIA's repo package index - await self._broadcast_log_line("Attempting manual NCCL installation...") - - try: - cuda_major = cuda_version.split(".")[0] - packages = await self._get_repo_packages(ubuntu_version) - nccl_pkg = self._select_repo_package( - packages, - "libnccl2", - version_prefix="2.", - version_contains=f"+cuda{cuda_major}", - ) - nccl_dev_pkg = self._select_repo_package( - packages, - "libnccl-dev", - version_prefix="2.", - version_contains=f"+cuda{cuda_major}", - ) - - if not nccl_pkg or not nccl_dev_pkg: - await self._broadcast_log_line( - "NCCL packages not found in repository, skipping NCCL installation" - ) - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 100, - "message": "NCCL installation skipped (optional)", - } - ) - return - - base_url = ( - f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" - ) - nccl_url = base_url + nccl_pkg.get("Filename", "").lstrip("./") - nccl_dev_url = base_url + nccl_dev_pkg.get("Filename", "").lstrip("./") - - nccl_path = os.path.join(self._download_dir, "libnccl2.deb") - nccl_dev_path = os.path.join(self._download_dir, "libnccl-dev.deb") - - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 25, - "message": "Downloading NCCL packages...", - } - ) - - # Download NCCL packages - async with aiohttp.ClientSession() as session: - for url, path, name in [ - (nccl_url, nccl_path, "libnccl2"), - (nccl_dev_url, nccl_dev_path, "libnccl-dev"), - ]: - try: - await self._broadcast_log_line(f"Downloading {name}...") - async with session.get(url) as response: - if response.status == 200: - async with aiofiles.open(path, "wb") as f: - await f.write(await response.read()) - await self._broadcast_log_line(f"Downloaded {name}") - else: - await self._broadcast_log_line( - f"Failed to download {name}: HTTP {response.status}" - ) - # Try alternative URL with different NCCL version - continue - except Exception as download_err: - await self._broadcast_log_line( - f"Download error for {name}: {download_err}" - ) - continue - - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 50, - "message": "Installing NCCL packages...", - } - ) - - if os.path.exists(nccl_path): - await self._broadcast_log_line( - "Extracting NCCL to CUDA directory..." - ) - - # Extract .deb file (it's an ar archive containing data.tar) - extract_dir = os.path.join(self._download_dir, "nccl_extract") - os.makedirs(extract_dir, exist_ok=True) - - for deb_path in [nccl_path, nccl_dev_path]: - if os.path.exists(deb_path): - # Extract using ar and tar - extract_process = await asyncio.create_subprocess_exec( - "bash", - "-c", - f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - ) - await extract_process.wait() - - # Copy NCCL files to CUDA installation - nccl_lib_src = os.path.join( - extract_dir, "usr", "lib", "x86_64-linux-gnu" - ) - nccl_include_src = os.path.join(extract_dir, "usr", "include") - - cuda_lib_dst = os.path.join(cuda_path, "lib64") - cuda_include_dst = os.path.join(cuda_path, "include") - - if os.path.exists(nccl_lib_src): - # First pass: collect files and symlinks, copy actual files first - files_to_copy = [] - symlinks_to_create = [] - - for f in os.listdir(nccl_lib_src): - if "nccl" in f.lower(): - src = os.path.join(nccl_lib_src, f) - dst = os.path.join(cuda_lib_dst, f) - - if os.path.islink(src): - # Resolve symlink to find actual target - link_target = os.readlink(src) - # If relative symlink, resolve relative to source directory - if not os.path.isabs(link_target): - link_target = os.path.normpath( - os.path.join(os.path.dirname(src), link_target) - ) - # Find the actual target file name - actual_target = os.path.basename(link_target) - symlinks_to_create.append((f, actual_target, dst)) - else: - files_to_copy.append((f, src, dst)) - - # Copy all actual files first - for f, src, dst in files_to_copy: - try: - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA lib directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Then create symlinks pointing to the copied files - for link_name, target_name, dst in symlinks_to_create: - try: - if os.path.exists(dst): - os.remove(dst) - # Create symlink pointing to target in same directory - os.symlink(target_name, dst) - await self._broadcast_log_line( - f"Created symlink {link_name} -> {target_name} in CUDA lib directory" - ) - except Exception as link_err: - await self._broadcast_log_line( - f"Failed to create symlink {link_name}: {link_err}" - ) - - if os.path.exists(nccl_include_src): - for f in os.listdir(nccl_include_src): - if "nccl" in f.lower(): - src = os.path.join(nccl_include_src, f) - dst = os.path.join(cuda_include_dst, f) - try: - if os.path.isdir(src): - # Handle directories by copying recursively - if os.path.exists(dst): - shutil.rmtree(dst) - shutil.copytree(src, dst) - await self._broadcast_log_line( - f"Copied directory {f} to CUDA include directory" - ) - else: - # Handle regular files - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA include directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Cleanup temporary extract directory only (keep .deb files) - shutil.rmtree(extract_dir, ignore_errors=True) - # Keep .deb files for future use - logger.info(f"NCCL packages kept at: {nccl_path}, {nccl_dev_path}") - - await self._broadcast_log_line("NCCL extracted to CUDA directory") - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 100, - "message": "NCCL installed successfully", - } - ) - else: - await self._broadcast_log_line( - "NCCL packages not available, skipping NCCL installation" - ) - await self._broadcast_log_line( - "Note: NCCL is optional but recommended for multi-GPU builds" - ) - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 100, - "message": "NCCL installation skipped (optional)", - } - ) - - except Exception as e: - await self._broadcast_log_line(f"NCCL installation error: {e}") - await self._broadcast_log_line( - "Note: NCCL is optional. The build will continue without multi-GPU support." - ) - await self._broadcast_progress( - { - "stage": "nccl", - "progress": 100, - "message": "NCCL installation skipped (optional)", - } - ) - - async def _install_nvidia_smi_linux(self, cuda_path: str) -> None: - """Install nvidia-smi binary for GPU monitoring.""" - await self._broadcast_log_line( - "Installing nvidia-smi (NVIDIA System Management Interface)..." - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 0, - "message": "Installing nvidia-smi...", - } - ) - - # Check if nvidia-smi already exists in CUDA installation - cuda_bin = os.path.join(cuda_path, "bin") - nvidia_smi_dst = os.path.join(cuda_bin, "nvidia-smi") - if os.path.exists(nvidia_smi_dst): - await self._broadcast_log_line( - "nvidia-smi already exists in CUDA installation, skipping" - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi already installed", - } - ) - return - - ubuntu_version = self._get_ubuntu_version() - - try: - # Try to find nvidia-utils package which contains nvidia-smi - packages = await self._get_repo_packages(ubuntu_version) - nvidia_utils_pkg = None - - # Try multiple package name patterns - for pkg_name in ["nvidia-utils", "nvidia-driver-utils", "nvidia-utils-"]: - nvidia_utils_pkg = self._select_repo_package( - packages, - pkg_name, - ) - if nvidia_utils_pkg: - break - - if not nvidia_utils_pkg: - await self._broadcast_log_line( - "nvidia-utils package not found in repository, skipping nvidia-smi installation" - ) - await self._broadcast_log_line( - "Note: nvidia-smi will not be available. GPU monitoring may be limited." - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi installation skipped (package not available)", - } - ) - return - - base_url = ( - f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" - ) - nvidia_utils_url = base_url + nvidia_utils_pkg.get("Filename", "").lstrip("./") - - nvidia_utils_path = os.path.join(self._download_dir, "nvidia-utils.deb") - - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 25, - "message": "Downloading nvidia-utils package...", - } - ) - - # Download nvidia-utils package - async with aiohttp.ClientSession() as session: - try: - await self._broadcast_log_line("Downloading nvidia-utils...") - async with session.get(nvidia_utils_url) as response: - if response.status == 200: - async with aiofiles.open(nvidia_utils_path, "wb") as f: - await f.write(await response.read()) - await self._broadcast_log_line("Downloaded nvidia-utils") - else: - await self._broadcast_log_line( - f"Failed to download nvidia-utils: HTTP {response.status}" - ) - raise RuntimeError(f"Failed to download nvidia-utils: HTTP {response.status}") - except Exception as download_err: - await self._broadcast_log_line( - f"Download error for nvidia-utils: {download_err}" - ) - raise - - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 50, - "message": "Extracting nvidia-smi...", - } - ) - - if os.path.exists(nvidia_utils_path): - await self._broadcast_log_line( - "Extracting nvidia-smi to CUDA directory..." - ) - - # Extract .deb file - extract_dir = os.path.join(self._download_dir, "nvidia_utils_extract") - os.makedirs(extract_dir, exist_ok=True) - - # Extract using ar and tar - extract_process = await asyncio.create_subprocess_exec( - "bash", - "-c", - f"cd {extract_dir} && ar x {nvidia_utils_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - ) - await extract_process.wait() - - # Copy nvidia-smi binary to CUDA installation - nvidia_smi_src = os.path.join(extract_dir, "usr", "bin", "nvidia-smi") - cuda_bin_dst = os.path.join(cuda_path, "bin") - nvidia_smi_dst = os.path.join(cuda_bin_dst, "nvidia-smi") - - if os.path.exists(nvidia_smi_src): - os.makedirs(cuda_bin_dst, exist_ok=True) - try: - shutil.copy2(nvidia_smi_src, nvidia_smi_dst) - os.chmod(nvidia_smi_dst, 0o755) - await self._broadcast_log_line( - "Copied nvidia-smi to CUDA bin directory" - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi installed successfully", - } - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy nvidia-smi: {copy_err}" - ) - raise - else: - await self._broadcast_log_line( - "nvidia-smi not found in extracted package" - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi installation skipped (not in package)", - } - ) - - # Cleanup temporary extract directory only (keep .deb file) - shutil.rmtree(extract_dir, ignore_errors=True) - # Keep .deb file for future use - if os.path.exists(nvidia_utils_path): - logger.info(f"nvidia-utils package kept at: {nvidia_utils_path}") - - else: - await self._broadcast_log_line( - "nvidia-utils package not available, skipping nvidia-smi installation" - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi installation skipped (package not available)", - } - ) - - except Exception as e: - await self._broadcast_log_line(f"nvidia-smi installation error: {e}") - await self._broadcast_log_line( - "Note: nvidia-smi installation failed. GPU monitoring may be limited." - ) - await self._broadcast_progress( - { - "stage": "nvidia-smi", - "progress": 100, - "message": "nvidia-smi installation skipped (error occurred)", - } - ) - - async def _install_cudnn_linux(self, cuda_version: str, cuda_path: str) -> None: - """Install cuDNN library for deep learning primitives.""" - await self._broadcast_log_line( - "Installing cuDNN (CUDA Deep Neural Network library)..." - ) - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 0, - "message": "Installing cuDNN...", - } - ) - - try: - # Determine CUDA major version for cuDNN compatibility - cuda_major = cuda_version.split(".")[0] - cudnn_version = self.CUDNN_VERSIONS.get(cuda_major) - - if not cudnn_version: - await self._broadcast_log_line( - f"cuDNN version not available for CUDA {cuda_version}, skipping" - ) - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 100, - "message": "cuDNN installation skipped (version not available)", - } - ) - return - - ubuntu_version = self._get_ubuntu_version() - - # cuDNN package names vary by CUDA version - # For CUDA 12.x: libcudnn9-cuda-12, libcudnn9-dev-cuda-12 - # For CUDA 11.x: libcudnn8-cuda-11, libcudnn8-dev-cuda-11 - if cuda_major == "12" or cuda_major == "13": - cudnn_pkg = "libcudnn9" - cudnn_cuda_suffix = f"cuda-{cuda_major}" - else: - cudnn_pkg = "libcudnn8" - cudnn_cuda_suffix = f"cuda-{cuda_major}" - - # Manual cuDNN installation - await self._broadcast_log_line("Installing cuDNN packages...") - - cudnn_package_name = f"{cudnn_pkg}-{cudnn_cuda_suffix}" - cudnn_dev_package_name = f"{cudnn_pkg}-dev-{cudnn_cuda_suffix}" - packages = await self._get_repo_packages(ubuntu_version) - cudnn_pkg_entry = self._select_repo_package( - packages, cudnn_package_name, version_prefix=cudnn_version - ) - cudnn_dev_pkg_entry = self._select_repo_package( - packages, cudnn_dev_package_name, version_prefix=cudnn_version - ) - - if not cudnn_pkg_entry or not cudnn_dev_pkg_entry: - await self._broadcast_log_line( - "cuDNN packages not found in repository, skipping cuDNN installation" - ) - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 100, - "message": "cuDNN installation skipped (optional)", - } - ) - return - - base_url = ( - f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" - ) - cudnn_url = base_url + cudnn_pkg_entry.get("Filename", "").lstrip("./") - cudnn_dev_url = base_url + cudnn_dev_pkg_entry.get("Filename", "").lstrip("./") - - cudnn_path = os.path.join(self._download_dir, f"{cudnn_pkg}.deb") - cudnn_dev_path = os.path.join(self._download_dir, f"{cudnn_pkg}-dev.deb") - - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 25, - "message": "Downloading cuDNN packages...", - } - ) - - # Download cuDNN packages - async with aiohttp.ClientSession() as session: - for url, path, name in [ - (cudnn_url, cudnn_path, cudnn_pkg), - (cudnn_dev_url, cudnn_dev_path, f"{cudnn_pkg}-dev"), - ]: - try: - await self._broadcast_log_line(f"Downloading {name}...") - async with session.get(url) as response: - if response.status == 200: - async with aiofiles.open(path, "wb") as f: - await f.write(await response.read()) - await self._broadcast_log_line(f"Downloaded {name}") - else: - await self._broadcast_log_line( - f"Failed to download {name}: HTTP {response.status}" - ) - # Try alternative URL pattern - continue - except Exception as download_err: - await self._broadcast_log_line( - f"Download error for {name}: {download_err}" - ) - continue - - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 50, - "message": "Installing cuDNN packages...", - } - ) - - if os.path.exists(cudnn_path): - await self._broadcast_log_line( - "Extracting cuDNN to CUDA directory..." - ) - - # Extract .deb file - extract_dir = os.path.join(self._download_dir, "cudnn_extract") - os.makedirs(extract_dir, exist_ok=True) - - for deb_path in [cudnn_path, cudnn_dev_path]: - if os.path.exists(deb_path): - # Extract using ar and tar - extract_process = await asyncio.create_subprocess_exec( - "bash", - "-c", - f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - ) - await extract_process.wait() - - # Copy cuDNN files to CUDA installation - cudnn_lib_src = os.path.join( - extract_dir, "usr", "lib", "x86_64-linux-gnu" - ) - cudnn_include_src = os.path.join(extract_dir, "usr", "include") - - cuda_lib_dst = os.path.join(cuda_path, "lib64") - cuda_include_dst = os.path.join(cuda_path, "include") - - if os.path.exists(cudnn_lib_src): - for f in os.listdir(cudnn_lib_src): - if "cudnn" in f.lower(): - src = os.path.join(cudnn_lib_src, f) - dst = os.path.join(cuda_lib_dst, f) - try: - if os.path.islink(src): - linkto = os.readlink(src) - if os.path.exists(dst): - os.remove(dst) - os.symlink(linkto, dst) - else: - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA lib directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - if os.path.exists(cudnn_include_src): - for f in os.listdir(cudnn_include_src): - if "cudnn" in f.lower(): - src = os.path.join(cudnn_include_src, f) - dst = os.path.join(cuda_include_dst, f) - try: - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA include directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Cleanup temporary extract directory only (keep .deb files) - shutil.rmtree(extract_dir, ignore_errors=True) - # Keep .deb files for future use - logger.info(f"cuDNN packages kept at: {cudnn_path}, {cudnn_dev_path}") - - await self._broadcast_log_line("cuDNN extracted to CUDA directory") - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 100, - "message": "cuDNN installed successfully", - } - ) - else: - await self._broadcast_log_line( - "cuDNN packages not available, skipping cuDNN installation" - ) - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 100, - "message": "cuDNN installation skipped (optional)", - } - ) - - except Exception as e: - await self._broadcast_log_line(f"cuDNN installation error: {e}") - await self._broadcast_log_line( - "Note: cuDNN is optional. The build will continue without cuDNN support." - ) - await self._broadcast_progress( - { - "stage": "cudnn", - "progress": 100, - "message": "cuDNN installation skipped (optional)", - } - ) - - async def _install_tensorrt_linux(self, cuda_version: str, cuda_path: str) -> None: - """Install TensorRT library for inference optimization.""" - await self._broadcast_log_line( - "Installing TensorRT (NVIDIA TensorRT inference library)..." - ) - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 0, - "message": "Installing TensorRT...", - } - ) - - try: - # Determine CUDA major version for TensorRT compatibility - cuda_major = cuda_version.split(".")[0] - tensorrt_version = self.TENSORRT_VERSIONS.get(cuda_major) - - if not tensorrt_version: - await self._broadcast_log_line( - f"TensorRT version not available for CUDA {cuda_version}, skipping" - ) - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 100, - "message": "TensorRT installation skipped (version not available)", - } - ) - return - - ubuntu_version = self._get_ubuntu_version() - - # TensorRT package names - # For CUDA 12.x/13.x: libnvinfer10, libnvinfer-dev, libnvinfer-plugin10, libnvinfer-plugin-dev - # For CUDA 11.x: libnvinfer8, libnvinfer-dev, libnvinfer-plugin8, libnvinfer-plugin-dev - if cuda_major == "12" or cuda_major == "13": - tensorrt_pkg = "libnvinfer10" - tensorrt_plugin_pkg = "libnvinfer-plugin10" - else: - tensorrt_pkg = "libnvinfer8" - tensorrt_plugin_pkg = "libnvinfer-plugin8" - - # Manual TensorRT installation - await self._broadcast_log_line("Installing TensorRT packages...") - - packages = await self._get_repo_packages(ubuntu_version) - tensorrt_pkg_entry = self._select_repo_package( - packages, tensorrt_pkg, version_prefix=tensorrt_version - ) - tensorrt_dev_pkg_entry = self._select_repo_package( - packages, f"{tensorrt_pkg}-dev", version_prefix=tensorrt_version - ) - tensorrt_plugin_entry = self._select_repo_package( - packages, tensorrt_plugin_pkg, version_prefix=tensorrt_version - ) - tensorrt_plugin_dev_entry = self._select_repo_package( - packages, f"{tensorrt_plugin_pkg}-dev", version_prefix=tensorrt_version - ) - - if not all( - [ - tensorrt_pkg_entry, - tensorrt_dev_pkg_entry, - tensorrt_plugin_entry, - tensorrt_plugin_dev_entry, - ] - ): - await self._broadcast_log_line( - "TensorRT packages not found in repository, skipping TensorRT installation" - ) - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 100, - "message": "TensorRT installation skipped (optional)", - } - ) - return - - base_url = ( - f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" - ) - tensorrt_url = base_url + tensorrt_pkg_entry.get("Filename", "").lstrip("./") - tensorrt_dev_url = base_url + tensorrt_dev_pkg_entry.get("Filename", "").lstrip("./") - tensorrt_plugin_url = base_url + tensorrt_plugin_entry.get("Filename", "").lstrip("./") - tensorrt_plugin_dev_url = base_url + tensorrt_plugin_dev_entry.get("Filename", "").lstrip("./") - - tensorrt_path = os.path.join(self._download_dir, f"{tensorrt_pkg}.deb") - tensorrt_dev_path = os.path.join(self._download_dir, f"{tensorrt_pkg}-dev.deb") - tensorrt_plugin_path = os.path.join(self._download_dir, f"{tensorrt_plugin_pkg}.deb") - tensorrt_plugin_dev_path = os.path.join(self._download_dir, f"{tensorrt_plugin_pkg}-dev.deb") - - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 25, - "message": "Downloading TensorRT packages...", - } - ) - - # Download TensorRT packages - async with aiohttp.ClientSession() as session: - for url, path, name in [ - (tensorrt_url, tensorrt_path, tensorrt_pkg), - (tensorrt_dev_url, tensorrt_dev_path, f"{tensorrt_pkg}-dev"), - (tensorrt_plugin_url, tensorrt_plugin_path, tensorrt_plugin_pkg), - (tensorrt_plugin_dev_url, tensorrt_plugin_dev_path, f"{tensorrt_plugin_pkg}-dev"), - ]: - try: - await self._broadcast_log_line(f"Downloading {name}...") - async with session.get(url) as response: - if response.status == 200: - async with aiofiles.open(path, "wb") as f: - await f.write(await response.read()) - await self._broadcast_log_line(f"Downloaded {name}") - else: - await self._broadcast_log_line( - f"Failed to download {name}: HTTP {response.status}" - ) - continue - except Exception as download_err: - await self._broadcast_log_line( - f"Download error for {name}: {download_err}" - ) - continue - - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 50, - "message": "Installing TensorRT packages...", - } - ) - - if os.path.exists(tensorrt_path): - await self._broadcast_log_line( - "Extracting TensorRT to CUDA directory..." - ) - - # Extract .deb file - extract_dir = os.path.join(self._download_dir, "tensorrt_extract") - os.makedirs(extract_dir, exist_ok=True) - - for deb_path in [ - tensorrt_path, - tensorrt_dev_path, - tensorrt_plugin_path, - tensorrt_plugin_dev_path, - ]: - if os.path.exists(deb_path): - # Extract using ar and tar - extract_process = await asyncio.create_subprocess_exec( - "bash", - "-c", - f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - ) - await extract_process.wait() - - # Copy TensorRT files to CUDA installation - tensorrt_lib_src = os.path.join( - extract_dir, "usr", "lib", "x86_64-linux-gnu" - ) - tensorrt_include_src = os.path.join(extract_dir, "usr", "include") - tensorrt_bin_src = os.path.join(extract_dir, "usr", "bin") - - cuda_lib_dst = os.path.join(cuda_path, "lib64") - cuda_include_dst = os.path.join(cuda_path, "include") - cuda_bin_dst = os.path.join(cuda_path, "bin") - - # Copy libraries - if os.path.exists(tensorrt_lib_src): - for f in os.listdir(tensorrt_lib_src): - if "nvinfer" in f.lower() or "tensorrt" in f.lower(): - src = os.path.join(tensorrt_lib_src, f) - dst = os.path.join(cuda_lib_dst, f) - try: - if os.path.islink(src): - linkto = os.readlink(src) - if os.path.exists(dst): - os.remove(dst) - os.symlink(linkto, dst) - else: - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA lib directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Copy headers - if os.path.exists(tensorrt_include_src): - for f in os.listdir(tensorrt_include_src): - if "nvinfer" in f.lower() or "tensorrt" in f.lower(): - src = os.path.join(tensorrt_include_src, f) - dst = os.path.join(cuda_include_dst, f) - try: - if os.path.isdir(src): - shutil.copytree(src, dst, dirs_exist_ok=True) - else: - shutil.copy2(src, dst) - await self._broadcast_log_line( - f"Copied {f} to CUDA include directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Copy binaries (like trtexec) - if os.path.exists(tensorrt_bin_src): - for f in os.listdir(tensorrt_bin_src): - if "trt" in f.lower() or "nvinfer" in f.lower(): - src = os.path.join(tensorrt_bin_src, f) - dst = os.path.join(cuda_bin_dst, f) - try: - shutil.copy2(src, dst) - os.chmod(dst, 0o755) - await self._broadcast_log_line( - f"Copied {f} to CUDA bin directory" - ) - except Exception as copy_err: - await self._broadcast_log_line( - f"Failed to copy {f}: {copy_err}" - ) - - # Cleanup temporary extract directory only (keep .deb files) - shutil.rmtree(extract_dir, ignore_errors=True) - # Keep .deb files for future use - logger.info( - f"TensorRT packages kept at: {tensorrt_path}, {tensorrt_dev_path}, " - f"{tensorrt_plugin_path}, {tensorrt_plugin_dev_path}" - ) - - await self._broadcast_log_line("TensorRT extracted to CUDA directory") - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 100, - "message": "TensorRT installed successfully", - } - ) - else: - await self._broadcast_log_line( - "TensorRT packages not available, skipping TensorRT installation" - ) - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 100, - "message": "TensorRT installation skipped (optional)", - } - ) - - except Exception as e: - await self._broadcast_log_line(f"TensorRT installation error: {e}") - await self._broadcast_log_line( - "Note: TensorRT is optional. The build will continue without TensorRT support." - ) - await self._broadcast_progress( - { - "stage": "tensorrt", - "progress": 100, - "message": "TensorRT installation skipped (optional)", - } - ) - - async def install( - self, - version: str = "12.6", - install_cudnn: bool = False, - install_tensorrt: bool = False, - ) -> Dict[str, Any]: - """Install CUDA Toolkit with optional cuDNN and TensorRT.""" - async with self._lock: - if self._operation: - raise RuntimeError( - "Another CUDA installer operation is already running" - ) - - system, arch = self._get_platform() - - if system != "linux": - raise RuntimeError( - f"CUDA installation is only supported on Linux, not {system}" - ) - - if version not in self.SUPPORTED_VERSIONS: - raise ValueError( - f"Unsupported CUDA version: {version}. Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}" - ) - - # Fetch the download URL dynamically - await self._broadcast_log_line( - f"Fetching download URL for CUDA {version}..." - ) - url = await self._fetch_download_url(version) - installer_filename = os.path.basename(url) - installer_path = os.path.join(self._download_dir, installer_filename) - - await self._set_operation("install") - - async def _runner(): - try: - # Download installer - await self._download_installer(version, url, installer_path) - - # Install (Linux only) - returns the installation path - install_path = await self._install_linux( - installer_path, version, install_cudnn, install_tensorrt - ) - - # Update state (already saved in _install_linux, but update main fields) - state = self._load_state() - state["installed_version"] = version - state["installed_at"] = _utcnow() - state["cuda_path"] = install_path - if install_cudnn: - state["cudnn_installed"] = True - if install_tensorrt: - state["tensorrt_installed"] = True - self._save_state(state) - - components = ["CUDA Toolkit"] - if install_cudnn: - components.append("cuDNN") - if install_tensorrt: - components.append("TensorRT") - - await self._finish_operation( - True, f"{', '.join(components)} installed successfully" - ) - - # Update current process environment with CUDA paths - # This ensures the running application can use CUDA immediately - cuda_env = self.get_cuda_env(version) - if cuda_env: - os.environ.update(cuda_env) - logger.info( - f"Updated process environment with CUDA {version} paths" - ) - - # Restart llama-swap to pick up new CUDA environment variables - # llama-swap needs to be restarted because subprocess environment - # variables are set at process creation time and can't be changed - try: - from backend.llama_swap_manager import get_llama_swap_manager - llama_swap_manager = get_llama_swap_manager() - await llama_swap_manager.restart_proxy() - logger.info("Restarted llama-swap to pick up new CUDA environment") - except Exception as restart_error: - # Don't fail the installation if restart fails - logger.warning( - f"Failed to restart llama-swap after CUDA installation: {restart_error}. " - f"You may need to manually restart llama-swap to use the new CUDA version." - ) - - # Keep installer file for future use (not deleting) - logger.info(f"Installer file kept at: {installer_path}") - - except Exception as exc: - self._last_error = str(exc) - await self._finish_operation(False, str(exc)) - raise - - self._create_task(_runner()) - return {"message": f"CUDA {version} installation started"} - - def _detect_cudnn_version(self, cuda_path: Optional[str]) -> Optional[str]: - """Detect installed cuDNN version by checking library files.""" - if not cuda_path: - return None - - lib_path = os.path.join(cuda_path, "lib64") - if not os.path.exists(lib_path): - return None - - try: - for f in os.listdir(lib_path): - if "libcudnn" in f and ".so" in f: - match = re.search(r"\.so(?:\.(\d+(?:\.\d+){0,2}))?", f) - if match and match.group(1): - return match.group(1) - except Exception: - pass - - return None - - def _detect_tensorrt_version(self, cuda_path: Optional[str]) -> Optional[str]: - """Detect installed TensorRT version by checking library files.""" - if not cuda_path: - return None - - lib_path = os.path.join(cuda_path, "lib64") - if not os.path.exists(lib_path): - return None - - try: - for f in os.listdir(lib_path): - if "libnvinfer" in f and ".so" in f and "plugin" not in f: - match = re.search(r"\.so(?:\.(\d+(?:\.\d+){0,2}))?", f) - if match and match.group(1): - return match.group(1) - except Exception: - pass - - return None - - def status(self) -> Dict[str, Any]: - """Get CUDA installation status.""" - version = self._detect_installed_version() - cuda_path = self._get_cuda_path() - installed = version is not None and cuda_path is not None - state = self._load_state() - installations = state.get("installations", {}) - - # Detect cuDNN and TensorRT - cudnn_version = None - tensorrt_version = None - if cuda_path: - cudnn_version = self._detect_cudnn_version(cuda_path) - tensorrt_version = self._detect_tensorrt_version(cuda_path) - - # Get all installed versions with their details - installed_versions = [] - for v, info in installations.items(): - install_path = info.get("path") - if install_path and os.path.exists(install_path): - installed_versions.append( - { - "version": v, - "path": install_path, - "installed_at": info.get("installed_at"), - "is_system_install": info.get("is_system_install", False), - "is_current": v == version, - "cudnn_installed": info.get("cudnn_installed", False), - "tensorrt_installed": info.get("tensorrt_installed", False), - } - ) - - return { - "installed": installed, - "version": version, - "cuda_path": cuda_path, - "installed_at": state.get("installed_at"), - "installed_versions": installed_versions, - "operation": self._operation, - "operation_started_at": self._operation_started_at, - "last_error": self._last_error, - "log_path": self._log_path, - "available_versions": self.SUPPORTED_VERSIONS, - "platform": self._get_platform(), - "cudnn": { - "installed": cudnn_version is not None, - "version": cudnn_version, - }, - "tensorrt": { - "installed": tensorrt_version is not None, - "version": tensorrt_version, - }, - } - - def is_operation_running(self) -> bool: - return self._operation is not None - - def read_log_tail(self, max_bytes: int = 8192) -> str: - if not os.path.exists(self._log_path): - return "" - with open(self._log_path, "rb") as log_file: - log_file.seek(0, os.SEEK_END) - size = log_file.tell() - log_file.seek(max(0, size - max_bytes)) - data = log_file.read().decode("utf-8", errors="replace") - if size > max_bytes: - data = data.split("\n", 1)[-1] - return data.strip() - - async def uninstall(self, version: Optional[str] = None) -> Dict[str, Any]: - """Uninstall CUDA Toolkit.""" - async with self._lock: - if self._operation: - raise RuntimeError( - "Another CUDA installer operation is already running" - ) - - # Determine which version to uninstall - if not version: - # Uninstall the currently detected version - version = self._detect_installed_version() - if not version: - raise RuntimeError("No CUDA installation found to uninstall") - - state = self._load_state() - installations = state.get("installations", {}) - - if version not in installations: - raise RuntimeError(f"CUDA {version} installation not found in state") - - install_info = installations[version] - install_path = install_info.get("path") - - if not install_path or not os.path.exists(install_path): - # Path doesn't exist, just remove from state - logger.warning( - f"CUDA installation path {install_path} does not exist, removing from state only" - ) - installations.pop(version, None) - if state.get("installed_version") == version: - state["installed_version"] = None - state["installed_at"] = None - state["cuda_path"] = None - self._save_state(state) - return { - "message": f"CUDA {version} removed from state (installation path not found)" - } - - await self._set_operation("uninstall") - - async def _runner(): - try: - await self._broadcast_log_line( - f"Starting uninstallation of CUDA {version}..." - ) - await self._broadcast_progress( - { - "stage": "uninstall", - "progress": 0, - "message": f"Uninstalling CUDA {version}...", - } - ) - - # Remove the installation directory - if os.path.exists(install_path): - await self._broadcast_log_line( - f"Removing installation directory: {install_path}" - ) - try: - shutil.rmtree(install_path) - await self._broadcast_log_line( - f"Successfully removed {install_path}" - ) - except Exception as e: - logger.error( - f"Failed to remove CUDA installation directory: {e}" - ) - raise RuntimeError( - f"Failed to remove installation directory: {e}" - ) - - # Update state - installations.pop(version, None) - if state.get("installed_version") == version: - state["installed_version"] = None - state["installed_at"] = None - state["cuda_path"] = None - self._save_state(state) - - # Update or remove the current symlink - self._remove_current_symlink() - await self._broadcast_log_line( - "Updated CUDA current symlink (removed or re-pointed to another version)" - ) - - await self._broadcast_progress( - { - "stage": "uninstall", - "progress": 100, - "message": "CUDA uninstallation completed", - } - ) - await self._broadcast_log_line( - f"CUDA {version} uninstalled successfully" - ) - await self._finish_operation( - True, f"CUDA {version} uninstalled successfully" - ) - - except Exception as exc: - self._last_error = str(exc) - await self._finish_operation(False, str(exc)) - raise - - self._create_task(_runner()) - return {"message": f"CUDA {version} uninstallation started"} +""" +CUDA Toolkit Installer + +Handles downloading and installing CUDA Toolkit on Linux systems. +""" + +import asyncio +import json +import os +import platform +import re +import shutil +import subprocess +import sys +import tempfile +import time +import gzip +from datetime import datetime, timezone +from typing import Any, Awaitable, Dict, Optional, Tuple +import aiohttp +import aiofiles + +from backend.logging_config import get_logger +from backend.progress_manager import get_progress_manager + +logger = get_logger(__name__) + +_installer_instance: Optional["CUDAInstaller"] = None + + +def get_cuda_installer() -> "CUDAInstaller": + global _installer_instance + if _installer_instance is None: + _installer_instance = CUDAInstaller() + return _installer_instance + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +class CUDAInstaller: + """Install CUDA Toolkit on Linux systems.""" + + # Supported CUDA versions - URLs are fetched dynamically from NVIDIA's archive + # Format: version -> platform -> architecture (URLs fetched on demand) + SUPPORTED_VERSIONS = [ + "13.0", + "12.9", + "12.8", + "12.7", + "12.6", + "12.5", + "12.4", + "12.3", + "12.2", + "12.1", + "12.0", + "11.9", + "11.8", + ] + + # cuDNN version mappings by CUDA major version + CUDNN_VERSIONS = { + "13": "9.5.1", # cuDNN 9.x for CUDA 13.x + "12": "9.5.1", # cuDNN 9.x for CUDA 12.x + "11": "8.9.7", # cuDNN 8.x for CUDA 11.x + } + + # TensorRT version mappings by CUDA major version + TENSORRT_VERSIONS = { + "13": "10.7.0", # TensorRT 10.x for CUDA 13.x + "12": "10.7.0", # TensorRT 10.x for CUDA 12.x + "11": "8.6.1", # TensorRT 8.x for CUDA 11.x + } + + def __init__( + self, + *, + log_path: Optional[str] = None, + state_path: Optional[str] = None, + download_dir: Optional[str] = None, + ) -> None: + self._lock = asyncio.Lock() + self._operation: Optional[str] = None + self._operation_started_at: Optional[str] = None + self._current_task: Optional[asyncio.Task] = None + self._last_error: Optional[str] = None + self._download_progress: Dict[str, Any] = {} + self._last_logged_percentage: int = -1 + self._last_progress_broadcast_time: float = 0.0 + self._pending_progress: Optional[Dict[str, Any]] = None + self._progress_broadcast_count: int = 0 + + # Determine data root - check Docker path first, then fallback to local + if os.path.exists("/app/data"): + data_root = "/app/data" + else: + data_root = os.path.abspath("data") + + log_path = log_path or os.path.join(data_root, "logs", "cuda_install.log") + state_path = state_path or os.path.join( + data_root, "configs", "cuda_installer.json" + ) + download_dir = download_dir or os.path.join( + data_root, "temp", "cuda_installers" + ) + self._cuda_install_dir = os.path.join(data_root, "cuda") + + self._log_path = os.path.abspath(log_path) + self._state_path = os.path.abspath(state_path) + self._download_dir = os.path.abspath(download_dir) + self._url_cache: Dict[str, str] = {} # Cache for dynamically fetched URLs + self._repo_cache: Dict[str, list] = {} # Cache for NVIDIA repo packages + self._ensure_directories() + + def _ensure_directories(self) -> None: + os.makedirs(self._download_dir, exist_ok=True) + os.makedirs(os.path.dirname(self._log_path), exist_ok=True) + os.makedirs(os.path.dirname(self._state_path), exist_ok=True) + os.makedirs(self._cuda_install_dir, exist_ok=True) + + def _update_current_symlink(self, install_path: str) -> None: + """Create or update the /app/data/cuda/current symlink to point to the active CUDA installation.""" + current_symlink = os.path.join(self._cuda_install_dir, "current") + try: + # Remove existing symlink if it exists + if os.path.islink(current_symlink): + os.remove(current_symlink) + elif os.path.exists(current_symlink): + # If it's not a symlink but exists, remove it (shouldn't happen, but be safe) + os.remove(current_symlink) + + # Create new symlink pointing to the installation + os.symlink(install_path, current_symlink) + logger.info(f"Updated CUDA current symlink: {current_symlink} -> {install_path}") + except OSError as e: + logger.warning(f"Failed to update CUDA current symlink: {e}") + + def _remove_current_symlink(self) -> None: + """Remove the current symlink and optionally re-point it to another installed version.""" + current_symlink = os.path.join(self._cuda_install_dir, "current") + try: + if os.path.islink(current_symlink) or os.path.exists(current_symlink): + os.remove(current_symlink) + + # Try to find another installed version to point to + state = self._load_state() + installations = state.get("installations", {}) + + # Find the most recently installed version that still exists + latest_version = None + latest_time = None + for v, info in installations.items(): + install_path = info.get("path") + if install_path and os.path.exists(install_path): + installed_at = info.get("installed_at", "") + if not latest_time or installed_at > latest_time: + latest_time = installed_at + latest_version = v + + # Re-point to the latest remaining installation + if latest_version: + install_path = installations[latest_version].get("path") + if install_path and os.path.exists(install_path): + os.symlink(install_path, current_symlink) + logger.info(f"Re-pointed CUDA current symlink to: {install_path}") + except OSError as e: + logger.warning(f"Failed to update CUDA current symlink: {e}") + + def _get_platform(self) -> Tuple[str, str]: + """Get platform (os, arch) tuple.""" + system = platform.system().lower() + machine = platform.machine().lower() + + if machine in ("x86_64", "amd64"): + arch = "x86_64" + else: + arch = machine + + return system, arch + + def _get_ubuntu_version(self) -> str: + """Get Ubuntu version for NVIDIA repository URLs.""" + # Try to detect Ubuntu version from /etc/os-release + try: + if os.path.exists("/etc/os-release"): + with open("/etc/os-release", "r") as f: + for line in f: + if line.startswith("VERSION_ID="): + version = line.split("=")[1].strip().strip('"') + # Extract major.minor (e.g., "24.04" from "24.04.1") + parts = version.split(".") + if len(parts) >= 2: + major_minor = f"{parts[0]}{parts[1]}" + # Check if it's 24.04 or newer + if major_minor >= "2404": + return "ubuntu2404" + else: + return "ubuntu2204" + except Exception: + pass + + # Default to ubuntu2404 for Ubuntu 24.04 base image + return "ubuntu2404" + + def _get_archive_target_version(self) -> str: + """Get archive target version for CUDA runfile lookups.""" + ubuntu_version = self._get_ubuntu_version() + if ubuntu_version == "ubuntu2404": + return "24.04" + return "22.04" + + async def _get_repo_packages(self, ubuntu_version: str) -> list: + """Fetch and cache NVIDIA CUDA repo package metadata.""" + if ubuntu_version in self._repo_cache: + return self._repo_cache[ubuntu_version] + + base_url = ( + f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64" + ) + packages_url = f"{base_url}/Packages.gz" + packages_plain_url = f"{base_url}/Packages" + packages: list = [] + + async with aiohttp.ClientSession() as session: + data = None + try: + async with session.get(packages_url) as response: + if response.status == 200: + compressed = await response.read() + data = gzip.decompress(compressed) + except Exception: + data = None + + if data is None: + try: + async with session.get(packages_plain_url) as response: + if response.status == 200: + data = await response.read() + except Exception: + data = None + + if not data: + self._repo_cache[ubuntu_version] = [] + return [] + + text = data.decode("utf-8", errors="replace") + current = {} + for line in text.splitlines(): + if not line.strip(): + if current: + packages.append(current) + current = {} + continue + if line.startswith("Package:"): + current["Package"] = line.split(":", 1)[1].strip() + elif line.startswith("Version:"): + current["Version"] = line.split(":", 1)[1].strip() + elif line.startswith("Filename:"): + current["Filename"] = line.split(":", 1)[1].strip() + + if current: + packages.append(current) + + self._repo_cache[ubuntu_version] = packages + return packages + + def _version_key(self, version: str) -> tuple: + """Create a sortable key for package version strings.""" + tokens = re.split(r"[^\w]+", version) + key = [] + for token in tokens: + if token.isdigit(): + key.append(int(token)) + elif token: + key.append(token) + return tuple(key) + + def _select_repo_package( + self, + packages: list, + package_name: str, + version_prefix: Optional[str] = None, + version_contains: Optional[str] = None, + ) -> Optional[Dict[str, str]]: + """Select the best matching package from repo metadata.""" + candidates = [ + pkg for pkg in packages if pkg.get("Package") == package_name + ] + if version_prefix: + candidates = [ + pkg + for pkg in candidates + if pkg.get("Version", "").startswith(version_prefix) + ] + if version_contains: + candidates = [ + pkg + for pkg in candidates + if version_contains in pkg.get("Version", "") + ] + if not candidates: + return None + return max(candidates, key=lambda pkg: self._version_key(pkg.get("Version", ""))) + + def _load_state(self) -> Dict[str, Any]: + if not os.path.exists(self._state_path): + return {} + try: + with open(self._state_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except Exception as exc: + logger.warning(f"Failed to load CUDA installer state: {exc}") + return {} + + def _save_state(self, state: Dict[str, Any]) -> None: + tmp_path = f"{self._state_path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(state, f, indent=2) + os.replace(tmp_path, self._state_path) + + def _detect_installed_version(self) -> Optional[str]: + """Detect installed CUDA version by checking nvcc or state.""" + # First check state for installed versions + state = self._load_state() + installations = state.get("installations", {}) + if installations: + # Return the most recently installed version + latest_version = None + latest_time = None + for v, info in installations.items(): + installed_at = info.get("installed_at", "") + if not latest_time or installed_at > latest_time: + latest_time = installed_at + latest_version = v + if latest_version: + install_path = installations[latest_version].get("path") + if install_path and os.path.exists(install_path): + return latest_version + + # Fallback: try to detect via nvcc command + try: + # Get CUDA environment to find nvcc + cuda_env = self.get_cuda_env() + env = os.environ.copy() + env.update(cuda_env) + + nvcc_path = shutil.which("nvcc", path=env.get("PATH", "")) + if not nvcc_path: + return None + + result = subprocess.run( + [nvcc_path, "--version"], + capture_output=True, + text=True, + timeout=5, + env=env, + ) + if result.returncode == 0: + # Parse version from output + for line in result.stdout.split("\n"): + if "release" in line.lower(): + parts = line.split() + for i, part in enumerate(parts): + if "release" in part.lower() and i + 1 < len(parts): + version_str = parts[i + 1].rstrip(",") + # Extract major.minor + version_parts = version_str.split(".") + if len(version_parts) >= 2: + return f"{version_parts[0]}.{version_parts[1]}" + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + pass + return None + + def _get_cuda_path(self, version: Optional[str] = None) -> Optional[str]: + """Get CUDA installation path.""" + # First, check the current symlink (most reliable for active installation) + current_symlink = os.path.join(self._cuda_install_dir, "current") + if os.path.islink(current_symlink) or os.path.exists(current_symlink): + try: + resolved_path = os.path.realpath(current_symlink) + if os.path.exists(resolved_path): + nvcc_path = os.path.join(resolved_path, "bin", "nvcc") + if os.path.exists(nvcc_path): + return resolved_path + except (OSError, ValueError): + pass + + # Check state for installed versions + state = self._load_state() + installations = state.get("installations", {}) + + # If version specified, return that installation path + if version and version in installations: + install_path = installations[version].get("path") + if install_path and os.path.exists(install_path): + return install_path + + # Check for latest installed version in state + if installations: + # Get the most recently installed version + latest_version = None + latest_time = None + for v, info in installations.items(): + installed_at = info.get("installed_at", "") + if not latest_time or installed_at > latest_time: + latest_time = installed_at + latest_version = v + + if latest_version: + install_path = installations[latest_version].get("path") + if install_path and os.path.exists(install_path): + return install_path + + # Check environment variables (only accept paths under data directory) + env_path = os.environ.get("CUDA_PATH") or os.environ.get("CUDA_HOME") + if ( + env_path + and os.path.exists(env_path) + and os.path.abspath(env_path).startswith(self._cuda_install_dir) + ): + return env_path + + # Scan the data directory for CUDA installs as fallback + try: + if os.path.exists(self._cuda_install_dir): + for item in sorted(os.listdir(self._cuda_install_dir), reverse=True): + # Skip the current symlink + if item == "current": + continue + full_path = os.path.join(self._cuda_install_dir, item) + if os.path.isdir(full_path): + nvcc_path = os.path.join(full_path, "bin", "nvcc") + if os.path.exists(nvcc_path): + return full_path + except OSError: + pass + + return None + + def get_cuda_env(self, version: Optional[str] = None) -> Dict[str, str]: + """Get environment variables for CUDA installation.""" + cuda_path = self._get_cuda_path(version) + if not cuda_path: + return {} + + cuda_bin = os.path.join(cuda_path, "bin") + cuda_lib = os.path.join(cuda_path, "lib64") + + env = { + "CUDA_HOME": cuda_path, + "CUDA_PATH": cuda_path, + } + + # Add to PATH if bin directory exists + if os.path.exists(cuda_bin): + current_path = os.environ.get("PATH", "") + if cuda_bin not in current_path: + env["PATH"] = f"{cuda_bin}:{current_path}" if current_path else cuda_bin + + # Add to LD_LIBRARY_PATH if lib64 directory exists + if os.path.exists(cuda_lib): + current_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + if cuda_lib not in current_ld_path: + env["LD_LIBRARY_PATH"] = ( + f"{cuda_lib}:{current_ld_path}" if current_ld_path else cuda_lib + ) + + # Add TensorRT path if TensorRT is installed + tensorrt_version = self._detect_tensorrt_version(cuda_path) + if tensorrt_version: + env["TENSORRT_PATH"] = cuda_path + env["TENSORRT_ROOT"] = cuda_path + + return env + + def _get_archive_url(self, version: str) -> str: + """Get NVIDIA download archive URL for a CUDA version.""" + # Convert version like "12.8" to "12-8-0" for URL + version_parts = version.split(".") + major = version_parts[0] + minor = version_parts[1] if len(version_parts) > 1 else "0" + patch = version_parts[2] if len(version_parts) > 2 else "0" + version_slug = f"{major}-{minor}-{patch}" + target_version = self._get_archive_target_version() + + return ( + f"https://developer.nvidia.com/cuda-{version_slug}-download-archive" + f"?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version={target_version}&target_type=runfile_local" + ) + + async def _fetch_download_url(self, version: str) -> str: + """Fetch the actual download URL from NVIDIA's archive page.""" + # Check cache first + cache_key = f"{version}_linux_x86_64" + if cache_key in self._url_cache: + return self._url_cache[cache_key] + + archive_url = self._get_archive_url(version) + logger.info(f"Fetching CUDA {version} download URL from {archive_url}") + + async with aiohttp.ClientSession() as session: + try: + async with session.get( + archive_url, timeout=aiohttp.ClientTimeout(total=30) + ) as response: + if response.status != 200: + raise RuntimeError( + f"Failed to fetch archive page: HTTP {response.status}" + ) + + html = await response.text() + + # The page contains JSON data with download URLs + # The JSON structure has keys like "Linux/x86_64/Ubuntu/24.04/runfile_local" + # The URL is in the "details" field which contains HTML with href attributes + target_version = self._get_archive_target_version() + json_key = f"Linux/x86_64/Ubuntu/{target_version}/runfile_local" + + # Pattern 1: Look for href in the details field (HTML may be escaped) + # Match: "Linux/x86_64/Ubuntu//runfile_local":{..."details":"...href=\"URL\"..."} + pattern1 = rf'"{re.escape(json_key)}"[^}}]*"details"[^"]*href[=:][\\"]*([^"\\s<>]+cuda_\d+\.\d+\.\d+_[^"\\s<>]+_linux\.run)' + matches = re.findall(pattern1, html, re.IGNORECASE | re.DOTALL) + + if not matches: + # Pattern 2: Look for href with escaped quotes (\u0022 or \") + pattern2 = rf'"{re.escape(json_key)}"[^}}]*href[\\u0022=:]*([^"\\s<>]+cuda_\d+\.\d+\.\d+_[^"\\s<>]+_linux\.run)' + matches = re.findall(pattern2, html, re.IGNORECASE | re.DOTALL) + + if not matches: + # Pattern 3: Look for the filename field and construct URL + pattern3 = rf'"{re.escape(json_key)}"[^}}]*"filename"[^"]*"([^"]+_linux\.run)"' + filename_matches = re.findall(pattern3, html, re.IGNORECASE) + if filename_matches: + filename = filename_matches[0] + version_full = f"{version}.0" + url = f"https://developer.download.nvidia.com/compute/cuda/{version_full}/local_installers/{filename}" + matches = [url] + + if not matches: + # Pattern 4: Fallback - look for any URL matching the pattern + version_escaped = version.replace(".", r"\.") + pattern4 = rf'https://developer\.download\.nvidia\.com/compute/cuda/{version_escaped}\.0/local_installers/cuda_{version_escaped}\.0_[^"\'\s<>]+_linux\.run' + matches = re.findall(pattern4, html, re.IGNORECASE) + + if matches: + url = matches[0] + # Cache it + self._url_cache[cache_key] = url + logger.info(f"Found CUDA {version} download URL: {url}") + return url + else: + raise RuntimeError( + f"Could not find download URL for CUDA {version} on archive page" + ) + + except aiohttp.ClientError as e: + raise RuntimeError(f"Failed to fetch archive page: {e}") + + async def _broadcast_log_line(self, line: str) -> None: + try: + await get_progress_manager().broadcast( + { + "type": "cuda_install_log", + "line": line, + "timestamp": _utcnow(), + } + ) + except Exception as exc: + logger.debug(f"Failed to broadcast CUDA log line: {exc}") + + async def _broadcast_progress(self, progress: Dict[str, Any]) -> None: + """Broadcast progress updates, throttled to 1 second intervals.""" + try: + current_time = time.time() + progress_value = progress.get("progress", 0) + is_complete = progress_value >= 100 + + # Always send completion updates immediately + if is_complete: + await get_progress_manager().broadcast( + { + "type": "cuda_install_progress", + **progress, + "timestamp": _utcnow(), + } + ) + self._last_progress_broadcast_time = current_time + self._pending_progress = None + return + + # Always send the first few updates immediately (first 3 updates) + # then throttle to 1 second intervals + is_first_update = self._last_progress_broadcast_time == 0.0 + time_since_last_broadcast = current_time - self._last_progress_broadcast_time + is_early_update = self._progress_broadcast_count < 3 + should_send = is_first_update or is_early_update or time_since_last_broadcast >= 1.0 + + if should_send: + await get_progress_manager().broadcast( + { + "type": "cuda_install_progress", + **progress, + "timestamp": _utcnow(), + } + ) + self._last_progress_broadcast_time = current_time + self._pending_progress = None + self._progress_broadcast_count += 1 + else: + # Store the latest progress data for next send + self._pending_progress = progress + except Exception as exc: + logger.exception(f"Failed to broadcast CUDA progress: {exc}") + + async def _set_operation(self, operation: str) -> None: + self._operation = operation + self._operation_started_at = _utcnow() + self._last_error = None + await get_progress_manager().broadcast( + { + "type": "cuda_install_status", + "status": operation, + "started_at": self._operation_started_at, + } + ) + + async def _finish_operation(self, success: bool, message: str = "") -> None: + payload = { + "type": "cuda_install_status", + "status": "completed" if success else "failed", + "operation": self._operation, + "message": message, + "ended_at": _utcnow(), + } + await get_progress_manager().broadcast(payload) + self._operation = None + self._operation_started_at = None + + def _create_task(self, coro: Awaitable[Any]) -> None: + loop = asyncio.get_running_loop() + task = loop.create_task(coro) + self._current_task = task + + def _cleanup(fut: asyncio.Future) -> None: + try: + fut.result() + except Exception as exc: + logger.exception("CUDA installer task error") + finally: + self._current_task = None + + task.add_done_callback(_cleanup) + + async def _download_installer( + self, version: str, url: str, installer_path: str + ) -> None: + """Download CUDA installer with progress tracking.""" + # Check if installer already exists + if os.path.exists(installer_path): + file_size = os.path.getsize(installer_path) + file_size_mb = file_size / (1024 * 1024) + + # Verify existing file is not corrupted (should be at least 100MB for CUDA installers) + if file_size < 100 * 1024 * 1024: + await self._broadcast_log_line( + f"Existing installer file appears corrupted (too small: {file_size_mb:.1f} MB), re-downloading..." + ) + try: + os.remove(installer_path) + except OSError: + pass + else: + # Verify the file is actually valid and matches expected size from server + try: + # First, check if it's a valid shell script + with open(installer_path, "rb") as f: + header = f.read(100) + if not header.startswith(b"#!/"): + await self._broadcast_log_line( + f"Existing installer file is not a valid shell script, re-downloading..." + ) + try: + os.remove(installer_path) + except OSError: + pass + else: + # File appears valid, now verify size matches server expectation + # Fetch the expected file size from the server + try: + async with aiohttp.ClientSession() as session: + async with session.head(url, allow_redirects=True) as head_response: + expected_size = int(head_response.headers.get("Content-Length", 0)) + + if expected_size > 0: + # Verify file size matches (with small tolerance) + size_diff = abs(file_size - expected_size) + if size_diff > 1024: # Allow 1KB tolerance + await self._broadcast_log_line( + f"Existing installer file size mismatch: expected {expected_size / (1024*1024):.1f} MB, " + f"got {file_size_mb:.1f} MB (difference: {size_diff} bytes). Re-downloading..." + ) + try: + os.remove(installer_path) + except OSError: + pass + else: + # File size matches, verify it's stable (not currently being written) + await asyncio.sleep(0.2) # Brief pause to ensure file is fully written if being written + new_size = os.path.getsize(installer_path) + if new_size != file_size: + await self._broadcast_log_line( + f"File size changed during verification (was {file_size_mb:.1f} MB, now {new_size / (1024*1024):.1f} MB), " + f"file may still be downloading. Re-downloading..." + ) + try: + os.remove(installer_path) + except OSError: + pass + else: + await self._broadcast_log_line( + f"Installer file already exists and verified: {installer_path} ({file_size_mb:.1f} MB)" + ) + await self._broadcast_progress( + { + "stage": "download", + "progress": 100, + "message": f"Using existing installer file ({file_size_mb:.1f} MB)", + } + ) + return + else: + # Couldn't get expected size, but file looks valid - use it + await self._broadcast_log_line( + f"Installer file already exists: {installer_path} ({file_size_mb:.1f} MB). " + f"Could not verify size from server, but file appears valid." + ) + await self._broadcast_progress( + { + "stage": "download", + "progress": 100, + "message": f"Using existing installer file ({file_size_mb:.1f} MB)", + } + ) + return + except Exception as size_check_error: + # If we can't verify size from server, but file looks valid, use it + await self._broadcast_log_line( + f"Could not verify file size from server: {size_check_error}. " + f"File appears valid, using existing file: {installer_path} ({file_size_mb:.1f} MB)" + ) + await self._broadcast_progress( + { + "stage": "download", + "progress": 100, + "message": f"Using existing installer file ({file_size_mb:.1f} MB)", + } + ) + return + except (OSError, IOError) as e: + await self._broadcast_log_line( + f"Failed to verify existing installer file: {e}, re-downloading..." + ) + try: + os.remove(installer_path) + except OSError: + pass + + # Reset logging state for new download + self._last_logged_percentage = -1 + self._last_progress_broadcast_time = 0.0 + self._pending_progress = None + self._progress_broadcast_count = 0 + + log_header = f"[{_utcnow()}] Downloading CUDA {version} installer from {url}\n" + with open(self._log_path, "w", encoding="utf-8") as log_file: + log_file.write(log_header) + + await self._broadcast_log_line( + f"Starting download of CUDA {version} installer..." + ) + await self._broadcast_progress( + { + "stage": "download", + "progress": 0, + "message": f"Downloading CUDA {version} installer...", + } + ) + + # Configure timeout for large file downloads: + # - total: 1 hour (3600s) for very large files and slow connections + # - connect: 30s to establish connection + # - sock_read: 5 minutes (300s) to allow for slow network during chunk reads + timeout = aiohttp.ClientTimeout( + total=3600, # 1 hour total timeout + connect=30, # 30 seconds to connect + sock_read=300, # 5 minutes per read operation + ) + + downloaded = 0 + total_size = 0 + try: + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + response.raise_for_status() + total_size = int(response.headers.get("Content-Length", 0)) + + async with aiofiles.open(installer_path, "wb") as f: + async for chunk in response.content.iter_chunked(8192): + await f.write(chunk) + downloaded += len(chunk) + + if total_size > 0: + progress = int((downloaded / total_size) * 100) + # Format sizes in MB + downloaded_mb = downloaded / (1024 * 1024) + total_mb = total_size / (1024 * 1024) + await self._broadcast_progress( + { + "stage": "download", + "progress": progress, + "message": f"Downloading CUDA {version} installer... ({downloaded_mb:.1f}/{total_mb:.1f} MB)", + "bytes_downloaded": downloaded, + "total_bytes": total_size, + } + ) + + # Log progress only at key percentage milestones (10%, 25%, 50%, 75%, 90%, 100%) + # Only log when we cross a milestone, not when we're within it + should_log = False + + # Check if we've crossed a key percentage milestone + if total_size > 0: + progress = int((downloaded / total_size) * 100) + if progress != self._last_logged_percentage and progress in [ + 10, + 25, + 50, + 75, + 90, + 100, + ]: + should_log = True + self._last_logged_percentage = progress + + if should_log: + downloaded_mb = downloaded / (1024 * 1024) + total_mb = total_size / (1024 * 1024) + log_line = f"Downloaded {downloaded_mb:.1f}/{total_mb:.1f} MB ({progress}%)\n" + with open( + self._log_path, "a", encoding="utf-8" + ) as log_file: + log_file.write(log_line) + await self._broadcast_log_line( + f"Downloaded {downloaded_mb:.1f} MB / {total_mb:.1f} MB ({progress}%)" + ) + + # File is automatically flushed when the context manager exits + except asyncio.TimeoutError as e: + # Clean up partial download on timeout + if os.path.exists(installer_path): + try: + os.remove(installer_path) + except OSError: + pass + downloaded_mb = downloaded / (1024 * 1024) if downloaded > 0 else 0 + total_mb = total_size / (1024 * 1024) if total_size > 0 else 0 + error_msg = ( + f"Download timeout: Failed to download CUDA {version} installer. " + f"Downloaded {downloaded_mb:.1f} MB of {total_mb:.1f} MB. " + f"This may be due to a slow network connection. Please try again." + ) + await self._broadcast_log_line(error_msg) + raise RuntimeError(error_msg) from e + except aiohttp.ClientError as e: + # Clean up partial download on client error + if os.path.exists(installer_path): + try: + os.remove(installer_path) + except OSError: + pass + error_msg = ( + f"Network error while downloading CUDA {version} installer: {e}. " + f"Please check your network connection and try again." + ) + await self._broadcast_log_line(error_msg) + raise RuntimeError(error_msg) from e + + # Wait a brief moment to ensure file system has fully written the file + # This helps ensure the file is completely written to disk before verification + await asyncio.sleep(0.5) + + # Verify downloaded file exists and is complete + if not os.path.exists(installer_path): + raise RuntimeError(f"Downloaded file not found: {installer_path}") + + # Verify file size matches expected size (with a small tolerance for filesystem differences) + actual_size = os.path.getsize(installer_path) + if total_size > 0: + size_diff = abs(actual_size - total_size) + if size_diff > 1024: # Allow 1KB tolerance for filesystem differences + raise RuntimeError( + f"Downloaded file size mismatch: expected {total_size} bytes, " + f"got {actual_size} bytes (difference: {size_diff} bytes). File may be corrupted or incomplete." + ) + + if actual_size < 100 * 1024 * 1024: # Less than 100MB is suspicious + raise RuntimeError( + f"Downloaded file appears to be corrupted or incomplete: " + f"{installer_path} (size: {actual_size} bytes)" + ) + + # Verify the file is a valid shell script (CUDA .run files are self-extracting) + try: + with open(installer_path, "rb") as verify_file: + header = verify_file.read(100) + if not header.startswith(b"#!/"): + raise RuntimeError( + f"Downloaded file does not appear to be a valid shell script: {installer_path}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to verify downloaded file integrity: {installer_path}, error: {e}" + ) + + await self._broadcast_log_line( + f"Download completed and verified: {installer_path} ({actual_size / (1024*1024):.1f} MB)" + ) + await self._broadcast_progress( + { + "stage": "download", + "progress": 100, + "message": "Download completed and verified", + } + ) + + def _is_docker_container(self) -> bool: + """Check if running inside a Docker container.""" + # Check for Docker-specific files + docker_indicators = [ + "/.dockerenv", + "/proc/self/cgroup", + ] + + # Check /.dockerenv + if os.path.exists("/.dockerenv"): + return True + + # Check /proc/self/cgroup for Docker + try: + if os.path.exists("/proc/self/cgroup"): + with open("/proc/self/cgroup", "r") as f: + content = f.read() + if "docker" in content or "containerd" in content: + return True + except (OSError, IOError): + pass + + return False + + async def _install_linux( + self, + installer_path: str, + version: str, + install_cudnn: bool = False, + install_tensorrt: bool = False, + ) -> str: + """ + Install CUDA on Linux using runfile installer. + + Uses optimized installer options for custom location installation: + - Silent installation with EULA acceptance + - Toolkit-only installation (no driver) + - Override installation checks for custom paths + - Skip OpenGL libraries (not needed in Docker/headless environments) + - Skip man pages to reduce installation size + + Args: + installer_path: Path to the CUDA installer runfile + version: CUDA version being installed + install_cudnn: Whether to install cuDNN + install_tensorrt: Whether to install TensorRT + """ + await self._broadcast_log_line("Starting CUDA installation on Linux...") + await self._broadcast_progress( + { + "stage": "install", + "progress": 0, + "message": "Installing CUDA Toolkit...", + } + ) + + # Verify installer file exists and is not corrupted + if not os.path.exists(installer_path): + raise RuntimeError(f"Installer file not found: {installer_path}") + + file_size = os.path.getsize(installer_path) + if file_size < 100 * 1024 * 1024: # Less than 100MB is suspicious for CUDA installers + raise RuntimeError( + f"Installer file appears to be corrupted or incomplete: {installer_path} " + f"(size: {file_size / (1024*1024):.1f} MB, expected > 100 MB)" + ) + + # Verify the file starts with a shell script header (CUDA .run files are self-extracting) + try: + with open(installer_path, "rb") as f: + header = f.read(100) + if not header.startswith(b"#!/"): + raise RuntimeError( + f"Installer file does not appear to be a valid shell script: {installer_path}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to verify installer file: {installer_path}, error: {e}" + ) + + await self._broadcast_log_line( + f"Verifying installer file: {installer_path} ({file_size / (1024*1024):.1f} MB)" + ) + + # Make installer executable + os.chmod(installer_path, 0o755) + + # Always install to the data directory for persistence + install_path = os.path.join(self._cuda_install_dir, f"cuda-{version}") + await self._broadcast_log_line(f"Installing to data directory: {install_path}") + os.makedirs(install_path, exist_ok=True) + + # Build installer arguments with optimized options for custom location installation + # + # Selected options based on NVIDIA CUDA installer documentation: + # - --silent: Required for silent installation, implies EULA acceptance + # - --toolkit: Install toolkit only (not driver) - required for non-root installations + # - --override: Override compiler, third-party library, and toolkit detection checks + # (essential for custom installation paths) + # - --toolkitpath: Install to custom data directory path + # - --no-opengl-libs: Skip OpenGL libraries (not needed in Docker/headless environments) + # - --no-man-page: Skip man pages to reduce installation size + # + install_args = [ + "bash", + installer_path, + "--silent", # Silent installation with EULA acceptance + "--toolkit", # Install toolkit only (not driver) + "--override", # Override installation checks for custom paths + f"--toolkitpath={install_path}", # Install to custom data directory + "--no-opengl-libs", # Skip OpenGL libraries (not needed in Docker) + "--no-man-page", # Skip man pages to reduce size + ] + + await self._broadcast_log_line(f"Installer arguments: {' '.join(install_args[2:])}") # Skip 'bash' and installer_path + + # Set up environment to prevent /dev/tty access issues in Docker + env = os.environ.copy() + env["DEBIAN_FRONTEND"] = "noninteractive" + # Disable interactive prompts + env["PERL_BADLANG"] = "0" + # Ensure we're in a non-interactive environment + env["TERM"] = "dumb" + # Prevent installer from trying to access /dev/tty + env["CI"] = "true" # Indicate we're in a CI/non-interactive environment + + process = await asyncio.create_subprocess_exec( + *install_args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + stdin=asyncio.subprocess.DEVNULL, # Redirect stdin to prevent /dev/tty access + env=env, + ) + + # Collect output for error analysis + output_lines = [] + + async def _stream_output(): + if process.stdout is None: + return + with open(self._log_path, "a", encoding="utf-8", buffering=1) as log_file: + while True: + chunk = await process.stdout.readline() + if not chunk: + break + text = chunk.decode("utf-8", errors="replace") + output_lines.append(text) + log_file.write(text) + await self._broadcast_log_line(text.rstrip("\n")) + + await asyncio.gather(process.wait(), _stream_output()) + + if process.returncode != 0: + # Check for specific error patterns + output_text = "".join(output_lines) + + # Check for /dev/tty errors + if "/dev/tty" in output_text.lower() or "cannot create /dev/tty" in output_text.lower(): + error_msg = ( + f"CUDA installer failed due to /dev/tty access issue (common in Docker). " + f"This may indicate the installer file is corrupted or the environment is not properly configured. " + f"Exit code: {process.returncode}. " + f"Please check the installation logs for details. " + f"If the file appears corrupted, try deleting it and re-downloading." + ) + # Check for gzip/corruption errors + elif "gzip" in output_text.lower() and ("unexpected end" in output_text.lower() or "corrupt" in output_text.lower()): + error_msg = ( + f"CUDA installer file appears to be corrupted (gzip error detected). " + f"Please delete the installer file at {installer_path} and try again. " + f"Exit code: {process.returncode}." + ) + else: + error_msg = ( + f"CUDA installer exited with code {process.returncode}. " + "Please check the installation logs for details." + ) + + raise RuntimeError(error_msg) + + # Verify installation and set up environment + cuda_home = install_path + cuda_bin = os.path.join(cuda_home, "bin") + cuda_lib = os.path.join(cuda_home, "lib64") + + # Verify key directories exist + if not os.path.exists(cuda_bin) or not os.path.exists(cuda_lib): + raise RuntimeError( + f"CUDA installation completed but expected directories not found. " + f"Expected: {cuda_bin}, {cuda_lib}" + ) + + await self._broadcast_log_line( + f"CUDA installed successfully to: {install_path}" + ) + await self._broadcast_log_line(f"CUDA_HOME={cuda_home}") + await self._broadcast_log_line(f"Adding to PATH: {cuda_bin}") + await self._broadcast_log_line(f"Adding to LD_LIBRARY_PATH: {cuda_lib}") + + # Install NCCL (required for multi-GPU and llama.cpp CUDA builds) + await self._install_nccl_linux(version, install_path) + + # Install nvidia-smi (required for GPU monitoring) + await self._install_nvidia_smi_linux(install_path) + + # Install cuDNN if requested + if install_cudnn: + await self._install_cudnn_linux(version, install_path) + + # Install TensorRT if requested + if install_tensorrt: + await self._install_tensorrt_linux(version, install_path) + + # Save installation path to state + state = self._load_state() + if "installations" not in state: + state["installations"] = {} + state["installations"][version] = { + "path": install_path, + "installed_at": _utcnow(), + "is_system_install": False, + "cudnn_installed": install_cudnn, + "tensorrt_installed": install_tensorrt, + } + self._save_state(state) + + # Update the current symlink to point to this installation + self._update_current_symlink(install_path) + await self._broadcast_log_line( + f"Updated CUDA current symlink: /app/data/cuda/current -> {install_path}" + ) + + components = ["CUDA", "NCCL", "nvidia-smi"] + if install_cudnn: + components.append("cuDNN") + if install_tensorrt: + components.append("TensorRT") + + await self._broadcast_progress( + { + "stage": "install", + "progress": 100, + "message": f"{', '.join(components)} installation completed", + } + ) + + return install_path + + async def _install_nccl_linux(self, cuda_version: str, cuda_path: str) -> None: + """Install NCCL library for multi-GPU support.""" + await self._broadcast_log_line( + "Installing NCCL (NVIDIA Collective Communications Library)..." + ) + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 0, + "message": "Installing NCCL...", + } + ) + + ubuntu_version = self._get_ubuntu_version() + + # Download NCCL from NVIDIA's repo package index + await self._broadcast_log_line("Attempting manual NCCL installation...") + + try: + cuda_major = cuda_version.split(".")[0] + packages = await self._get_repo_packages(ubuntu_version) + nccl_pkg = self._select_repo_package( + packages, + "libnccl2", + version_prefix="2.", + version_contains=f"+cuda{cuda_major}", + ) + nccl_dev_pkg = self._select_repo_package( + packages, + "libnccl-dev", + version_prefix="2.", + version_contains=f"+cuda{cuda_major}", + ) + + if not nccl_pkg or not nccl_dev_pkg: + await self._broadcast_log_line( + "NCCL packages not found in repository, skipping NCCL installation" + ) + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 100, + "message": "NCCL installation skipped (optional)", + } + ) + return + + base_url = ( + f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" + ) + nccl_url = base_url + nccl_pkg.get("Filename", "").lstrip("./") + nccl_dev_url = base_url + nccl_dev_pkg.get("Filename", "").lstrip("./") + + nccl_path = os.path.join(self._download_dir, "libnccl2.deb") + nccl_dev_path = os.path.join(self._download_dir, "libnccl-dev.deb") + + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 25, + "message": "Downloading NCCL packages...", + } + ) + + # Download NCCL packages + async with aiohttp.ClientSession() as session: + for url, path, name in [ + (nccl_url, nccl_path, "libnccl2"), + (nccl_dev_url, nccl_dev_path, "libnccl-dev"), + ]: + try: + await self._broadcast_log_line(f"Downloading {name}...") + async with session.get(url) as response: + if response.status == 200: + async with aiofiles.open(path, "wb") as f: + await f.write(await response.read()) + await self._broadcast_log_line(f"Downloaded {name}") + else: + await self._broadcast_log_line( + f"Failed to download {name}: HTTP {response.status}" + ) + # Try alternative URL with different NCCL version + continue + except Exception as download_err: + await self._broadcast_log_line( + f"Download error for {name}: {download_err}" + ) + continue + + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 50, + "message": "Installing NCCL packages...", + } + ) + + if os.path.exists(nccl_path): + await self._broadcast_log_line( + "Extracting NCCL to CUDA directory..." + ) + + # Extract .deb file (it's an ar archive containing data.tar) + extract_dir = os.path.join(self._download_dir, "nccl_extract") + os.makedirs(extract_dir, exist_ok=True) + + for deb_path in [nccl_path, nccl_dev_path]: + if os.path.exists(deb_path): + # Extract using ar and tar + extract_process = await asyncio.create_subprocess_exec( + "bash", + "-c", + f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + await extract_process.wait() + + # Copy NCCL files to CUDA installation + nccl_lib_src = os.path.join( + extract_dir, "usr", "lib", "x86_64-linux-gnu" + ) + nccl_include_src = os.path.join(extract_dir, "usr", "include") + + cuda_lib_dst = os.path.join(cuda_path, "lib64") + cuda_include_dst = os.path.join(cuda_path, "include") + + if os.path.exists(nccl_lib_src): + # First pass: collect files and symlinks, copy actual files first + files_to_copy = [] + symlinks_to_create = [] + + for f in os.listdir(nccl_lib_src): + if "nccl" in f.lower(): + src = os.path.join(nccl_lib_src, f) + dst = os.path.join(cuda_lib_dst, f) + + if os.path.islink(src): + # Resolve symlink to find actual target + link_target = os.readlink(src) + # If relative symlink, resolve relative to source directory + if not os.path.isabs(link_target): + link_target = os.path.normpath( + os.path.join(os.path.dirname(src), link_target) + ) + # Find the actual target file name + actual_target = os.path.basename(link_target) + symlinks_to_create.append((f, actual_target, dst)) + else: + files_to_copy.append((f, src, dst)) + + # Copy all actual files first + for f, src, dst in files_to_copy: + try: + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA lib directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Then create symlinks pointing to the copied files + for link_name, target_name, dst in symlinks_to_create: + try: + if os.path.exists(dst): + os.remove(dst) + # Create symlink pointing to target in same directory + os.symlink(target_name, dst) + await self._broadcast_log_line( + f"Created symlink {link_name} -> {target_name} in CUDA lib directory" + ) + except Exception as link_err: + await self._broadcast_log_line( + f"Failed to create symlink {link_name}: {link_err}" + ) + + if os.path.exists(nccl_include_src): + for f in os.listdir(nccl_include_src): + if "nccl" in f.lower(): + src = os.path.join(nccl_include_src, f) + dst = os.path.join(cuda_include_dst, f) + try: + if os.path.isdir(src): + # Handle directories by copying recursively + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + await self._broadcast_log_line( + f"Copied directory {f} to CUDA include directory" + ) + else: + # Handle regular files + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA include directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Cleanup temporary extract directory only (keep .deb files) + shutil.rmtree(extract_dir, ignore_errors=True) + # Keep .deb files for future use + logger.info(f"NCCL packages kept at: {nccl_path}, {nccl_dev_path}") + + await self._broadcast_log_line("NCCL extracted to CUDA directory") + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 100, + "message": "NCCL installed successfully", + } + ) + else: + await self._broadcast_log_line( + "NCCL packages not available, skipping NCCL installation" + ) + await self._broadcast_log_line( + "Note: NCCL is optional but recommended for multi-GPU builds" + ) + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 100, + "message": "NCCL installation skipped (optional)", + } + ) + + except Exception as e: + await self._broadcast_log_line(f"NCCL installation error: {e}") + await self._broadcast_log_line( + "Note: NCCL is optional. The build will continue without multi-GPU support." + ) + await self._broadcast_progress( + { + "stage": "nccl", + "progress": 100, + "message": "NCCL installation skipped (optional)", + } + ) + + async def _install_nvidia_smi_linux(self, cuda_path: str) -> None: + """Install nvidia-smi binary for GPU monitoring.""" + await self._broadcast_log_line( + "Installing nvidia-smi (NVIDIA System Management Interface)..." + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 0, + "message": "Installing nvidia-smi...", + } + ) + + # Check if nvidia-smi already exists in CUDA installation + cuda_bin = os.path.join(cuda_path, "bin") + nvidia_smi_dst = os.path.join(cuda_bin, "nvidia-smi") + if os.path.exists(nvidia_smi_dst): + await self._broadcast_log_line( + "nvidia-smi already exists in CUDA installation, skipping" + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi already installed", + } + ) + return + + ubuntu_version = self._get_ubuntu_version() + + try: + # Try to find nvidia-utils package which contains nvidia-smi + packages = await self._get_repo_packages(ubuntu_version) + nvidia_utils_pkg = None + + # Try multiple package name patterns + for pkg_name in ["nvidia-utils", "nvidia-driver-utils", "nvidia-utils-"]: + nvidia_utils_pkg = self._select_repo_package( + packages, + pkg_name, + ) + if nvidia_utils_pkg: + break + + if not nvidia_utils_pkg: + await self._broadcast_log_line( + "nvidia-utils package not found in repository, skipping nvidia-smi installation" + ) + await self._broadcast_log_line( + "Note: nvidia-smi will not be available. GPU monitoring may be limited." + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi installation skipped (package not available)", + } + ) + return + + base_url = ( + f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" + ) + nvidia_utils_url = base_url + nvidia_utils_pkg.get("Filename", "").lstrip("./") + + nvidia_utils_path = os.path.join(self._download_dir, "nvidia-utils.deb") + + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 25, + "message": "Downloading nvidia-utils package...", + } + ) + + # Download nvidia-utils package + async with aiohttp.ClientSession() as session: + try: + await self._broadcast_log_line("Downloading nvidia-utils...") + async with session.get(nvidia_utils_url) as response: + if response.status == 200: + async with aiofiles.open(nvidia_utils_path, "wb") as f: + await f.write(await response.read()) + await self._broadcast_log_line("Downloaded nvidia-utils") + else: + await self._broadcast_log_line( + f"Failed to download nvidia-utils: HTTP {response.status}" + ) + raise RuntimeError(f"Failed to download nvidia-utils: HTTP {response.status}") + except Exception as download_err: + await self._broadcast_log_line( + f"Download error for nvidia-utils: {download_err}" + ) + raise + + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 50, + "message": "Extracting nvidia-smi...", + } + ) + + if os.path.exists(nvidia_utils_path): + await self._broadcast_log_line( + "Extracting nvidia-smi to CUDA directory..." + ) + + # Extract .deb file + extract_dir = os.path.join(self._download_dir, "nvidia_utils_extract") + os.makedirs(extract_dir, exist_ok=True) + + # Extract using ar and tar + extract_process = await asyncio.create_subprocess_exec( + "bash", + "-c", + f"cd {extract_dir} && ar x {nvidia_utils_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + await extract_process.wait() + + # Copy nvidia-smi binary to CUDA installation + nvidia_smi_src = os.path.join(extract_dir, "usr", "bin", "nvidia-smi") + cuda_bin_dst = os.path.join(cuda_path, "bin") + nvidia_smi_dst = os.path.join(cuda_bin_dst, "nvidia-smi") + + if os.path.exists(nvidia_smi_src): + os.makedirs(cuda_bin_dst, exist_ok=True) + try: + shutil.copy2(nvidia_smi_src, nvidia_smi_dst) + os.chmod(nvidia_smi_dst, 0o755) + await self._broadcast_log_line( + "Copied nvidia-smi to CUDA bin directory" + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi installed successfully", + } + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy nvidia-smi: {copy_err}" + ) + raise + else: + await self._broadcast_log_line( + "nvidia-smi not found in extracted package" + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi installation skipped (not in package)", + } + ) + + # Cleanup temporary extract directory only (keep .deb file) + shutil.rmtree(extract_dir, ignore_errors=True) + # Keep .deb file for future use + if os.path.exists(nvidia_utils_path): + logger.info(f"nvidia-utils package kept at: {nvidia_utils_path}") + + else: + await self._broadcast_log_line( + "nvidia-utils package not available, skipping nvidia-smi installation" + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi installation skipped (package not available)", + } + ) + + except Exception as e: + await self._broadcast_log_line(f"nvidia-smi installation error: {e}") + await self._broadcast_log_line( + "Note: nvidia-smi installation failed. GPU monitoring may be limited." + ) + await self._broadcast_progress( + { + "stage": "nvidia-smi", + "progress": 100, + "message": "nvidia-smi installation skipped (error occurred)", + } + ) + + async def _install_cudnn_linux(self, cuda_version: str, cuda_path: str) -> None: + """Install cuDNN library for deep learning primitives.""" + await self._broadcast_log_line( + "Installing cuDNN (CUDA Deep Neural Network library)..." + ) + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 0, + "message": "Installing cuDNN...", + } + ) + + try: + # Determine CUDA major version for cuDNN compatibility + cuda_major = cuda_version.split(".")[0] + cudnn_version = self.CUDNN_VERSIONS.get(cuda_major) + + if not cudnn_version: + await self._broadcast_log_line( + f"cuDNN version not available for CUDA {cuda_version}, skipping" + ) + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 100, + "message": "cuDNN installation skipped (version not available)", + } + ) + return + + ubuntu_version = self._get_ubuntu_version() + + # cuDNN package names vary by CUDA version + # For CUDA 12.x: libcudnn9-cuda-12, libcudnn9-dev-cuda-12 + # For CUDA 11.x: libcudnn8-cuda-11, libcudnn8-dev-cuda-11 + if cuda_major == "12" or cuda_major == "13": + cudnn_pkg = "libcudnn9" + cudnn_cuda_suffix = f"cuda-{cuda_major}" + else: + cudnn_pkg = "libcudnn8" + cudnn_cuda_suffix = f"cuda-{cuda_major}" + + # Manual cuDNN installation + await self._broadcast_log_line("Installing cuDNN packages...") + + cudnn_package_name = f"{cudnn_pkg}-{cudnn_cuda_suffix}" + cudnn_dev_package_name = f"{cudnn_pkg}-dev-{cudnn_cuda_suffix}" + packages = await self._get_repo_packages(ubuntu_version) + cudnn_pkg_entry = self._select_repo_package( + packages, cudnn_package_name, version_prefix=cudnn_version + ) + cudnn_dev_pkg_entry = self._select_repo_package( + packages, cudnn_dev_package_name, version_prefix=cudnn_version + ) + + if not cudnn_pkg_entry or not cudnn_dev_pkg_entry: + await self._broadcast_log_line( + "cuDNN packages not found in repository, skipping cuDNN installation" + ) + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 100, + "message": "cuDNN installation skipped (optional)", + } + ) + return + + base_url = ( + f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" + ) + cudnn_url = base_url + cudnn_pkg_entry.get("Filename", "").lstrip("./") + cudnn_dev_url = base_url + cudnn_dev_pkg_entry.get("Filename", "").lstrip("./") + + cudnn_path = os.path.join(self._download_dir, f"{cudnn_pkg}.deb") + cudnn_dev_path = os.path.join(self._download_dir, f"{cudnn_pkg}-dev.deb") + + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 25, + "message": "Downloading cuDNN packages...", + } + ) + + # Download cuDNN packages + async with aiohttp.ClientSession() as session: + for url, path, name in [ + (cudnn_url, cudnn_path, cudnn_pkg), + (cudnn_dev_url, cudnn_dev_path, f"{cudnn_pkg}-dev"), + ]: + try: + await self._broadcast_log_line(f"Downloading {name}...") + async with session.get(url) as response: + if response.status == 200: + async with aiofiles.open(path, "wb") as f: + await f.write(await response.read()) + await self._broadcast_log_line(f"Downloaded {name}") + else: + await self._broadcast_log_line( + f"Failed to download {name}: HTTP {response.status}" + ) + # Try alternative URL pattern + continue + except Exception as download_err: + await self._broadcast_log_line( + f"Download error for {name}: {download_err}" + ) + continue + + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 50, + "message": "Installing cuDNN packages...", + } + ) + + if os.path.exists(cudnn_path): + await self._broadcast_log_line( + "Extracting cuDNN to CUDA directory..." + ) + + # Extract .deb file + extract_dir = os.path.join(self._download_dir, "cudnn_extract") + os.makedirs(extract_dir, exist_ok=True) + + for deb_path in [cudnn_path, cudnn_dev_path]: + if os.path.exists(deb_path): + # Extract using ar and tar + extract_process = await asyncio.create_subprocess_exec( + "bash", + "-c", + f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + await extract_process.wait() + + # Copy cuDNN files to CUDA installation + cudnn_lib_src = os.path.join( + extract_dir, "usr", "lib", "x86_64-linux-gnu" + ) + cudnn_include_src = os.path.join(extract_dir, "usr", "include") + + cuda_lib_dst = os.path.join(cuda_path, "lib64") + cuda_include_dst = os.path.join(cuda_path, "include") + + if os.path.exists(cudnn_lib_src): + for f in os.listdir(cudnn_lib_src): + if "cudnn" in f.lower(): + src = os.path.join(cudnn_lib_src, f) + dst = os.path.join(cuda_lib_dst, f) + try: + if os.path.islink(src): + linkto = os.readlink(src) + if os.path.exists(dst): + os.remove(dst) + os.symlink(linkto, dst) + else: + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA lib directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + if os.path.exists(cudnn_include_src): + for f in os.listdir(cudnn_include_src): + if "cudnn" in f.lower(): + src = os.path.join(cudnn_include_src, f) + dst = os.path.join(cuda_include_dst, f) + try: + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA include directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Cleanup temporary extract directory only (keep .deb files) + shutil.rmtree(extract_dir, ignore_errors=True) + # Keep .deb files for future use + logger.info(f"cuDNN packages kept at: {cudnn_path}, {cudnn_dev_path}") + + await self._broadcast_log_line("cuDNN extracted to CUDA directory") + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 100, + "message": "cuDNN installed successfully", + } + ) + else: + await self._broadcast_log_line( + "cuDNN packages not available, skipping cuDNN installation" + ) + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 100, + "message": "cuDNN installation skipped (optional)", + } + ) + + except Exception as e: + await self._broadcast_log_line(f"cuDNN installation error: {e}") + await self._broadcast_log_line( + "Note: cuDNN is optional. The build will continue without cuDNN support." + ) + await self._broadcast_progress( + { + "stage": "cudnn", + "progress": 100, + "message": "cuDNN installation skipped (optional)", + } + ) + + async def _install_tensorrt_linux(self, cuda_version: str, cuda_path: str) -> None: + """Install TensorRT library for inference optimization.""" + await self._broadcast_log_line( + "Installing TensorRT (NVIDIA TensorRT inference library)..." + ) + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 0, + "message": "Installing TensorRT...", + } + ) + + try: + # Determine CUDA major version for TensorRT compatibility + cuda_major = cuda_version.split(".")[0] + tensorrt_version = self.TENSORRT_VERSIONS.get(cuda_major) + + if not tensorrt_version: + await self._broadcast_log_line( + f"TensorRT version not available for CUDA {cuda_version}, skipping" + ) + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 100, + "message": "TensorRT installation skipped (version not available)", + } + ) + return + + ubuntu_version = self._get_ubuntu_version() + + # TensorRT package names + # For CUDA 12.x/13.x: libnvinfer10, libnvinfer-dev, libnvinfer-plugin10, libnvinfer-plugin-dev + # For CUDA 11.x: libnvinfer8, libnvinfer-dev, libnvinfer-plugin8, libnvinfer-plugin-dev + if cuda_major == "12" or cuda_major == "13": + tensorrt_pkg = "libnvinfer10" + tensorrt_plugin_pkg = "libnvinfer-plugin10" + else: + tensorrt_pkg = "libnvinfer8" + tensorrt_plugin_pkg = "libnvinfer-plugin8" + + # Manual TensorRT installation + await self._broadcast_log_line("Installing TensorRT packages...") + + packages = await self._get_repo_packages(ubuntu_version) + tensorrt_pkg_entry = self._select_repo_package( + packages, tensorrt_pkg, version_prefix=tensorrt_version + ) + tensorrt_dev_pkg_entry = self._select_repo_package( + packages, f"{tensorrt_pkg}-dev", version_prefix=tensorrt_version + ) + tensorrt_plugin_entry = self._select_repo_package( + packages, tensorrt_plugin_pkg, version_prefix=tensorrt_version + ) + tensorrt_plugin_dev_entry = self._select_repo_package( + packages, f"{tensorrt_plugin_pkg}-dev", version_prefix=tensorrt_version + ) + + if not all( + [ + tensorrt_pkg_entry, + tensorrt_dev_pkg_entry, + tensorrt_plugin_entry, + tensorrt_plugin_dev_entry, + ] + ): + await self._broadcast_log_line( + "TensorRT packages not found in repository, skipping TensorRT installation" + ) + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 100, + "message": "TensorRT installation skipped (optional)", + } + ) + return + + base_url = ( + f"https://developer.download.nvidia.com/compute/cuda/repos/{ubuntu_version}/x86_64/" + ) + tensorrt_url = base_url + tensorrt_pkg_entry.get("Filename", "").lstrip("./") + tensorrt_dev_url = base_url + tensorrt_dev_pkg_entry.get("Filename", "").lstrip("./") + tensorrt_plugin_url = base_url + tensorrt_plugin_entry.get("Filename", "").lstrip("./") + tensorrt_plugin_dev_url = base_url + tensorrt_plugin_dev_entry.get("Filename", "").lstrip("./") + + tensorrt_path = os.path.join(self._download_dir, f"{tensorrt_pkg}.deb") + tensorrt_dev_path = os.path.join(self._download_dir, f"{tensorrt_pkg}-dev.deb") + tensorrt_plugin_path = os.path.join(self._download_dir, f"{tensorrt_plugin_pkg}.deb") + tensorrt_plugin_dev_path = os.path.join(self._download_dir, f"{tensorrt_plugin_pkg}-dev.deb") + + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 25, + "message": "Downloading TensorRT packages...", + } + ) + + # Download TensorRT packages + async with aiohttp.ClientSession() as session: + for url, path, name in [ + (tensorrt_url, tensorrt_path, tensorrt_pkg), + (tensorrt_dev_url, tensorrt_dev_path, f"{tensorrt_pkg}-dev"), + (tensorrt_plugin_url, tensorrt_plugin_path, tensorrt_plugin_pkg), + (tensorrt_plugin_dev_url, tensorrt_plugin_dev_path, f"{tensorrt_plugin_pkg}-dev"), + ]: + try: + await self._broadcast_log_line(f"Downloading {name}...") + async with session.get(url) as response: + if response.status == 200: + async with aiofiles.open(path, "wb") as f: + await f.write(await response.read()) + await self._broadcast_log_line(f"Downloaded {name}") + else: + await self._broadcast_log_line( + f"Failed to download {name}: HTTP {response.status}" + ) + continue + except Exception as download_err: + await self._broadcast_log_line( + f"Download error for {name}: {download_err}" + ) + continue + + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 50, + "message": "Installing TensorRT packages...", + } + ) + + if os.path.exists(tensorrt_path): + await self._broadcast_log_line( + "Extracting TensorRT to CUDA directory..." + ) + + # Extract .deb file + extract_dir = os.path.join(self._download_dir, "tensorrt_extract") + os.makedirs(extract_dir, exist_ok=True) + + for deb_path in [ + tensorrt_path, + tensorrt_dev_path, + tensorrt_plugin_path, + tensorrt_plugin_dev_path, + ]: + if os.path.exists(deb_path): + # Extract using ar and tar + extract_process = await asyncio.create_subprocess_exec( + "bash", + "-c", + f"cd {extract_dir} && ar x {deb_path} && tar xf data.tar.* 2>/dev/null || tar xf data.tar 2>/dev/null", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + await extract_process.wait() + + # Copy TensorRT files to CUDA installation + tensorrt_lib_src = os.path.join( + extract_dir, "usr", "lib", "x86_64-linux-gnu" + ) + tensorrt_include_src = os.path.join(extract_dir, "usr", "include") + tensorrt_bin_src = os.path.join(extract_dir, "usr", "bin") + + cuda_lib_dst = os.path.join(cuda_path, "lib64") + cuda_include_dst = os.path.join(cuda_path, "include") + cuda_bin_dst = os.path.join(cuda_path, "bin") + + # Copy libraries + if os.path.exists(tensorrt_lib_src): + for f in os.listdir(tensorrt_lib_src): + if "nvinfer" in f.lower() or "tensorrt" in f.lower(): + src = os.path.join(tensorrt_lib_src, f) + dst = os.path.join(cuda_lib_dst, f) + try: + if os.path.islink(src): + linkto = os.readlink(src) + if os.path.exists(dst): + os.remove(dst) + os.symlink(linkto, dst) + else: + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA lib directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Copy headers + if os.path.exists(tensorrt_include_src): + for f in os.listdir(tensorrt_include_src): + if "nvinfer" in f.lower() or "tensorrt" in f.lower(): + src = os.path.join(tensorrt_include_src, f) + dst = os.path.join(cuda_include_dst, f) + try: + if os.path.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy2(src, dst) + await self._broadcast_log_line( + f"Copied {f} to CUDA include directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Copy binaries (like trtexec) + if os.path.exists(tensorrt_bin_src): + for f in os.listdir(tensorrt_bin_src): + if "trt" in f.lower() or "nvinfer" in f.lower(): + src = os.path.join(tensorrt_bin_src, f) + dst = os.path.join(cuda_bin_dst, f) + try: + shutil.copy2(src, dst) + os.chmod(dst, 0o755) + await self._broadcast_log_line( + f"Copied {f} to CUDA bin directory" + ) + except Exception as copy_err: + await self._broadcast_log_line( + f"Failed to copy {f}: {copy_err}" + ) + + # Cleanup temporary extract directory only (keep .deb files) + shutil.rmtree(extract_dir, ignore_errors=True) + # Keep .deb files for future use + logger.info( + f"TensorRT packages kept at: {tensorrt_path}, {tensorrt_dev_path}, " + f"{tensorrt_plugin_path}, {tensorrt_plugin_dev_path}" + ) + + await self._broadcast_log_line("TensorRT extracted to CUDA directory") + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 100, + "message": "TensorRT installed successfully", + } + ) + else: + await self._broadcast_log_line( + "TensorRT packages not available, skipping TensorRT installation" + ) + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 100, + "message": "TensorRT installation skipped (optional)", + } + ) + + except Exception as e: + await self._broadcast_log_line(f"TensorRT installation error: {e}") + await self._broadcast_log_line( + "Note: TensorRT is optional. The build will continue without TensorRT support." + ) + await self._broadcast_progress( + { + "stage": "tensorrt", + "progress": 100, + "message": "TensorRT installation skipped (optional)", + } + ) + + async def install( + self, + version: str = "12.6", + install_cudnn: bool = False, + install_tensorrt: bool = False, + ) -> Dict[str, Any]: + """Install CUDA Toolkit with optional cuDNN and TensorRT.""" + async with self._lock: + if self._operation: + raise RuntimeError( + "Another CUDA installer operation is already running" + ) + + system, arch = self._get_platform() + + if system != "linux": + raise RuntimeError( + f"CUDA installation is only supported on Linux, not {system}" + ) + + if version not in self.SUPPORTED_VERSIONS: + raise ValueError( + f"Unsupported CUDA version: {version}. Supported versions: {', '.join(self.SUPPORTED_VERSIONS)}" + ) + + # Fetch the download URL dynamically + await self._broadcast_log_line( + f"Fetching download URL for CUDA {version}..." + ) + url = await self._fetch_download_url(version) + installer_filename = os.path.basename(url) + installer_path = os.path.join(self._download_dir, installer_filename) + + await self._set_operation("install") + + async def _runner(): + try: + # Download installer + await self._download_installer(version, url, installer_path) + + # Install (Linux only) - returns the installation path + install_path = await self._install_linux( + installer_path, version, install_cudnn, install_tensorrt + ) + + # Update state (already saved in _install_linux, but update main fields) + state = self._load_state() + state["installed_version"] = version + state["installed_at"] = _utcnow() + state["cuda_path"] = install_path + if install_cudnn: + state["cudnn_installed"] = True + if install_tensorrt: + state["tensorrt_installed"] = True + self._save_state(state) + + components = ["CUDA Toolkit"] + if install_cudnn: + components.append("cuDNN") + if install_tensorrt: + components.append("TensorRT") + + await self._finish_operation( + True, f"{', '.join(components)} installed successfully" + ) + + # Update current process environment with CUDA paths + # This ensures the running application can use CUDA immediately + cuda_env = self.get_cuda_env(version) + if cuda_env: + os.environ.update(cuda_env) + logger.info( + f"Updated process environment with CUDA {version} paths" + ) + + # Restart llama-swap to pick up new CUDA environment variables + # llama-swap needs to be restarted because subprocess environment + # variables are set at process creation time and can't be changed + try: + from backend.llama_swap_manager import get_llama_swap_manager + llama_swap_manager = get_llama_swap_manager() + await llama_swap_manager.restart_proxy() + logger.info("Restarted llama-swap to pick up new CUDA environment") + except Exception as restart_error: + # Don't fail the installation if restart fails + logger.warning( + f"Failed to restart llama-swap after CUDA installation: {restart_error}. " + f"You may need to manually restart llama-swap to use the new CUDA version." + ) + + # Keep installer file for future use (not deleting) + logger.info(f"Installer file kept at: {installer_path}") + + except Exception as exc: + self._last_error = str(exc) + await self._finish_operation(False, str(exc)) + raise + + self._create_task(_runner()) + return {"message": f"CUDA {version} installation started"} + + def _detect_cudnn_version(self, cuda_path: Optional[str]) -> Optional[str]: + """Detect installed cuDNN version by checking library files.""" + if not cuda_path: + return None + + lib_path = os.path.join(cuda_path, "lib64") + if not os.path.exists(lib_path): + return None + + try: + for f in os.listdir(lib_path): + if "libcudnn" in f and ".so" in f: + match = re.search(r"\.so(?:\.(\d+(?:\.\d+){0,2}))?", f) + if match and match.group(1): + return match.group(1) + except Exception: + pass + + return None + + def _detect_tensorrt_version(self, cuda_path: Optional[str]) -> Optional[str]: + """Detect installed TensorRT version by checking library files.""" + if not cuda_path: + return None + + lib_path = os.path.join(cuda_path, "lib64") + if not os.path.exists(lib_path): + return None + + try: + for f in os.listdir(lib_path): + if "libnvinfer" in f and ".so" in f and "plugin" not in f: + match = re.search(r"\.so(?:\.(\d+(?:\.\d+){0,2}))?", f) + if match and match.group(1): + return match.group(1) + except Exception: + pass + + return None + + def status(self) -> Dict[str, Any]: + """Get CUDA installation status.""" + version = self._detect_installed_version() + cuda_path = self._get_cuda_path() + installed = version is not None and cuda_path is not None + state = self._load_state() + installations = state.get("installations", {}) + + # Detect cuDNN and TensorRT + cudnn_version = None + tensorrt_version = None + if cuda_path: + cudnn_version = self._detect_cudnn_version(cuda_path) + tensorrt_version = self._detect_tensorrt_version(cuda_path) + + # Get all installed versions with their details + installed_versions = [] + for v, info in installations.items(): + install_path = info.get("path") + if install_path and os.path.exists(install_path): + installed_versions.append( + { + "version": v, + "path": install_path, + "installed_at": info.get("installed_at"), + "is_system_install": info.get("is_system_install", False), + "is_current": v == version, + "cudnn_installed": info.get("cudnn_installed", False), + "tensorrt_installed": info.get("tensorrt_installed", False), + } + ) + + return { + "installed": installed, + "version": version, + "cuda_path": cuda_path, + "installed_at": state.get("installed_at"), + "installed_versions": installed_versions, + "operation": self._operation, + "operation_started_at": self._operation_started_at, + "last_error": self._last_error, + "log_path": self._log_path, + "available_versions": self.SUPPORTED_VERSIONS, + "platform": self._get_platform(), + "cudnn": { + "installed": cudnn_version is not None, + "version": cudnn_version, + }, + "tensorrt": { + "installed": tensorrt_version is not None, + "version": tensorrt_version, + }, + } + + def is_operation_running(self) -> bool: + return self._operation is not None + + def read_log_tail(self, max_bytes: int = 8192) -> str: + if not os.path.exists(self._log_path): + return "" + with open(self._log_path, "rb") as log_file: + log_file.seek(0, os.SEEK_END) + size = log_file.tell() + log_file.seek(max(0, size - max_bytes)) + data = log_file.read().decode("utf-8", errors="replace") + if size > max_bytes: + data = data.split("\n", 1)[-1] + return data.strip() + + async def uninstall(self, version: Optional[str] = None) -> Dict[str, Any]: + """Uninstall CUDA Toolkit.""" + async with self._lock: + if self._operation: + raise RuntimeError( + "Another CUDA installer operation is already running" + ) + + # Determine which version to uninstall + if not version: + # Uninstall the currently detected version + version = self._detect_installed_version() + if not version: + raise RuntimeError("No CUDA installation found to uninstall") + + state = self._load_state() + installations = state.get("installations", {}) + + if version not in installations: + raise RuntimeError(f"CUDA {version} installation not found in state") + + install_info = installations[version] + install_path = install_info.get("path") + + if not install_path or not os.path.exists(install_path): + # Path doesn't exist, just remove from state + logger.warning( + f"CUDA installation path {install_path} does not exist, removing from state only" + ) + installations.pop(version, None) + if state.get("installed_version") == version: + state["installed_version"] = None + state["installed_at"] = None + state["cuda_path"] = None + self._save_state(state) + return { + "message": f"CUDA {version} removed from state (installation path not found)" + } + + await self._set_operation("uninstall") + + async def _runner(): + try: + await self._broadcast_log_line( + f"Starting uninstallation of CUDA {version}..." + ) + await self._broadcast_progress( + { + "stage": "uninstall", + "progress": 0, + "message": f"Uninstalling CUDA {version}...", + } + ) + + # Remove the installation directory + if os.path.exists(install_path): + await self._broadcast_log_line( + f"Removing installation directory: {install_path}" + ) + try: + shutil.rmtree(install_path) + await self._broadcast_log_line( + f"Successfully removed {install_path}" + ) + except Exception as e: + logger.error( + f"Failed to remove CUDA installation directory: {e}" + ) + raise RuntimeError( + f"Failed to remove installation directory: {e}" + ) + + # Update state + installations.pop(version, None) + if state.get("installed_version") == version: + state["installed_version"] = None + state["installed_at"] = None + state["cuda_path"] = None + self._save_state(state) + + # Update or remove the current symlink + self._remove_current_symlink() + await self._broadcast_log_line( + "Updated CUDA current symlink (removed or re-pointed to another version)" + ) + + await self._broadcast_progress( + { + "stage": "uninstall", + "progress": 100, + "message": "CUDA uninstallation completed", + } + ) + await self._broadcast_log_line( + f"CUDA {version} uninstalled successfully" + ) + await self._finish_operation( + True, f"CUDA {version} uninstalled successfully" + ) + + except Exception as exc: + self._last_error = str(exc) + await self._finish_operation(False, str(exc)) + raise + + self._create_task(_runner()) + return {"message": f"CUDA {version} uninstallation started"} diff --git a/backend/data_store.py b/backend/data_store.py new file mode 100644 index 0000000..36aebdf --- /dev/null +++ b/backend/data_store.py @@ -0,0 +1,216 @@ +"""YAML-backed data store replacing SQLite.""" + +import os +import threading +from typing import Any, Dict, List, Optional + +import yaml + +from backend.logging_config import get_logger + +logger = get_logger(__name__) + + +def _get_config_dir() -> str: + """Return config directory (Docker: /app/data/config, local: data/config).""" + if os.path.exists("/app/data"): + return "/app/data/config" + return os.path.abspath("data/config") + + +def generate_proxy_name(huggingface_id: str, quantization: Optional[str] = None) -> str: + """ + Generate a proxy name for llama-swap using HuggingFace ID and optional quantization. + """ + huggingface_slug = ( + huggingface_id.replace("/", "-").replace(" ", "-").replace(".", "-").lower() + ) + if quantization: + quantization_slug = quantization.replace(" ", "-").lower() + return f"{huggingface_slug}.{quantization_slug}" + return huggingface_slug + + +class DataStore: + """Thread-safe YAML-backed data store replacing SQLite.""" + + def __init__(self, config_dir: Optional[str] = None): + self._config_dir = os.path.abspath(config_dir or _get_config_dir()) + self._lock = threading.Lock() + self._ensure_files_exist() + + def _ensure_files_exist(self) -> None: + """Create config dir and default YAML files if they don't exist.""" + os.makedirs(self._config_dir, exist_ok=True) + for filename, default in [ + ("models.yaml", {"models": []}), + ( + "engines.yaml", + { + "llama_cpp": {"active_version": None, "versions": []}, + "ik_llama": {"active_version": None, "versions": []}, + "lmdeploy": { + "installed": False, + "version": None, + "install_type": None, + "source_repo": None, + "source_branch": None, + "venv_path": None, + }, + "cuda": {"installed_version": None, "install_path": None}, + }, + ), + ("settings.yaml", {"huggingface_token": "", "proxy_port": 2000}), + ]: + path = os.path.join(self._config_dir, filename) + if not os.path.exists(path): + self._write_yaml(path, default) + + def _read_yaml(self, filename: str) -> dict: + """Read and parse a YAML file. Returns empty dict on error.""" + path = os.path.join(self._config_dir, filename) + with self._lock: + if not os.path.exists(path): + return {} + try: + with open(path, "r") as f: + return yaml.safe_load(f) or {} + except Exception as e: + logger.warning(f"Failed to read {path}: {e}") + return {} + + def _write_yaml(self, path: str, data: dict) -> None: + """Atomic write: write to temp file then rename.""" + tmp_path = path + ".tmp" + try: + with open(tmp_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + os.replace(tmp_path, path) + except Exception as e: + if os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except OSError: + pass + raise e + + def _save_yaml(self, filename: str, data: dict) -> None: + """Thread-safe write to a YAML file.""" + path = os.path.join(self._config_dir, filename) + with self._lock: + self._write_yaml(path, data) + + # --- Models --- + + def list_models(self) -> List[dict]: + return self._read_yaml("models.yaml").get("models", []) + + def get_model(self, model_id: str) -> Optional[dict]: + for m in self.list_models(): + if m.get("id") == model_id: + return m + return None + + def add_model(self, model: dict) -> dict: + data = self._read_yaml("models.yaml") + data.setdefault("models", []).append(model) + self._save_yaml("models.yaml", data) + return model + + def update_model(self, model_id: str, updates: dict) -> Optional[dict]: + data = self._read_yaml("models.yaml") + for m in data.get("models", []): + if m.get("id") == model_id: + m.update(updates) + self._save_yaml("models.yaml", data) + return m + return None + + def delete_model(self, model_id: str) -> bool: + data = self._read_yaml("models.yaml") + models = data.get("models", []) + new_models = [m for m in models if m.get("id") != model_id] + if len(new_models) == len(models): + return False + data["models"] = new_models + self._save_yaml("models.yaml", data) + return True + + # --- Engines (llama_cpp, ik_llama) --- + + def get_engine_versions(self, engine: str) -> List[dict]: + """engine is 'llama_cpp' or 'ik_llama'.""" + return self._read_yaml("engines.yaml").get(engine, {}).get("versions", []) + + def get_active_engine_version(self, engine: str) -> Optional[dict]: + data = self._read_yaml("engines.yaml").get(engine, {}) + active = data.get("active_version") + if not active: + return None + for v in data.get("versions", []): + if v.get("version") == active: + return v + return None + + def add_engine_version(self, engine: str, version_data: dict) -> None: + data = self._read_yaml("engines.yaml") + data.setdefault(engine, {}).setdefault("versions", []).append(version_data) + self._save_yaml("engines.yaml", data) + + def set_active_engine_version(self, engine: str, version: str) -> None: + data = self._read_yaml("engines.yaml") + data.setdefault(engine, {})["active_version"] = version + self._save_yaml("engines.yaml", data) + + def delete_engine_version(self, engine: str, version: str) -> bool: + data = self._read_yaml("engines.yaml") + engine_data = data.get(engine, {}) + versions = engine_data.get("versions", []) + new_versions = [v for v in versions if v.get("version") != version] + if len(new_versions) == len(versions): + return False + engine_data["versions"] = new_versions + if engine_data.get("active_version") == version: + engine_data["active_version"] = None + self._save_yaml("engines.yaml", data) + return True + + # --- LMDeploy --- + + def get_lmdeploy_status(self) -> dict: + return self._read_yaml("engines.yaml").get("lmdeploy", {}) + + def update_lmdeploy(self, updates: dict) -> None: + data = self._read_yaml("engines.yaml") + data.setdefault("lmdeploy", {}).update(updates) + self._save_yaml("engines.yaml", data) + + # --- CUDA --- + + def get_cuda_status(self) -> dict: + return self._read_yaml("engines.yaml").get("cuda", {}) + + def update_cuda(self, updates: dict) -> None: + data = self._read_yaml("engines.yaml") + data.setdefault("cuda", {}).update(updates) + self._save_yaml("engines.yaml", data) + + # --- Settings --- + + def get_settings(self) -> dict: + return self._read_yaml("settings.yaml") + + def update_settings(self, updates: dict) -> None: + data = self._read_yaml("settings.yaml") + data.update(updates) + self._save_yaml("settings.yaml", data) + + +_store: Optional[DataStore] = None + + +def get_store() -> DataStore: + global _store + if _store is None: + _store = DataStore() + return _store diff --git a/backend/database.py b/backend/database.py deleted file mode 100644 index 65a29fa..0000000 --- a/backend/database.py +++ /dev/null @@ -1,364 +0,0 @@ -from sqlalchemy import ( - create_engine, - Column, - Integer, - String, - DateTime, - Boolean, - Text, - Float, - ForeignKey, - JSON, - text, - inspect, -) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, relationship -from datetime import datetime -from typing import Dict, List -import os -from backend.logging_config import get_logger - -logger = get_logger(__name__) - -# Determine database path - use /app/data in Docker, ./data locally -if os.path.exists("/app/data"): - db_dir = "/app/data" - db_path = "/app/data/db.sqlite" -else: - db_dir = "data" - db_path = "data/db.sqlite" - -DATABASE_URL = f"sqlite:///{db_path}" - -engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() - - -def get_db(): - """Dependency to get database session""" - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def generate_proxy_name(huggingface_id: str, quantization: str) -> str: - """ - Generate a centralized proxy name for llama-swap using HuggingFace ID and quantization. - This ensures consistent naming across all components. - """ - # Create unique proxy name using HuggingFace ID and quantization to avoid conflicts - huggingface_slug = ( - huggingface_id.replace("/", "-").replace(" ", "-").replace(".", "-").lower() - ) - quantization_slug = quantization.replace(" ", "-").lower() - return f"{huggingface_slug}.{quantization_slug}" - - -class Model(Base): - __tablename__ = "models" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True) - huggingface_id = Column(String, index=True) # Removed unique constraint - base_model_name = Column(String, index=True) # Model name without quantization - file_path = Column(String) - file_size = Column(Integer) # in bytes - quantization = Column(String) # Q4_K_M, Q8_0, etc. - model_type = Column(String) # llama, mistral, etc. - downloaded_at = Column(DateTime) - is_active = Column(Boolean, default=False) - config = Column(JSON) # JSON object of llama.cpp parameters - proxy_name = Column(String, index=True) # Centralized proxy name for llama-swap - model_format = Column(String, default="gguf", server_default="gguf", index=True) - pipeline_tag = Column(String, index=True) - - -class LlamaVersion(Base): - __tablename__ = "llama_versions" - - id = Column(Integer, primary_key=True, index=True) - version = Column(String, unique=True, index=True) - install_type = Column(String) # "release", "source", "patched" - binary_path = Column(String) - source_commit = Column(String) # For source builds - patches = Column(Text) # JSON array of patch URLs/metadata - installed_at = Column(DateTime, default=datetime.utcnow) - is_active = Column(Boolean, default=False) # Changed from is_default to is_active - build_config = Column(JSON) # Store BuildConfig as JSON - repository_source = Column( - String, default="llama.cpp" - ) # "llama.cpp" or "ik_llama.cpp" - - -class RunningInstance(Base): - __tablename__ = "running_instances" - - id = Column(Integer, primary_key=True, index=True) - model_id = Column(Integer, index=True) - llama_version = Column(String) - proxy_model_name = Column(String) # NEW: Model name in llama-swap - started_at = Column(DateTime) - config = Column(Text) # JSON string of runtime config - runtime_type = Column( - String, default="llama_cpp", server_default="llama_cpp", index=True - ) - - -def sync_model_active_status(db): - """Sync model is_active status with running instances""" - - # Get all running instances - running_instances = db.query(RunningInstance).all() - active_model_ids = set() - - for instance in running_instances: - active_model_ids.add(instance.model_id) - - # Update all models' is_active status - all_models = db.query(Model).all() - updated_count = 0 - - for model in all_models: - new_status = model.id in active_model_ids - if model.is_active != new_status: - model.is_active = new_status - updated_count += 1 - - if updated_count > 0: - db.commit() - logger.info(f"Synced {updated_count} models' is_active status") - - return updated_count - - -async def init_db(): - """Initialize database tables""" - # Use the same db_dir determined at module load time - os.makedirs(db_dir, exist_ok=True) - # Ensure the database directory is writable - import stat - try: - os.chmod(db_dir, stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) - except Exception as perm_error: - logger.warning(f"Could not set permissions on {db_dir}: {perm_error}") - - # If database file exists, ensure it's writable - if os.path.exists(db_path): - try: - # Check if we can write to the database file - if not os.access(db_path, os.W_OK): - logger.error(f"Database file {db_path} is not writable. Please check file permissions.") - logger.error(f"Current user: {os.getuid() if hasattr(os, 'getuid') else 'unknown'}") - logger.error(f"File owner: {os.stat(db_path).st_uid if hasattr(os.stat(db_path), 'st_uid') else 'unknown'}") - raise PermissionError(f"Database file {db_path} is not writable") - except Exception as perm_error: - logger.warning(f"Could not check database file permissions: {perm_error}") - - Base.metadata.create_all(bind=engine) - - try: - ensure_model_format_column() - except Exception as exc: - logger.warning(f"Failed to ensure model_format column: {exc}") - try: - ensure_running_instance_runtime_column() - except Exception as exc: - logger.warning(f"Failed to ensure running_instances.runtime_type column: {exc}") - try: - ensure_pipeline_tag_column() - except Exception as exc: - logger.warning(f"Failed to ensure models.pipeline_tag column: {exc}") - try: - ensure_repository_source_column() - except Exception as exc: - logger.warning(f"Failed to ensure repository_source column: {exc}") - - # Migrate existing models to populate base_model_name - migrate_existing_models() - - # Migrate safetensors models: merge multiple rows per repo into single row - try: - migrate_safetensors_models_to_unified() - except Exception as exc: - logger.warning(f"Failed to migrate safetensors models: {exc}") - - -def migrate_existing_models(): - """Migrate existing models to populate base_model_name field""" - db = SessionLocal() - try: - models = db.query(Model).filter(Model.base_model_name.is_(None)).all() - - for model in models: - # Extract base model name from huggingface_id or name - if model.huggingface_id: - # Extract model name from huggingface_id (e.g., "microsoft/DialoGPT-medium" -> "DialoGPT") - parts = model.huggingface_id.split("/") - if len(parts) > 1: - model.base_model_name = parts[-1].split("-")[ - 0 - ] # Remove quantization suffix - else: - model.base_model_name = model.huggingface_id - elif model.name: - # Extract from name if no huggingface_id - model.base_model_name = model.name.split("-")[0] - else: - model.base_model_name = "unknown" - - db.commit() - logger.info(f"Migrated {len(models)} models with base_model_name") - - except Exception as e: - logger.error(f"Error migrating models: {e}") - db.rollback() - finally: - db.close() - - -def ensure_model_format_column(): - """Ensure the models table has the model_format column (retrofit for existing DBs).""" - inspector = inspect(engine) - columns = [column["name"] for column in inspector.get_columns("models")] - if "model_format" in columns: - return - - with engine.connect() as connection: - connection.execute(text("ALTER TABLE models ADD COLUMN model_format VARCHAR")) - connection.execute( - text("UPDATE models SET model_format = 'gguf' WHERE model_format IS NULL") - ) - logger.info("Added model_format column to models table") - - -def ensure_running_instance_runtime_column(): - """Ensure running_instances table tracks runtime_type.""" - inspector = inspect(engine) - columns = [column["name"] for column in inspector.get_columns("running_instances")] - if "runtime_type" in columns: - return - - with engine.connect() as connection: - connection.execute( - text("ALTER TABLE running_instances ADD COLUMN runtime_type VARCHAR") - ) - connection.execute( - text( - "UPDATE running_instances SET runtime_type = 'llama_cpp' WHERE runtime_type IS NULL" - ) - ) - logger.info("Added runtime_type column to running_instances table") - - -def migrate_safetensors_models_to_unified(): - """Migrate safetensors models: merge multiple Model rows per repo into a single row.""" - db = SessionLocal() - try: - # Find all safetensors models grouped by huggingface_id - safetensors_models = ( - db.query(Model).filter(Model.model_format == "safetensors").all() - ) - - # Group by huggingface_id - by_repo: Dict[str, List[Model]] = {} - for model in safetensors_models: - hf_id = model.huggingface_id or "unknown" - by_repo.setdefault(hf_id, []).append(model) - - merged_count = 0 - for huggingface_id, models in by_repo.items(): - if len(models) <= 1: - continue # Already unified - - # Keep the first model, merge others into it - primary = models[0] - others = models[1:] - - # Aggregate file_size - total_size = sum(m.file_size or 0 for m in models) - if total_size: - primary.file_size = total_size - - # Merge metadata: use most complete pipeline_tag, model_type, etc. - for other in others: - if not primary.pipeline_tag and other.pipeline_tag: - primary.pipeline_tag = other.pipeline_tag - if not primary.model_type and other.model_type: - primary.model_type = other.model_type - if not primary.base_model_name and other.base_model_name: - primary.base_model_name = other.base_model_name - # Use earliest downloaded_at - if other.downloaded_at and ( - not primary.downloaded_at - or other.downloaded_at < primary.downloaded_at - ): - primary.downloaded_at = other.downloaded_at - - # Update RunningInstance records to point to primary model - for other in others: - instances = ( - db.query(RunningInstance) - .filter(RunningInstance.model_id == other.id) - .all() - ) - for instance in instances: - instance.model_id = primary.id - - # Delete duplicate models - for other in others: - db.delete(other) - - merged_count += len(others) - logger.info( - f"Merged {len(others)} safetensors Model rows for {huggingface_id} into model_id={primary.id}" - ) - - if merged_count > 0: - db.commit() - logger.info( - f"Migration complete: merged {merged_count} safetensors Model rows" - ) - else: - logger.debug("No safetensors Model rows to merge") - - except Exception as e: - logger.error(f"Error migrating safetensors models: {e}") - db.rollback() - finally: - db.close() - - -def ensure_pipeline_tag_column(): - """Ensure the models table stores pipeline tags.""" - inspector = inspect(engine) - columns = [column["name"] for column in inspector.get_columns("models")] - if "pipeline_tag" in columns: - return - - with engine.connect() as connection: - connection.execute(text("ALTER TABLE models ADD COLUMN pipeline_tag VARCHAR")) - logger.info("Added pipeline_tag column to models table") - - -def ensure_repository_source_column(): - """Ensure the llama_versions table has the repository_source column.""" - inspector = inspect(engine) - columns = [column["name"] for column in inspector.get_columns("llama_versions")] - if "repository_source" in columns: - return - - with engine.connect() as connection: - connection.execute( - text("ALTER TABLE llama_versions ADD COLUMN repository_source VARCHAR") - ) - connection.execute( - text( - "UPDATE llama_versions SET repository_source = 'llama.cpp' WHERE repository_source IS NULL" - ) - ) - logger.info("Added repository_source column to llama_versions table") diff --git a/backend/gguf_reader.py b/backend/gguf_reader.py index ef4c3c5..2ad9dd1 100644 --- a/backend/gguf_reader.py +++ b/backend/gguf_reader.py @@ -9,11 +9,27 @@ from typing import Dict, Optional, Any, List, Tuple, BinaryIO from backend.logging_config import get_logger -from backend.architecture_profiles import compute_layers_for_architecture logger = get_logger(__name__) +def _compute_layers_for_architecture( + architecture: str, + metadata: dict, + base_block_count: int, +) -> dict: + """Compute block_count and effective_layer_count from architecture and metadata.""" + block_count = max(0, int(base_block_count)) + # Most architectures add one output head layer + effective = block_count + 1 + arch = (architecture or "").lower() + if arch == "glm4moe": + nextn = metadata.get("glm4moe.nextn_predict_layers") + if nextn is not None: + effective = block_count + int(nextn) + return {"block_count": block_count, "effective_layer_count": effective} + + class GGUFValueType(IntEnum): """ GGUF Value Types as defined in the specification. @@ -677,7 +693,7 @@ def read_gguf_metadata(file_path: str) -> Optional[Dict[str, Any]]: # Then compute architecture-aware block_count and effective_layer_count architecture = metadata.get("general.architecture", "").lower() - layer_info = compute_layers_for_architecture( + layer_info = _compute_layers_for_architecture( architecture=architecture, metadata=metadata, base_block_count=base_block_count, diff --git a/backend/huggingface.py b/backend/huggingface.py index 6d59a42..288ab57 100644 --- a/backend/huggingface.py +++ b/backend/huggingface.py @@ -1,11 +1,9 @@ from huggingface_hub import HfApi, hf_hub_download, list_models from typing import List, Dict, Optional, Tuple, Any import asyncio -import aiohttp import json import os import threading -from tqdm import tqdm import time import re import traceback @@ -39,6 +37,45 @@ _safetensors_metadata_ttl = 600 # 10 minutes +def get_accurate_file_sizes(repo_id: str, paths: List[str]) -> Dict[str, Optional[int]]: + """Fetch accurate file sizes from HuggingFace API via get_paths_info.""" + if not paths: + return {} + try: + paths_info = hf_api.get_paths_info(repo_id=repo_id, paths=paths) + return { + getattr(pi, "path", getattr(pi, "rfilename", "")): getattr(pi, "size", None) + for pi in paths_info + } + except Exception as e: + logger.warning(f"get_paths_info failed for {repo_id}: {e}") + return {} + + +def get_mmproj_f16_filename(repo_id: str) -> Optional[str]: + """ + If the repo contains vision projector (mmproj) GGUF files, return the F16 one to download. + Prefers mmproj-F16.gguf, then any *mmproj*F16*.gguf, then first mmproj*.gguf. + Returns None if no mmproj files or on API error. + """ + try: + files = list(hf_api.list_repo_files(repo_id=repo_id)) + except Exception as e: + logger.debug(f"list_repo_files failed for {repo_id}: {e}") + return None + mmproj = [f for f in files if "mmproj" in f.lower() and f.lower().endswith(".gguf")] + if not mmproj: + return None + # Prefer exact mmproj-F16.gguf, then any filename containing F16, then first mmproj + for name in mmproj: + if name == "mmproj-F16.gguf": + return name + for name in mmproj: + if "f16" in name.lower(): + return name + return mmproj[0] + + def _download_repo_json(repo_id: str, filename: str) -> Optional[Dict[str, Any]]: try: path = hf_hub_download(repo_id, filename, local_dir_use_symlinks=False) @@ -92,6 +129,127 @@ def _get_download_directory(model_format: str, huggingface_id: str) -> str: return MODEL_BASE_DIR +def _hf_repo_folder_name(huggingface_id: str) -> str: + """Return the HF cache folder name for a model repo (e.g. models--Org--Repo).""" + return "models--" + huggingface_id.replace("/", "--") + + +def resolve_cached_model_path(huggingface_id: str, filename: str) -> Optional[str]: + """Return the local path for a cached HF model file without triggering a download. + + Returns None if the file is not in the HF cache. + """ + try: + return hf_hub_download( + repo_id=huggingface_id, + filename=filename, + local_files_only=True, + ) + except Exception: + return None + + +def delete_cached_model_file(huggingface_id: str, filename: str) -> bool: + """Delete a specific model file from the HuggingFace cache. + + Removes both the snapshot symlink and the underlying content blob. + Returns True if the file was found and deleted, False otherwise. + """ + try: + cached_path = hf_hub_download( + repo_id=huggingface_id, + filename=filename, + local_files_only=True, + ) + except Exception: + logger.warning( + f"delete_cached_model_file: {huggingface_id}/{filename} not found in HF cache" + ) + return False + + if os.path.islink(cached_path): + blob_path = os.path.realpath(cached_path) + try: + os.unlink(cached_path) + except OSError as e: + logger.warning(f"Could not remove symlink {cached_path}: {e}") + if os.path.exists(blob_path): + try: + os.remove(blob_path) + except OSError as e: + logger.warning(f"Could not remove blob {blob_path}: {e}") + elif os.path.exists(cached_path): + try: + os.remove(cached_path) + except OSError as e: + logger.warning(f"Could not remove file {cached_path}: {e}") + + logger.info(f"Deleted cached model file: {huggingface_id}/{filename}") + return True + + +def resolve_model_path( + huggingface_id: str, + filename: Optional[str] = None, + model_format: str = "gguf", +) -> Optional[str]: + """ + Resolve a model's local path from current storage (data/models/...). + For GGUF: returns path to the specific file if filename is given. + For safetensors: returns the repo directory (filename ignored). + Returns None if the path does not exist. Does not create directories. + """ + if not huggingface_id: + return None + safe_repo = _safe_repo_name(huggingface_id) + base_dir = FORMAT_SUBDIRS.get(model_format, MODEL_BASE_DIR) + repo_dir = os.path.join(base_dir, safe_repo) + for prefix in ("", "/app"): + candidate = repo_dir if not prefix else os.path.join(prefix, repo_dir) + if not os.path.exists(candidate): + continue + if model_format == "gguf" and filename: + path = os.path.join(candidate, filename) + if os.path.isfile(path): + return path + continue + if model_format == "safetensors" or not filename: + if os.path.isdir(candidate): + return candidate + return None + + +def get_model_disk_size( + huggingface_id: str, + filename: Optional[str] = None, + model_format: str = "gguf", +) -> int: + """ + Compute actual disk usage in bytes for a model in current storage. + For GGUF: size of the given file. For safetensors: sum of all files in repo dir. + """ + path = resolve_model_path(huggingface_id, filename, model_format) + if not path: + return 0 + if os.path.isfile(path): + try: + return os.path.getsize(path) + except OSError: + return 0 + if os.path.isdir(path): + total = 0 + try: + for _dirpath, _dirnames, filenames in os.walk(path): + for f in filenames: + fp = os.path.join(_dirpath, f) + if os.path.isfile(fp): + total += os.path.getsize(fp) + except OSError: + pass + return total + return 0 + + def _get_manifest_lock( model_format: str, huggingface_id: Optional[str] = None ) -> threading.Lock: @@ -997,7 +1155,6 @@ async def _search_with_api(query: str, limit: int, model_format: str) -> List[Di search=query, limit=min(limit * 2, 50), # Get more models to filter from sort="downloads", - direction=-1, filter=filter_value, expand=[ "author", @@ -1069,11 +1226,13 @@ async def _process_single_model(model, model_format: str) -> Optional[Dict]: if model_format == "gguf": # Group GGUF files by logical quantization, handling multi-part shards # Accept both plain `.gguf` and multi-part patterns like `.gguf.part1of2` + # Exclude mmproj (vision/multimodal projection) files – they are extensions, not standalone quants gguf_siblings = [ s for s in model.siblings if isinstance(getattr(s, "rfilename", None), str) and re.search(r"\.gguf(\.|$)", s.rfilename) + and "mmproj" not in s.rfilename.lower() ] logger.info(f"Model {model.id}: {len(gguf_siblings)} GGUF files found") if not gguf_siblings: @@ -1139,6 +1298,23 @@ async def _process_single_model(model, model_format: str) -> Optional[Dict]: else 0.0 ) + # Siblings from list_models often have size=None; fetch accurate sizes from Hub + try: + all_filenames = [s.rfilename for s in gguf_siblings] + accurate_sizes = get_accurate_file_sizes(model.id, all_filenames) + if accurate_sizes: + for entry in quantizations.values(): + for f in entry["files"]: + f["size"] = accurate_sizes.get(f["filename"]) or f["size"] or 0 + entry["total_size"] = sum(f["size"] for f in entry["files"]) + entry["size_mb"] = ( + round(entry["total_size"] / (1024 * 1024), 2) + if entry["total_size"] + else 0.0 + ) + except Exception as size_err: + logger.debug(f"Could not fetch accurate sizes for {model.id}: {size_err}") + # If no quantizations were detected after grouping, skip this model if not quantizations: return None @@ -1162,6 +1338,15 @@ async def _process_single_model(model, model_format: str) -> Optional[Dict]: ) if not safetensors_files: return None + # Fetch accurate sizes; list_models siblings often have size=None + try: + st_filenames = [f["filename"] for f in safetensors_files] + accurate_sizes = get_accurate_file_sizes(model.id, st_filenames) + if accurate_sizes: + for f in safetensors_files: + f["size"] = accurate_sizes.get(f["filename"]) or 0 + except Exception as size_err: + logger.debug(f"Could not fetch accurate sizes for {model.id}: {size_err}") else: return None @@ -1510,23 +1695,18 @@ async def get_model_details(model_id: str) -> Dict: async def download_model( huggingface_id: str, filename: str, model_format: str = "gguf" ) -> tuple[str, int]: - """Download model from HuggingFace""" + """Download model from HuggingFace to the native HF cache.""" try: - models_dir = _get_download_directory(model_format, huggingface_id) - - # Sanitize filename filename = _sanitize_filename(filename) - # Download the file file_path = hf_hub_download( repo_id=huggingface_id, filename=filename, - local_dir=models_dir, - local_dir_use_symlinks=False, ) - # Get file size - file_size = os.path.getsize(file_path) + # Use realpath so getsize works even when file_path is a symlink + real_path = os.path.realpath(file_path) + file_size = os.path.getsize(real_path if os.path.exists(real_path) else file_path) return file_path, file_size @@ -1535,331 +1715,154 @@ async def download_model( raise -async def download_model_with_websocket_progress( +async def download_model_with_progress( huggingface_id: str, filename: str, - websocket_manager, + progress_manager, task_id: str, total_bytes: int = 0, model_format: str = "gguf", huggingface_id_for_progress: str = None, ): - """Download model with WebSocket progress updates by tracking filesystem size""" - import asyncio + """Download model to the HF native cache with SSE progress updates. + + Progress is tracked by monitoring the .incomplete blob file that hf_hub_download + writes to the HF cache during the download. + """ + import threading import time + from huggingface_hub.constants import HF_HUB_CACHE - logger.info(f"=== DOWNLOAD PROGRESS START ===") - logger.info(f"Download task: {task_id}") - logger.info(f"HuggingFace ID: {huggingface_id}") - logger.info(f"Filename: {filename}") - logger.info(f"Total bytes from search: {total_bytes}") - logger.info(f"WebSocket manager: {websocket_manager}") - logger.info(f"Active connections: {len(websocket_manager.active_connections)}") + filename = _sanitize_filename(filename) + progress_hf_id = huggingface_id_for_progress or huggingface_id - try: - models_dir = _get_download_directory(model_format, huggingface_id) + logger.info(f"Starting HF-cache download: {huggingface_id}/{filename} task={task_id}") - # Sanitize filename and build path - filename = _sanitize_filename(filename) - file_path = os.path.join(models_dir, filename) - directory = os.path.dirname(file_path) - if directory and not os.path.exists(directory): - os.makedirs(directory, exist_ok=True) - - # Send initial progress - logger.info(f"Sending initial progress message...") - progress_hf_id = huggingface_id_for_progress or huggingface_id - await websocket_manager.send_download_progress( - task_id=task_id, - progress=0, - message=f"Starting download of {filename}", - bytes_downloaded=0, - total_bytes=total_bytes, - speed_mbps=0, - eta_seconds=0, - filename=filename, - model_format=model_format, - huggingface_id=progress_hf_id, - ) - logger.info(f"Initial progress message sent") + # Resolve total size if not provided + if total_bytes == 0: + try: + file_info = HfApi().repo_file_info(repo_id=huggingface_id, filename=filename) + total_bytes = file_info.size or 0 + logger.info(f"Got file size from HuggingFace API: {total_bytes}") + except Exception as e: + logger.warning(f"Could not get file size: {e}") + + await progress_manager.send_download_progress( + task_id=task_id, + progress=0, + message=f"Starting download of {filename}", + bytes_downloaded=0, + total_bytes=total_bytes, + speed_mbps=0, + eta_seconds=0, + filename=filename, + model_format=model_format, + huggingface_id=progress_hf_id, + ) - # Get file size from HuggingFace API if not provided - if total_bytes == 0: - try: - from huggingface_hub import HfApi - - api = HfApi() - file_info = api.repo_file_info(repo_id=huggingface_id, path=filename) - total_bytes = file_info.size - logger.info(f"Got file size from HuggingFace API: {total_bytes}") - except Exception as e: - logger.warning(f"Could not get file size from HuggingFace API: {e}") - # If we can't get the size, we'll estimate it - total_bytes = 0 - - # Send total size update - if total_bytes > 0: - await websocket_manager.send_download_progress( + # Run the blocking hf_hub_download in a background thread + repo_folder = _hf_repo_folder_name(huggingface_id) + blobs_dir = os.path.join(HF_HUB_CACHE, repo_folder, "blobs") + + download_result: dict = {"file_path": None, "error": None, "done": False} + + def _do_download(): + try: + download_result["file_path"] = hf_hub_download( + repo_id=huggingface_id, + filename=filename, + ) + except Exception as exc: + download_result["error"] = exc + finally: + download_result["done"] = True + + thread = threading.Thread(target=_do_download, daemon=True) + thread.start() + + # Poll the .incomplete blob for progress + start_time = time.time() + last_bytes = 0 + last_poll = start_time + + while not download_result["done"]: + await asyncio.sleep(0.5) + + incomplete_bytes = 0 + if os.path.isdir(blobs_dir): + for fname in os.listdir(blobs_dir): + if fname.endswith(".incomplete"): + try: + incomplete_bytes = max( + incomplete_bytes, + os.path.getsize(os.path.join(blobs_dir, fname)), + ) + except OSError: + pass + + if incomplete_bytes > 0: + now = time.time() + elapsed_total = now - start_time + elapsed_poll = now - last_poll + delta = incomplete_bytes - last_bytes + speed_mbps = (delta / elapsed_poll / (1024 * 1024)) if elapsed_poll > 0 else 0 + progress = min(99, int(incomplete_bytes / total_bytes * 100)) if total_bytes else 0 + eta = ( + int((total_bytes - incomplete_bytes) / (incomplete_bytes / elapsed_total)) + if elapsed_total > 0 and incomplete_bytes > 0 and total_bytes > incomplete_bytes + else 0 + ) + await progress_manager.send_download_progress( task_id=task_id, - progress=0, + progress=progress, message=f"Downloading {filename}", - bytes_downloaded=0, + bytes_downloaded=incomplete_bytes, total_bytes=total_bytes, - speed_mbps=0, - eta_seconds=0, + speed_mbps=round(speed_mbps, 2), + eta_seconds=eta, filename=filename, model_format=model_format, huggingface_id=progress_hf_id, ) + last_bytes = incomplete_bytes + last_poll = now - # Start the download with built-in progress tracking - logger.info(f"🚀 Starting download with built-in progress tracking...") - - file_path, file_size = await download_with_progress_tracking( - huggingface_id, - filename, - file_path, - models_dir, - websocket_manager, - task_id, - total_bytes, - model_format, - progress_hf_id, - ) - - # Send final completion - await websocket_manager.send_download_progress( + if download_result["error"]: + err = download_result["error"] + await progress_manager.send_download_progress( task_id=task_id, - progress=100, - message=f"Download completed: {filename}", - bytes_downloaded=file_size, - total_bytes=file_size, + progress=0, + message=f"Download failed: {err}", + bytes_downloaded=0, + total_bytes=total_bytes, speed_mbps=0, eta_seconds=0, filename=filename, model_format=model_format, huggingface_id=progress_hf_id, ) + raise err + + # Success: get final path and size + file_path = download_result["file_path"] + real_path = os.path.realpath(file_path) if file_path else file_path + file_size = os.path.getsize(real_path if os.path.exists(real_path) else file_path) + + await progress_manager.send_download_progress( + task_id=task_id, + progress=100, + message=f"Download completed: {filename}", + bytes_downloaded=file_size, + total_bytes=file_size, + speed_mbps=0, + eta_seconds=0, + filename=filename, + model_format=model_format, + huggingface_id=progress_hf_id, + ) - return file_path, file_size - - except Exception as e: - # Send error notification - if websocket_manager and task_id: - progress_hf_id = huggingface_id_for_progress or huggingface_id - await websocket_manager.send_download_progress( - task_id=task_id, - progress=0, - message=f"Download failed: {str(e)}", - bytes_downloaded=0, - total_bytes=0, - speed_mbps=0, - eta_seconds=0, - filename=filename, - model_format=model_format, - huggingface_id=progress_hf_id, - ) - await websocket_manager.send_notification( - "error", - "Download Failed", - f"Failed to download {filename}: {str(e)}", - task_id, - ) - raise - - -async def download_with_progress_tracking( - huggingface_id: str, - filename: str, - file_path: str, - models_dir: str, - websocket_manager, - task_id: str, - total_bytes: int, - model_format: str, - huggingface_id_for_progress: str = None, -): - """Download the file using custom http_get method with progress tracking""" - try: - import aiofiles - - logger.info( - f"📁 Starting download of {filename} ({total_bytes} bytes) [{model_format}]" - ) - - # Use the standard HuggingFace resolve URL (this is the default/preferred method) - safe_filename = _sanitize_filename(filename) - download_url = ( - f"https://huggingface.co/{huggingface_id}/resolve/main/{safe_filename}" - ) - actual_file_size = total_bytes # Start with the provided size - - # Optionally get exact file size from HuggingFace API - try: - api = HfApi() - file_info = api.repo_file_info( - repo_id=huggingface_id, filename=safe_filename - ) - if hasattr(file_info, "size") and file_info.size: - actual_file_size = file_info.size - logger.info( - f"📊 Got file size from HuggingFace API: {actual_file_size} bytes ({actual_file_size / (1024*1024):.2f} MB)" - ) - except Exception as e: - logger.debug( - f"Could not get file size from API: {e}, using provided size: {total_bytes}" - ) - - logger.info(f"📁 Download URL: {download_url}") - - # Build headers manually - hf_headers = { - "User-Agent": "llama-cpp-studio/1.0.0", - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate", - } - - # Create final destination path - final_path = os.path.join(models_dir, safe_filename) - final_dir = os.path.dirname(final_path) - if final_dir and not os.path.exists(final_dir): - os.makedirs(final_dir, exist_ok=True) - - # Custom progress bar that sends WebSocket updates - progress_hf_id = huggingface_id_for_progress or huggingface_id - - class WebSocketProgressBar(tqdm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.websocket_manager = websocket_manager - self.task_id = task_id - self.filename = filename - self.huggingface_id = progress_hf_id - self.start_time = time.time() - self.last_update_time = self.start_time - - def update(self, n=1): - super().update(n) - # Send WebSocket update with current progress - current_time = time.time() - if ( - current_time - self.last_update_time >= 0.5 - ): # Update every 0.5 seconds - if self.total > 0: - progress = int((self.n / self.total) * 100) - current_bytes = int(self.n) - - # Calculate speed and ETA - elapsed_time = current_time - self.start_time - speed_bytes_per_sec = ( - current_bytes / elapsed_time if elapsed_time > 0 else 0 - ) - speed_mbps = speed_bytes_per_sec / (1024 * 1024) - - remaining_bytes = self.total - self.n - eta_seconds = ( - int(remaining_bytes / speed_bytes_per_sec) - if speed_bytes_per_sec > 0 - else 0 - ) - - logger.debug( - f"📊 Progress: {progress}% ({current_bytes}/{self.total} bytes) - {speed_mbps:.1f} MB/s" - ) - - # Send WebSocket update - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - asyncio.create_task( - self.websocket_manager.send_download_progress( - task_id=self.task_id, - progress=progress, - message=f"Downloading {self.filename}", - bytes_downloaded=current_bytes, - total_bytes=self.total, - speed_mbps=speed_mbps, - eta_seconds=eta_seconds, - filename=self.filename, - model_format=model_format, - huggingface_id=self.huggingface_id, - ) - ) - except Exception as e: - logger.error(f"Error sending progress update: {e}") - - self.last_update_time = current_time - - # Create our custom progress bar - custom_progress_bar = WebSocketProgressBar( - desc=safe_filename, - total=actual_file_size, # Use the actual file size - unit="B", - unit_scale=True, - unit_divisor=1024, - disable=False, - ) - - # Download using aiohttp with timeout and our custom progress bar - timeout = aiohttp.ClientTimeout( - total=3600, connect=30 - ) # 1 hour total, 30s connect - async with aiohttp.ClientSession( - headers=hf_headers, timeout=timeout - ) as session: - async with session.get(download_url) as response: - if response.status != 200: - raise Exception(f"Failed to download: HTTP {response.status}") - - # Get actual file size from response headers - content_length = response.headers.get("content-length") - if content_length: - response_size = int(content_length) - if response_size != actual_file_size: - logger.debug( - f"📏 Size difference: API said {actual_file_size}, response says {response_size} (diff: {abs(response_size - actual_file_size)} bytes)" - ) - # Use the response size as it's more accurate - actual_file_size = response_size - custom_progress_bar.total = actual_file_size - logger.info( - f"📊 Using response size: {actual_file_size} bytes ({actual_file_size / (1024*1024):.2f} MB)" - ) - - # Download with progress tracking - # Use 64KB chunks for better performance with large files - chunk_size = 65536 - downloaded_bytes = 0 - async with aiofiles.open(final_path, "wb") as f: - async for chunk in response.content.iter_chunked(chunk_size): - await f.write(chunk) - downloaded_bytes += len(chunk) - custom_progress_bar.update(len(chunk)) - - # Close the progress bar - custom_progress_bar.close() - - logger.info(f"📁 Downloaded to: {final_path}") - - # Validate downloaded file size - file_size = os.path.getsize(final_path) - if actual_file_size and actual_file_size > 0 and file_size != actual_file_size: - logger.warning( - f"⚠️ Download size mismatch: expected {actual_file_size}, got {file_size}" - ) - # Allow small differences (like metadata) - if abs(file_size - actual_file_size) > 1024: # More than 1KB difference - raise Exception( - f"Download incomplete: expected {actual_file_size} bytes, got {file_size} bytes" - ) - - return final_path, file_size + return file_path, file_size - except Exception as e: - logger.error(f"Download error: {str(e)}") - logger.error(f"Error type: {type(e).__name__}") - logger.error(f"Traceback:\n{traceback.format_exc()}") - raise async def get_quantization_sizes_from_hf( @@ -1898,27 +1901,23 @@ async def get_quantization_sizes_from_hf( updated: Dict[str, Dict] = {} if all_filenames: - try: - # Newer API: batch query specific paths for metadata - paths_info = hf_api.get_paths_info( - repo_id=huggingface_id, paths=all_filenames - ) - # Build lookup - file_sizes: Dict[str, Optional[int]] = { - pi.path: getattr(pi, "size", None) for pi in paths_info - } - except Exception as batch_err: - logger.warning( - f"get_paths_info failed for {huggingface_id}: {batch_err}" - ) + file_sizes = get_accurate_file_sizes(huggingface_id, all_filenames) + if not file_sizes: # Fallback: fetch full metadata once - model_info = hf_api.model_info( - repo_id=huggingface_id, files_metadata=True - ) - file_sizes = {} - if hasattr(model_info, "siblings") and model_info.siblings: - for sibling in model_info.siblings: - file_sizes[sibling.rfilename] = getattr(sibling, "size", None) + try: + model_info = hf_api.model_info( + repo_id=huggingface_id, files_metadata=True + ) + if hasattr(model_info, "siblings") and model_info.siblings: + for sibling in model_info.siblings: + key = getattr(sibling, "path", getattr(sibling, "rfilename", "")) + if key: + file_sizes[key] = getattr(sibling, "size", None) + except Exception as fallback_err: + logger.warning( + f"model_info fallback failed for {huggingface_id}: {fallback_err}" + ) + file_sizes = {} for quant_name, filenames in quant_to_files.items(): files_with_sizes = [] diff --git a/backend/llama_manager.py b/backend/llama_manager.py index a4bfa88..cbcb928 100644 --- a/backend/llama_manager.py +++ b/backend/llama_manager.py @@ -76,6 +76,8 @@ class LlamaManager: # Repository URLs LLAMA_CPP_REPO = "https://github.com/ggerganov/llama.cpp.git" IK_LLAMA_CPP_REPO = "https://github.com/ikawrakow/ik_llama.cpp.git" + # Pre-built CUDA releases (ai-dock builds; used for "Install Release") + LLAMA_CPP_CUDA_RELEASES_API = "https://api.github.com/repos/ai-dock/llama.cpp-cuda/releases" REPOSITORY_SOURCES = { "llama.cpp": LLAMA_CPP_REPO, @@ -83,7 +85,11 @@ class LlamaManager: } def __init__(self): - self.llama_dir = "data/llama-cpp" + # Use absolute path so clone/build work regardless of process cwd (e.g. --app-dir backend) + if os.path.exists("/app/data"): + self.llama_dir = "/app/data/llama-cpp" + else: + self.llama_dir = os.path.abspath(os.path.join(os.getcwd(), "data", "llama-cpp")) os.makedirs(self.llama_dir, exist_ok=True) # Ensure directory has proper permissions (read, write, execute for owner) try: @@ -92,6 +98,7 @@ def __init__(self): except Exception as e: logger.warning(f"Could not set permissions on {self.llama_dir}: {e}") self._cached_cuda_architectures: Optional[str] = None + self._cached_cmake_path: Optional[str] = None def _check_cuda_toolkit_available( self, @@ -326,8 +333,11 @@ def _verify_cuda_toolkit_complete(self, cuda_root: str) -> Tuple[bool, List[str] def _get_cmake_version(self) -> Optional[Tuple[int, int, int]]: """Get CMake version as tuple (major, minor, patch).""" try: + cmake_exe = self._find_cmake_executable() + if not cmake_exe: + return None result = subprocess.run( - ["cmake", "--version"], capture_output=True, text=True, timeout=5 + [cmake_exe, "--version"], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: # Parse "cmake version X.Y.Z" @@ -342,6 +352,24 @@ def _get_cmake_version(self) -> Optional[Tuple[int, int, int]]: pass return None + def _find_cmake_executable(self) -> Optional[str]: + """Find a usable cmake executable from env or PATH.""" + if self._cached_cmake_path and os.path.exists(self._cached_cmake_path): + return self._cached_cmake_path + + candidates = [ + os.getenv("CMAKE"), + os.getenv("CMAKE_EXECUTABLE"), + shutil.which("cmake"), + ] + + for candidate in candidates: + if candidate and os.path.exists(candidate): + self._cached_cmake_path = candidate + return candidate + + return None + def _get_cuda_version(self, nvcc_path: str) -> Optional[Tuple[int, int]]: """Get CUDA version from nvcc as tuple (major, minor).""" try: @@ -401,9 +429,9 @@ async def _detect_cuda_architectures(self) -> Optional[str]: return detected def _fetch_release(self, tag_name: str) -> Dict: - """Fetch release metadata for a tag.""" + """Fetch release metadata for a tag from ai-dock/llama.cpp-cuda (CUDA builds).""" response = requests.get( - f"https://api.github.com/repos/ggerganov/llama.cpp/releases/tags/{tag_name}", + f"{self.LLAMA_CPP_CUDA_RELEASES_API}/tags/{tag_name}", allow_redirects=True, ) response.raise_for_status() @@ -413,6 +441,10 @@ def _tokenize_asset_name(self, asset_name: str) -> List[str]: return [token for token in re.split(r"[.\-_\s]+", asset_name.lower()) if token] def _is_asset_compatible(self, asset_name: str) -> Tuple[bool, Optional[str]]: + # ai-dock/llama.cpp-cuda: single .tar.gz per release (e.g. llama.cpp-b8233-cuda-12.8.tar.gz) + if re.match(r"^llama\.cpp-[^\-]+-cuda-[0-9.]+\.tar\.gz$", asset_name, re.IGNORECASE): + return True, None + tokens = self._tokenize_asset_name(asset_name) if not tokens: @@ -453,6 +485,11 @@ def _extract_asset_features(self, asset_name: str) -> List[str]: tokens = self._tokenize_asset_name(asset_name) features = [] + # ai-dock tarballs contain llama-server and are CUDA builds + if re.match(r"^llama\.cpp-[^\-]+-cuda-[0-9.]+\.tar\.gz$", asset_name, re.IGNORECASE): + features.extend(["llama-server", "CUDA"]) + return features + feature_map = { "cuda": "CUDA", "vulkan": "Vulkan", @@ -701,8 +738,31 @@ def get_optimal_build_threads(self) -> int: except: return 1 # Fallback to single thread + async def _run_command( + self, + *args, + cwd: Optional[str] = None, + env: Optional[dict] = None, + timeout: Optional[int] = None, + merge_stderr: bool = False, + ) -> subprocess.CompletedProcess: + """Run a subprocess in a thread for cross-platform compatibility.""" + + def _runner(): + return subprocess.run( + list(args), + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT if merge_stderr else subprocess.PIPE, + timeout=timeout, + check=False, + ) + + return await asyncio.to_thread(_runner) + async def validate_build( - self, binary_path: str, websocket_manager=None, task_id: str = None + self, binary_path: str, progress_manager=None, task_id: str = None ) -> bool: """Run basic validation on built binary""" try: @@ -711,13 +771,9 @@ async def validate_build( return False # Test 2: Run --version command - process = await asyncio.create_subprocess_exec( - binary_path, - "--version", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=10) + process = await self._run_command(binary_path, "--version", timeout=10) + stdout = process.stdout or b"" + stderr = process.stderr or b"" if process.returncode != 0: return False @@ -736,15 +792,15 @@ async def validate_build( async def install_release( self, tag_name: str, - websocket_manager=None, + progress_manager=None, task_id: str = None, asset_id: Optional[int] = None, ) -> str: - """Install llama.cpp from GitHub release with WebSocket progress updates""" + """Install llama.cpp from GitHub release with SSE progress updates""" try: # Stage 1: Get release info - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="fetch", progress=10, @@ -768,8 +824,8 @@ async def install_release( ) # Stage 2: Download binary - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="download", progress=30, @@ -817,8 +873,8 @@ async def install_release( logger.warning(f"Unable to verify downloaded artifact size: {exc}") # Stage 3: Extract binary - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="extract", progress=70, @@ -852,8 +908,8 @@ async def install_release( os.remove(download_path) # Stage 4: Find and verify executable - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="verify", progress=90, @@ -877,8 +933,8 @@ async def install_release( logger.info( f"llama-server executable found and verified: {final_server_path}" ) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="verify", progress=100, @@ -905,9 +961,9 @@ async def install_release( except Exception as e: logger.error(f"Installation failed with error: {e}") - if websocket_manager and task_id: + if progress_manager and task_id: try: - await websocket_manager.send_build_progress( + await progress_manager.send_build_progress( task_id=task_id, stage="error", progress=0, @@ -915,7 +971,7 @@ async def install_release( log_lines=[f"Error: {str(e)}"], ) except Exception as ws_error: - logger.error(f"Failed to send error to WebSocket: {ws_error}") + logger.error(f"Failed to send error via SSE: {ws_error}") raise Exception(f"Failed to install release {tag_name}: {e}") async def build_source( @@ -923,7 +979,7 @@ async def build_source( commit_sha: str, patches: List[str] = None, build_config: BuildConfig = None, - websocket_manager=None, + progress_manager=None, task_id: str = None, repository_url: str = None, version_name: str = None, @@ -942,8 +998,8 @@ async def build_source( break # Send initial progress - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="init", progress=0, @@ -971,8 +1027,8 @@ async def build_source( logger.warning(f"Could not set permissions on {version_dir}: {e}") # Stage 1: Clone repository (simplified) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="clone", progress=20, @@ -984,20 +1040,16 @@ async def build_source( # Simple git clone with timeout try: - clone_process = await asyncio.create_subprocess_exec( + clone_process = await self._run_command( "git", "clone", repository_url, clone_dir, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - clone_stdout, clone_stderr = await asyncio.wait_for( - clone_process.communicate(), timeout=300 # 5 minute timeout + timeout=300, ) if clone_process.returncode != 0: + clone_stderr = clone_process.stderr or b"" error_msg = clone_stderr.decode().strip() raise Exception(f"Git clone failed: {error_msg}") @@ -1005,13 +1057,11 @@ async def build_source( except asyncio.TimeoutError: logger.error("Git clone timed out") - clone_process.kill() - await clone_process.wait() raise Exception("Git clone timed out - network issues") # Stage 2: Checkout specific commit/branch (simplified) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="checkout", progress=40, @@ -1020,38 +1070,30 @@ async def build_source( ) try: - checkout_process = await asyncio.create_subprocess_exec( + checkout_process = await self._run_command( "git", "checkout", commit_sha, cwd=clone_dir, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - checkout_stdout, checkout_stderr = await asyncio.wait_for( - checkout_process.communicate(), timeout=60 + timeout=60, ) if checkout_process.returncode != 0: + checkout_stderr = checkout_process.stderr or b"" error_msg = checkout_stderr.decode().strip() # Try main as fallback for "master" (legacy support) if commit_sha == "master": logger.info("Failed to checkout 'master', trying 'main'") - main_process = await asyncio.create_subprocess_exec( + main_process = await self._run_command( "git", "checkout", "main", cwd=clone_dir, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - main_stdout, main_stderr = await asyncio.wait_for( - main_process.communicate(), timeout=60 + timeout=60, ) if main_process.returncode != 0: + main_stderr = main_process.stderr or b"" raise Exception( f"Failed to checkout both 'master' and 'main': {main_stderr.decode()}" ) @@ -1065,8 +1107,8 @@ async def build_source( # Stage 3: Apply patches (if any) if patches: - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="patch", progress=50, @@ -1078,8 +1120,8 @@ async def build_source( await self._apply_patch(clone_dir, patch_url) # Stage 4: Build following official documentation - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=60, @@ -1133,8 +1175,8 @@ async def build_source( error_msg = f"CUDA build requested but CUDA Toolkit not found.\n\n{cuda_error}" logger.error(error_msg) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=60, @@ -1158,8 +1200,8 @@ async def build_source( if result.returncode != 0: error_msg = f"nvcc found at {nvcc_path} but failed to execute (exit code {result.returncode})" logger.error(error_msg) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=60, @@ -1174,8 +1216,8 @@ async def build_source( except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e: error_msg = f"Failed to verify nvcc at {nvcc_path}: {e}" logger.error(error_msg) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=60, @@ -1186,8 +1228,8 @@ async def build_source( else: error_msg = f"nvcc not found at expected path {nvcc_path} (CUDA root: {cuda_root})" logger.error(error_msg) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=60, @@ -1199,8 +1241,14 @@ async def build_source( # Store validated CUDA root for later use validated_cuda_root = cuda_root + cmake_exe = self._find_cmake_executable() + if not cmake_exe: + raise Exception( + "CMake was not found. Install CMake or set CMAKE/CMAKE_EXECUTABLE to its path." + ) + # Build CMake arguments - cmake_args = ["cmake", ".."] + cmake_args = [cmake_exe, ".."] # Add build type cmake_args.append(f"-DCMAKE_BUILD_TYPE={build_config.build_type}") @@ -1581,17 +1629,15 @@ def set_flag(flag: str, value: bool): # Log cmake arguments for debugging logger.info(f"CMake command: {' '.join(cmake_args)}") - cmake_process = await asyncio.create_subprocess_exec( + cmake_process = await self._run_command( *cmake_args, cwd=build_dir, env=env, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + timeout=180, ) - cmake_stdout, cmake_stderr = await asyncio.wait_for( - cmake_process.communicate(), timeout=180 # 3 minute timeout - ) + cmake_stdout = cmake_process.stdout or b"" + cmake_stderr = cmake_process.stderr or b"" if cmake_process.returncode != 0: error_msg = cmake_stderr.decode().strip() @@ -1688,20 +1734,17 @@ def set_flag(flag: str, value: bool): # List available targets for debugging (especially useful for ik_llama.cpp) try: - targets_process = await asyncio.create_subprocess_exec( - "cmake", + targets_process = await self._run_command( + cmake_exe, "--build", ".", "--target", "help", cwd=build_dir, env=env, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - targets_stdout, targets_stderr = await asyncio.wait_for( - targets_process.communicate(), timeout=30 + timeout=30, ) + targets_stdout = targets_process.stdout or b"" if targets_process.returncode == 0: targets_output = targets_stdout.decode( "utf-8", errors="replace" @@ -1727,8 +1770,8 @@ def set_flag(flag: str, value: bool): logger.warning( f"llama-server target not found in available targets. Repository: {repo_source_name}" ) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="configure", progress=65, @@ -1744,8 +1787,8 @@ def set_flag(flag: str, value: bool): raise Exception("CMake configuration timed out") # Stage 5: Build - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=70, @@ -1776,89 +1819,43 @@ def set_flag(flag: str, value: bool): ) # Explicitly build llama-server target - build_process = await asyncio.create_subprocess_exec( - "cmake", + build_process = await self._run_command( + cmake_exe, "--build", ".", "--target", "llama-server", - "--", - "-j", + "--parallel", str(thread_count), cwd=build_dir, env=env, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, # Merge stderr into stdout + timeout=1800, + merge_stderr=True, ) - # Stream build output for better diagnostics - build_output_lines = [] - last_progress_update = time.time() - - async def read_output(): - nonlocal last_progress_update - while True: - line = await build_process.stdout.readline() - if not line: - break - decoded_line = line.decode("utf-8", errors="replace").rstrip() - build_output_lines.append(decoded_line) - logger.debug(f"Build output: {decoded_line}") - # Send progress updates for important lines and periodically - if websocket_manager and task_id: - should_send = False - line_lower = decoded_line.lower() - # Always send errors and warnings - if any( - keyword in line_lower - for keyword in ["error", "warning", "fatal", "failed"] - ): - should_send = True - # Send periodic updates (every 5 seconds) to show progress - elif time.time() - last_progress_update > 5: - should_send = True - last_progress_update = time.time() - # Send important build milestones - elif any( - keyword in line_lower - for keyword in [ - "building", - "linking", - "built target", - "scanning", - "configuring", - ] - ): - should_send = True - - if should_send: - await websocket_manager.send_build_progress( - task_id=task_id, - stage="build", - progress=70, - message="Building llama.cpp...", - log_lines=[decoded_line], - ) - - # Start reading output - read_task = asyncio.create_task(read_output()) - - # Wait for process to complete - returncode = await asyncio.wait_for( - build_process.wait(), timeout=1800 # 30 minute timeout for build + build_output = (build_process.stdout or b"").decode( + "utf-8", errors="replace" ) + build_output_lines = [ + line.rstrip() for line in build_output.splitlines() if line.strip() + ] + returncode = build_process.returncode - # Wait for output reading to finish - await read_task - - build_output = "\n".join(build_output_lines) + if progress_manager and task_id and build_output_lines: + await progress_manager.send_build_progress( + task_id=task_id, + stage="build", + progress=70, + message="Building llama.cpp...", + log_lines=build_output_lines[-20:], + ) if returncode != 0: logger.error(f"Build failed with return code {returncode}") logger.error(f"Build output:\n{build_output}") - # Send error output via websocket if available - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + # Send error output via SSE if available + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=70, @@ -1911,8 +1908,8 @@ async def read_output(): logger.warning( f"Build target 'llama-server' not found, trying 'server' target (for examples/server)..." ) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=70, @@ -1924,41 +1921,29 @@ async def read_output(): # Try 'server' target (used when server is in examples/) logger.info("Attempting to build 'server' target...") - server_target_process = await asyncio.create_subprocess_exec( - "cmake", + server_target_process = await self._run_command( + cmake_exe, "--build", ".", "--target", "server", - "--", - "-j", + "--parallel", str(thread_count), cwd=build_dir, env=env, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, + timeout=1800, + merge_stderr=True, ) - server_target_output_lines = [] - - async def read_server_target_output(): - while True: - line = await server_target_process.stdout.readline() - if not line: - break - decoded_line = line.decode( - "utf-8", errors="replace" - ).rstrip() - server_target_output_lines.append(decoded_line) - logger.debug(f"Server target build output: {decoded_line}") - - read_server_task = asyncio.create_task(read_server_target_output()) - server_target_returncode = await asyncio.wait_for( - server_target_process.wait(), timeout=1800 + server_target_output = (server_target_process.stdout or b"").decode( + "utf-8", errors="replace" ) - await read_server_task - - server_target_output = "\n".join(server_target_output_lines) + server_target_output_lines = [ + line.rstrip() + for line in server_target_output.splitlines() + if line.strip() + ] + server_target_returncode = server_target_process.returncode if server_target_returncode == 0: logger.info("Successfully built 'server' target") @@ -1971,8 +1956,8 @@ async def read_server_target_output(): logger.error( f"Server target build output:\n{server_target_output}" ) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=70, @@ -1983,39 +1968,28 @@ async def read_server_target_output(): ) # Try building all targets as last resort logger.info("Attempting to build all targets as fallback...") - all_targets_process = await asyncio.create_subprocess_exec( - "cmake", + all_targets_process = await self._run_command( + cmake_exe, "--build", ".", - "--", - "-j", + "--parallel", str(thread_count), cwd=build_dir, env=env, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, + timeout=1800, + merge_stderr=True, ) - all_targets_output_lines = [] - - async def read_all_targets_output(): - while True: - line = await all_targets_process.stdout.readline() - if not line: - break - decoded_line = line.decode( - "utf-8", errors="replace" - ).rstrip() - all_targets_output_lines.append(decoded_line) - logger.debug(f"All targets build output: {decoded_line}") - - read_all_task = asyncio.create_task(read_all_targets_output()) - all_targets_returncode = await asyncio.wait_for( - all_targets_process.wait(), timeout=1800 + all_targets_output = (all_targets_process.stdout or b"").decode( + "utf-8", errors="replace" ) - await read_all_task + all_targets_output_lines = [ + line.rstrip() + for line in all_targets_output.splitlines() + if line.strip() + ] + all_targets_returncode = all_targets_process.returncode if all_targets_returncode != 0: - all_targets_output = "\n".join(all_targets_output_lines) logger.error( f"Building all targets failed with return code {all_targets_returncode}" ) @@ -2036,8 +2010,8 @@ async def read_all_targets_output(): f"Build completed with return code 0 but contains errors" ) logger.error(f"Build output:\n{build_output}") - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=70, @@ -2085,8 +2059,8 @@ async def read_all_targets_output(): logger.warning( "Binary not found in common locations immediately after build - will search more thoroughly" ) - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="build", progress=75, @@ -2100,8 +2074,8 @@ async def read_all_targets_output(): raise Exception("Build timed out") # Stage 6: Find executable - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="verify", progress=90, @@ -2231,9 +2205,9 @@ async def read_all_targets_output(): error_msg += f"3. CMake configuration succeeded\n" error_msg += f"4. Build output for errors" - # Send detailed error via websocket - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + # Send detailed error via SSE + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="error", progress=0, @@ -2255,8 +2229,8 @@ async def read_all_targets_output(): logger.info(f"Build completed, validating binary: {version_server_path}") # Validate the build - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="validate", progress=95, @@ -2265,13 +2239,13 @@ async def read_all_targets_output(): ) is_valid = await self.validate_build( - version_server_path, websocket_manager, task_id + version_server_path, progress_manager, task_id ) if not is_valid: logger.warning("Build validation failed") - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="validate", progress=95, @@ -2281,8 +2255,8 @@ async def read_all_targets_output(): logger.info(f"Build completed successfully: {version_server_path}") - if websocket_manager and task_id: - await websocket_manager.send_build_progress( + if progress_manager and task_id: + await progress_manager.send_build_progress( task_id=task_id, stage="complete", progress=100, @@ -2297,9 +2271,9 @@ async def read_all_targets_output(): except Exception as e: logger.error(f"Build failed: {e}") - if websocket_manager and task_id: + if progress_manager and task_id: try: - await websocket_manager.send_build_progress( + await progress_manager.send_build_progress( task_id=task_id, stage="error", progress=0, @@ -2307,7 +2281,7 @@ async def read_all_targets_output(): log_lines=[f"Error: {str(e)}"], ) except Exception as ws_error: - logger.error(f"Failed to send error to WebSocket: {ws_error}") + logger.error(f"Failed to send error via SSE: {ws_error}") raise Exception(f"Failed to build from source {commit_sha}: {e}") async def _apply_patch(self, repo_dir: str, patch_url: str): @@ -2332,17 +2306,16 @@ async def _apply_patch(self, repo_dir: str, patch_url: str): with open(patch_file, "w") as f: f.write(patch_content) - apply_process = await asyncio.create_subprocess_exec( + apply_process = await self._run_command( "git", "apply", patch_file, cwd=repo_dir, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + timeout=60, ) - apply_stdout, apply_stderr = await apply_process.communicate() if apply_process.returncode != 0: + apply_stderr = apply_process.stderr or b"" raise Exception(f"Failed to apply patch: {apply_stderr.decode()}") os.remove(patch_file) diff --git a/backend/llama_swap_config.py b/backend/llama_swap_config.py index cf717bc..3e071e7 100644 --- a/backend/llama_swap_config.py +++ b/backend/llama_swap_config.py @@ -265,26 +265,20 @@ def is_ik_llama_cpp(llama_server_path: Optional[str]) -> bool: except Exception as e: logger.debug(f"Error detecting ik_llama.cpp via flags: {e}") - # Fallback: Check database for repository_source + # Fallback: Check store for repository_source try: - from backend.database import SessionLocal, LlamaVersion - - db = SessionLocal() - try: - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - if active_version and active_version.repository_source: - is_ik = active_version.repository_source == "ik_llama.cpp" - if is_ik: - logger.debug( - f"Detected ik_llama.cpp via database repository_source: {active_version.repository_source}" - ) - return is_ik - finally: - db.close() + from backend.data_store import get_store + store = get_store() + active_version = store.get_active_engine_version("ik_llama") or store.get_active_engine_version("llama_cpp") + if active_version and active_version.get("repository_source"): + is_ik = active_version.get("repository_source") == "ik_llama.cpp" + if is_ik: + logger.debug( + f"Detected ik_llama.cpp via store repository_source: {active_version.get('repository_source')}" + ) + return is_ik except Exception as e: - logger.debug(f"Error checking database for ik_llama.cpp: {e}") + logger.debug(f"Error checking store for ik_llama.cpp: {e}") return False @@ -395,43 +389,67 @@ def get_param_mapping(is_ik: bool) -> Dict[str, list]: def get_active_binary_path_from_db() -> Optional[str]: """ - Gets the active llama-server binary path from the database. + Gets the active llama-server binary path from the data store. Returns: Absolute path to the llama-server binary, or None if not found. """ try: - from backend.database import SessionLocal, LlamaVersion - - db = SessionLocal() - try: - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - if not active_version or not active_version.binary_path: - logger.warning("No active llama-cpp version found in database") - return None + from backend.data_store import get_store - # Convert to absolute path - binary_path = active_version.binary_path + store = get_store() + for engine in ("llama_cpp", "ik_llama"): + active_version = store.get_active_engine_version(engine) + if not active_version or not active_version.get("binary_path"): + continue + binary_path = active_version["binary_path"] if not os.path.isabs(binary_path): binary_path = os.path.join("/app", binary_path) - - # Verify the path exists if os.path.exists(binary_path): return binary_path - else: - logger.warning( - f"Binary path from database does not exist: {binary_path}" - ) - return None - finally: - db.close() + abs_path = os.path.abspath(binary_path) + if os.path.exists(abs_path): + return abs_path + logger.warning("No active llama-cpp version found in data store") + return None except Exception as e: - logger.error(f"Error getting binary path from database: {e}") + logger.error(f"Error getting binary path from data store: {e}") return None +def _build_lmdeploy_cmd( + model: Any, + config: Dict[str, Any], + lmdeploy_bin: str, + _model_attr: Any, +) -> str: + """Build lmdeploy serve api_server command for llama-swap config.""" + hf_id = _model_attr(model, "huggingface_id") + if not hf_id: + raise ValueError("LMDeploy model must have huggingface_id") + cmd_parts = [lmdeploy_bin, "serve", "api_server", hf_id] + cmd_parts.extend(["--server-port", "${PORT}"]) + cmd_parts.extend(["--backend", "turbomind"]) + if config.get("session_len") is not None: + cmd_parts.extend(["--session-len", str(config["session_len"])]) + if config.get("max_batch_size") is not None: + cmd_parts.extend(["--max-batch-size", str(config["max_batch_size"])]) + if config.get("tensor_parallel") is not None: + cmd_parts.extend(["--tp", str(config["tensor_parallel"])]) + if config.get("dtype"): + cmd_parts.extend(["--dtype", str(config["dtype"])]) + if config.get("quant_policy") is not None: + cmd_parts.extend(["--quant-policy", str(config["quant_policy"])]) + if config.get("enable_prefix_caching"): + cmd_parts.append("--enable-prefix-caching") + if config.get("chat_template"): + cmd_parts.extend(["--chat-template", str(config["chat_template"])]) + # Escape single quotes in the command for bash -c '...' + inner_cmd = " ".join(cmd_parts) + inner_cmd = inner_cmd.replace("'", "'\\''") + return f"bash -c '{inner_cmd}'" + + def generate_llama_swap_config( models: Dict[str, Dict[str, Any]], llama_server_path: Optional[str] = None, @@ -486,20 +504,83 @@ def generate_llama_swap_config( "models": {}, } - # First, add all models from the database (if provided) + def _model_attr(m: Any, key: str, default: Any = None) -> Any: + """Get attribute from model (dict or object).""" + if isinstance(m, dict): + return m.get(key, default) + return getattr(m, key, default) + + # Resolve LMDeploy binary and build proxy->model map for overlay (used for both all_models and running overlay) + lmdeploy_bin = None + all_models_by_proxy: Dict[str, Any] = {} + try: + from backend.data_store import get_store as _get_store + store = _get_store() + lmdeploy_status = store.get_lmdeploy_status() + if lmdeploy_status.get("installed") and lmdeploy_status.get("venv_path"): + venv = lmdeploy_status["venv_path"] + lmdeploy_bin = os.path.join(venv, "bin", "lmdeploy") + if not os.path.isabs(lmdeploy_bin): + lmdeploy_bin = os.path.join("/app", lmdeploy_bin) + if not os.path.exists(lmdeploy_bin): + lmdeploy_bin = None + except Exception as e: + logger.debug(f"Could not resolve LMDeploy binary: {e}") + + # First, add all models from the data store (if provided) if all_models: + from backend.data_store import generate_proxy_name as _gen_proxy_name + for model in all_models: - # Use the centralized proxy name from the database - if not model.proxy_name: + proxy_model_name = _model_attr(model, "proxy_name") + if not proxy_model_name: + proxy_model_name = _gen_proxy_name( + _model_attr(model, "huggingface_id", ""), + _model_attr(model, "quantization"), + ) + if not proxy_model_name: logger.warning( - f"Model '{model.name}' does not have a proxy_name set, skipping" + f"Model '{_model_attr(model, 'display_name') or _model_attr(model, 'name')}' does not have a proxy_name set, skipping" ) continue + all_models_by_proxy[proxy_model_name] = model + + engine = _model_attr(model, "engine") + model_format = _model_attr(model, "format") or _model_attr(model, "model_format") or "gguf" + is_lmdeploy = engine == "lmdeploy" or model_format == "safetensors" + if is_lmdeploy and lmdeploy_bin: + config = _coerce_model_config(_model_attr(model, "config")) + try: + cmd_with_env = _build_lmdeploy_cmd(model, config, lmdeploy_bin, _model_attr) + config_data["models"][proxy_model_name] = {"cmd": cmd_with_env} + except Exception as e: + logger.warning(f"Failed to build LMDeploy cmd for {proxy_model_name}: {e}") + continue + + hf_id = _model_attr(model, "huggingface_id") + filename = _model_attr(model, "filename") or ( + os.path.basename(_model_attr(model, "file_path") or "") or None + ) + + # Resolve model path: HF cache first, then legacy file_path + model_path = None + if hf_id and filename: + from backend.huggingface import resolve_cached_model_path + model_path = resolve_cached_model_path(hf_id, filename) - proxy_model_name = model.proxy_name - model_path = model.file_path + if not model_path: + # Legacy fallback: stored file_path (old-style records) + legacy = _model_attr(model, "file_path") + if legacy: + model_path = legacy if os.path.isabs(legacy) else f"/app/{legacy}" - # Convert model path to absolute path + if not model_path: + logger.warning( + f"Model '{proxy_model_name}' path could not be resolved (hf_id={hf_id}, filename={filename}), skipping" + ) + continue + + # Ensure absolute path (HF cache returns absolute; legacy may not) if not os.path.isabs(model_path): model_path = f"/app/{model_path}" @@ -523,7 +604,7 @@ def generate_llama_swap_config( working_dir = working_dir.replace("/bin/", "/build/bin/") # Parse existing config if available - config = _coerce_model_config(model.config) + config = _coerce_model_config(_model_attr(model, "config")) if proxy_model_name and config.get("jinja") is not None: logger.debug( f"Model {proxy_model_name}: jinja={config.get('jinja')} (type: {type(config.get('jinja'))})" @@ -539,6 +620,16 @@ def generate_llama_swap_config( "--port", "${PORT}", ] + # Vision: if model has mmproj (multimodal projector), add --mmproj so vision is available + mmproj_filename = _model_attr(model, "mmproj_filename") + if mmproj_filename and hf_id: + from backend.huggingface import resolve_cached_model_path + mmproj_path = resolve_cached_model_path(hf_id, mmproj_filename) + if mmproj_path and os.path.exists(mmproj_path): + if not os.path.isabs(mmproj_path): + mmproj_path = f"/app/{mmproj_path}" + quoted_mmproj = _quote_arg_if_needed(mmproj_path) + cmd_args.extend(["--mmproj", quoted_mmproj]) # Default values to skip (these cause errors if flag isn't supported) default_values = { @@ -740,6 +831,19 @@ def generate_llama_swap_config( # Then, add/update with running models (these take precedence for active models) for proxy_model_name, model_data in models.items(): + overlay_model = all_models_by_proxy.get(proxy_model_name) + engine = _model_attr(overlay_model, "engine") if overlay_model else None + model_format = _model_attr(overlay_model, "format") or _model_attr(overlay_model, "model_format") if overlay_model else None + is_lmdeploy_overlay = (engine == "lmdeploy" or model_format == "safetensors") and lmdeploy_bin and overlay_model + if is_lmdeploy_overlay: + config = _coerce_model_config(model_data.get("config")) + try: + cmd_with_env = _build_lmdeploy_cmd(overlay_model, config, lmdeploy_bin, _model_attr) + config_data["models"][proxy_model_name] = {"cmd": cmd_with_env} + except Exception as e: + logger.warning(f"Failed to build LMDeploy overlay cmd for {proxy_model_name}: {e}") + continue + model_path = model_data["model_path"] llama_cpp_config = model_data["config"] @@ -753,6 +857,17 @@ def generate_llama_swap_config( "--port", "${PORT}", ] + # Vision: add --mmproj if model has mmproj_filename + if overlay_model: + mmproj_fn = _model_attr(overlay_model, "mmproj_filename") + hf_id_overlay = _model_attr(overlay_model, "huggingface_id") + if mmproj_fn and hf_id_overlay: + from backend.huggingface import resolve_cached_model_path + mmproj_path = resolve_cached_model_path(hf_id_overlay, mmproj_fn) + if mmproj_path and os.path.exists(mmproj_path): + if not os.path.isabs(mmproj_path): + mmproj_path = f"/app/{mmproj_path}" + cmd_args.extend(["--mmproj", _quote_arg_if_needed(mmproj_path)]) # Default values to skip (these cause errors if flag isn't supported) default_values = { diff --git a/backend/llama_swap_manager.py b/backend/llama_swap_manager.py index 59f6544..d08a0c1 100644 --- a/backend/llama_swap_manager.py +++ b/backend/llama_swap_manager.py @@ -3,9 +3,9 @@ import os import yaml import httpx -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional from backend.llama_swap_config import generate_llama_swap_config -from backend.database import Model +from backend.data_store import get_store from backend.logging_config import get_logger logger = get_logger(__name__) @@ -57,17 +57,12 @@ async def _write_config(self, llama_server_path: str = None): "No llama-server binary path provided and none found in database" ) - # Load all models from database to include them in config - from backend.database import get_db, Model - - db = next(get_db()) - try: - all_models = db.query(Model).all() - config_content = generate_llama_swap_config( - self.running_models, llama_server_path, all_models - ) - finally: - db.close() + # Load all models from data store to include them in config + store = get_store() + all_models = store.list_models() + config_content = generate_llama_swap_config( + self.running_models, llama_server_path, all_models + ) # Ensure directory exists config_dir = os.path.dirname(self.config_path) @@ -346,32 +341,34 @@ async def restart_proxy(self): await self.start_proxy() logger.info("llama-swap proxy restarted successfully") - async def register_model(self, model: Model, config: Dict[str, Any]) -> str: + async def register_model(self, model: Any, config: Dict[str, Any]) -> str: """ Registers a model with llama-swap by storing its configuration. Returns the proxy_model_name used by llama-swap. Note: This only stores the model info, config is written separately. + model can be a dict or an object with proxy_name, file_path, display_name/name. """ - # Use the centralized proxy name from the database - if not model.proxy_name: - raise ValueError(f"Model '{model.name}' does not have a proxy_name set") + proxy_name = model.get("proxy_name") if isinstance(model, dict) else getattr(model, "proxy_name", None) + file_path = model.get("file_path") if isinstance(model, dict) else getattr(model, "file_path", None) + name = model.get("display_name") or model.get("name") if isinstance(model, dict) else (getattr(model, "display_name", None) or getattr(model, "name", None)) - proxy_model_name = model.proxy_name + if not proxy_name: + raise ValueError(f"Model '{name}' does not have a proxy_name set") - if proxy_model_name in self.running_models: + if proxy_name in self.running_models: raise ValueError( - f"Model '{proxy_model_name}' is already registered with llama-swap." + f"Model '{proxy_name}' is already registered with llama-swap." ) - self.running_models[proxy_model_name] = { - "model_path": model.file_path, + self.running_models[proxy_name] = { + "model_path": file_path, "config": config, } logger.info( - f"Model '{model.name}' registered as '{proxy_model_name}' with llama-swap" + f"Model '{name}' registered as '{proxy_name}' with llama-swap" ) - return proxy_model_name + return proxy_name def _detect_correct_binary_path(self, version_dir: str) -> str: """ @@ -406,46 +403,35 @@ async def _ensure_correct_binary_path(self): Ensures the active llama-cpp version has the correct binary path. Automatically detects and updates if needed. """ - from backend.database import SessionLocal, LlamaVersion - - db = SessionLocal() - try: - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) + store = get_store() + for engine in ("llama_cpp", "ik_llama"): + active_version = store.get_active_engine_version(engine) if not active_version: - logger.warning("No active llama-cpp version found") - return - - # Convert relative path to absolute - version_dir = active_version.binary_path + continue + version_dir = active_version.get("binary_path") + if not version_dir: + continue if not os.path.isabs(version_dir): version_dir = os.path.join("/app", version_dir) - - # Get the directory containing the binary binary_dir = os.path.dirname(version_dir) - - # Detect the correct binary path correct_binary_path = self._detect_correct_binary_path(binary_dir) - - # Convert back to relative path for database storage relative_path = os.path.relpath(correct_binary_path, "/app") - - # Update database if path has changed - if active_version.binary_path != relative_path: + if active_version.get("binary_path") != relative_path: logger.info( - f"Updating binary path from '{active_version.binary_path}' to '{relative_path}'" + f"Updating binary path from '{active_version.get('binary_path')}' to '{relative_path}'" ) - active_version.binary_path = relative_path - db.commit() + data = store._read_yaml("engines.yaml") + engine_data = data.get(engine, {}) + for i, v in enumerate(engine_data.get("versions", [])): + if v.get("version") == active_version.get("version"): + engine_data["versions"][i] = {**v, "binary_path": relative_path} + break + store._save_yaml("engines.yaml", data) logger.info("Binary path updated successfully") else: logger.debug(f"Binary path is already correct: {relative_path}") - - except Exception as e: - logger.error(f"Error ensuring correct binary path: {e}") - finally: - db.close() + return + logger.warning("No active llama-cpp version found") async def regenerate_config_with_active_version(self): """ @@ -454,52 +440,39 @@ async def regenerate_config_with_active_version(self): Automatically detects and fixes binary path if needed. Ensures llama-swap is running if an active version exists. """ - from backend.database import SessionLocal, LlamaVersion - - # First, ensure the binary path is correct await self._ensure_correct_binary_path() - db = SessionLocal() - try: - # Get the active version - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - if not active_version: - logger.warning( - "No active llama-cpp version found, skipping config regeneration" - ) - return - - # Convert to absolute path for existence check - binary_path = active_version.binary_path - if not os.path.isabs(binary_path): - binary_path = os.path.join("/app", binary_path) - - if not os.path.exists(binary_path): - logger.warning(f"Active version binary not found: {binary_path}") - return - - # Sync running_models with actual llama-swap state - await self.sync_running_models() - - # Regenerate config with active version and synced running_models - await self._write_config(active_version.binary_path) - logger.info( - f"Regenerated llama-swap config with active version: {active_version.version} and {len(self.running_models)} running models" + store = get_store() + active_version = None + for engine in ("llama_cpp", "ik_llama"): + active_version = store.get_active_engine_version(engine) + if active_version: + break + if not active_version: + logger.warning( + "No active llama-cpp version found, skipping config regeneration" ) + return - # Ensure llama-swap is running when we have an active version - try: - await self.start_proxy() - logger.info("Ensured llama-swap is running after config regeneration") - except Exception as e: - logger.warning(f"Failed to start llama-swap after config regeneration: {e}") + binary_path = active_version.get("binary_path") + if not binary_path: + return + if not os.path.isabs(binary_path): + binary_path = os.path.join("/app", binary_path) + if not os.path.exists(binary_path): + logger.warning(f"Active version binary not found: {binary_path}") + return + await self.sync_running_models() + await self._write_config(active_version.get("binary_path")) + logger.info( + f"Regenerated llama-swap config with active version: {active_version.get('version')} and {len(self.running_models)} running models" + ) + try: + await self.start_proxy() + logger.info("Ensured llama-swap is running after config regeneration") except Exception as e: - logger.error(f"Failed to regenerate config with active version: {e}") - finally: - db.close() + logger.warning(f"Failed to start llama-swap after config regeneration: {e}") async def unregister_model(self, proxy_model_name: str): """ diff --git a/backend/lmdeploy_installer.py b/backend/lmdeploy_installer.py index 73e7610..875b2f7 100644 --- a/backend/lmdeploy_installer.py +++ b/backend/lmdeploy_installer.py @@ -1,362 +1,416 @@ -import asyncio -import json -import os -import shutil -import subprocess -import sys -from asyncio.subprocess import PIPE, STDOUT -from datetime import datetime, timezone -from typing import Any, Awaitable, Dict, Optional - -from backend.logging_config import get_logger -from backend.websocket_manager import websocket_manager - - -def _utcnow() -> str: - return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - - -logger = get_logger(__name__) - -_installer_instance: Optional["LMDeployInstaller"] = None - - -def get_lmdeploy_installer() -> "LMDeployInstaller": - global _installer_instance - if _installer_instance is None: - _installer_instance = LMDeployInstaller() - return _installer_instance - - -class LMDeployInstaller: - """Install or remove LMDeploy inside the runtime environment on demand.""" - - def __init__( - self, - *, - log_path: Optional[str] = None, - state_path: Optional[str] = None, - base_dir: Optional[str] = None, - ) -> None: - self._lock = asyncio.Lock() - self._operation: Optional[str] = None - self._operation_started_at: Optional[str] = None - self._current_task: Optional[asyncio.Task] = None - self._last_error: Optional[str] = None - data_root = os.path.abspath("data") - base_path = base_dir or os.path.join(data_root, "lmdeploy") - self._base_dir = os.path.abspath(base_path) - self._venv_path = os.path.join(self._base_dir, "venv") - log_path = log_path or os.path.join(data_root, "logs", "lmdeploy_install.log") - state_path = state_path or os.path.join( - data_root, "configs", "lmdeploy_installer.json" - ) - self._log_path = os.path.abspath(log_path) - self._state_path = os.path.abspath(state_path) - self._ensure_directories() - - def _ensure_directories(self) -> None: - os.makedirs(self._base_dir, exist_ok=True) - os.makedirs(os.path.dirname(self._log_path), exist_ok=True) - os.makedirs(os.path.dirname(self._state_path), exist_ok=True) - - def _venv_bin(self, executable: str) -> str: - if os.name == "nt": - exe = ( - executable - if executable.lower().endswith(".exe") - else f"{executable}.exe" - ) - return os.path.join(self._venv_path, "Scripts", exe) - return os.path.join(self._venv_path, "bin", executable) - - def _venv_python(self) -> str: - return self._venv_bin("python") - - def _ensure_venv(self) -> None: - python_path = self._venv_python() - if os.path.exists(python_path): - return - os.makedirs(self._base_dir, exist_ok=True) - try: - subprocess.run([sys.executable, "-m", "venv", self._venv_path], check=True) - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"Failed to create LMDeploy virtual environment: {exc}" - ) from exc - - def _load_state(self) -> Dict[str, Any]: - if not os.path.exists(self._state_path): - return {} - try: - with open(self._state_path, "r", encoding="utf-8") as handle: - data = json.load(handle) - return data if isinstance(data, dict) else {} - except Exception as exc: - logger.warning(f"Failed to load LMDeploy installer state: {exc}") - return {} - - def _save_state(self, state: Dict[str, Any]) -> None: - tmp_path = f"{self._state_path}.tmp" - with open(tmp_path, "w", encoding="utf-8") as handle: - json.dump(state, handle, indent=2) - os.replace(tmp_path, self._state_path) - - def _detect_installed_version(self) -> Optional[str]: - python_exe = self._venv_python() - if not os.path.exists(python_exe): - return None - script = ( - "import importlib, sys\n" - "try:\n" - " from importlib import metadata\n" - "except ImportError:\n" - " import importlib_metadata as metadata\n" - "try:\n" - " print(metadata.version('lmdeploy'))\n" - "except metadata.PackageNotFoundError:\n" - " sys.exit(1)\n" - ) - try: - output = subprocess.check_output( - [python_exe, "-c", script], text=True - ).strip() - return output or None - except subprocess.CalledProcessError: - return None - except Exception as exc: # pragma: no cover - logger.debug(f"Unable to determine LMDeploy version: {exc}") - return None - - def _resolve_binary_path(self) -> Optional[str]: - override = os.getenv("LMDEPLOY_BIN") - if override: - override_path = os.path.abspath(os.path.expanduser(override)) - if os.path.exists(override_path): - return override_path - resolved_override = shutil.which(override) - if resolved_override: - return resolved_override - - candidate = self._venv_bin("lmdeploy") - if os.path.exists(candidate) and os.access(candidate, os.X_OK): - return os.path.abspath(candidate) - - resolved = shutil.which("lmdeploy") - return resolved - - def _update_installed_state( - self, installed: bool, version: Optional[str] = None - ) -> None: - state = self._load_state() - if installed: - state["installed_at"] = _utcnow() - if version: - state["installed_version"] = version - state["venv_path"] = self._venv_path - else: - state["installed_version"] = None - state["installed_at"] = None - state["removed_at"] = _utcnow() - state["venv_path"] = self._venv_path - self._save_state(state) - - def _refresh_state_from_environment(self) -> None: - state = self._load_state() - version = self._detect_installed_version() - state["installed_version"] = version - if version is None: - state["removed_at"] = _utcnow() - state["venv_path"] = self._venv_path - self._save_state(state) - - async def _run_pip( - self, args: list[str], operation: str, ensure_venv: bool = True - ) -> int: - if ensure_venv: - self._ensure_venv() - python_exe = self._venv_python() - if not os.path.exists(python_exe): - raise RuntimeError( - "LMDeploy virtual environment is missing; cannot run pip." - ) - header = ( - f"[{_utcnow()}] Starting LMDeploy {operation} via pip {' '.join(args)}\n" - ) - with open(self._log_path, "w", encoding="utf-8") as log_file: - log_file.write(header) - process = await asyncio.create_subprocess_exec( - python_exe, - "-m", - "pip", - *args, - stdout=PIPE, - stderr=STDOUT, - ) - - async def _stream_output() -> None: - if process.stdout is None: - return - with open(self._log_path, "a", encoding="utf-8", buffering=1) as log_file: - while True: - chunk = await process.stdout.readline() - if not chunk: - break - text = chunk.decode("utf-8", errors="replace") - log_file.write(text) - await self._broadcast_log_line(text.rstrip("\n")) - - await asyncio.gather(process.wait(), _stream_output()) - return process.returncode or 0 - - async def _broadcast_log_line(self, line: str) -> None: - try: - await websocket_manager.broadcast( - { - "type": "lmdeploy_install_log", - "line": line, - "timestamp": _utcnow(), - } - ) - except Exception as exc: # pragma: no cover - logger.debug(f"Failed to broadcast LMDeploy log line: {exc}") - - async def _set_operation(self, operation: str) -> None: - self._operation = operation - self._operation_started_at = _utcnow() - self._last_error = None - await websocket_manager.broadcast( - { - "type": "lmdeploy_install_status", - "status": operation, - "started_at": self._operation_started_at, - } - ) - - async def _finish_operation(self, success: bool, message: str = "") -> None: - payload = { - "type": "lmdeploy_install_status", - "status": "completed" if success else "failed", - "operation": self._operation, - "message": message, - "ended_at": _utcnow(), - } - await websocket_manager.broadcast(payload) - self._operation = None - self._operation_started_at = None - - def _create_task(self, coro: Awaitable[Any]) -> None: - loop = asyncio.get_running_loop() - task = loop.create_task(coro) - self._current_task = task - - def _cleanup(fut: asyncio.Future) -> None: - try: - fut.result() - except Exception as exc: # pragma: no cover - surfaced via status - logger.error(f"LMDeploy installer task error: {exc}") - finally: - self._current_task = None - - task.add_done_callback(_cleanup) - - async def install( - self, version: Optional[str] = None, force_reinstall: bool = False - ) -> Dict[str, Any]: - async with self._lock: - if self._operation: - raise RuntimeError( - "Another LMDeploy installer operation is already running" - ) - await self._set_operation("install") - args = ["install", "--upgrade"] - if force_reinstall: - args.append("--force-reinstall") - package = "lmdeploy" - if version: - package = f"lmdeploy=={version}" - args.append(package) - - async def _runner(): - try: - code = await self._run_pip(args, "install") - if code != 0: - raise RuntimeError(f"pip exited with status {code}") - detected_version = self._detect_installed_version() - self._update_installed_state(True, detected_version) - await self._finish_operation(True, "LMDeploy installed") - except Exception as exc: - self._last_error = str(exc) - self._refresh_state_from_environment() - await self._finish_operation(False, str(exc)) - - self._create_task(_runner()) - return {"message": "LMDeploy installation started"} - - async def remove(self) -> Dict[str, Any]: - async with self._lock: - if self._operation: - raise RuntimeError( - "Another LMDeploy installer operation is already running" - ) - await self._set_operation("remove") - args = ["uninstall", "-y", "lmdeploy"] - - async def _runner(): - try: - python_exists = os.path.exists(self._venv_python()) - if python_exists: - code = await self._run_pip(args, "remove", ensure_venv=False) - if code != 0: - raise RuntimeError(f"pip exited with status {code}") - shutil.rmtree(self._venv_path, ignore_errors=True) - self._update_installed_state(False) - await self._finish_operation(True, "LMDeploy removed") - except Exception as exc: - self._last_error = str(exc) - self._refresh_state_from_environment() - await self._finish_operation(False, str(exc)) - - self._create_task(_runner()) - return {"message": "LMDeploy removal started"} - - def status(self) -> Dict[str, Any]: - version = self._detect_installed_version() - binary_path = self._resolve_binary_path() - installed = version is not None and binary_path is not None - state = self._load_state() - return { - "installed": installed, - "version": version, - "binary_path": binary_path, - "venv_path": state.get("venv_path") or self._venv_path, - "installed_at": state.get("installed_at"), - "removed_at": state.get("removed_at"), - "operation": self._operation, - "operation_started_at": self._operation_started_at, - "last_error": self._last_error, - "log_path": self._log_path, - } - - async def _broadcast_status(self) -> None: - """Broadcast current status via WebSocket.""" - try: - status_data = self.status() - await websocket_manager.send_lmdeploy_status(status_data) - except Exception as exc: - logger.debug(f"Failed to broadcast LMDeploy status: {exc}") - - def is_operation_running(self) -> bool: - return self._operation is not None - - def read_log_tail(self, max_bytes: int = 8192) -> str: - if not os.path.exists(self._log_path): - return "" - with open(self._log_path, "rb") as log_file: - log_file.seek(0, os.SEEK_END) - size = log_file.tell() - log_file.seek(max(0, size - max_bytes)) - data = log_file.read().decode("utf-8", errors="replace") - if size > max_bytes: - data = data.split("\n", 1)[-1] - return data.strip() +import asyncio +import json +import os +import shutil +import subprocess +import sys +from asyncio.subprocess import PIPE, STDOUT +from datetime import datetime, timezone +from typing import Any, Awaitable, Dict, Optional + +from backend.logging_config import get_logger +from backend.progress_manager import get_progress_manager + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +logger = get_logger(__name__) + +_installer_instance: Optional["LMDeployInstaller"] = None + + +def get_lmdeploy_installer() -> "LMDeployInstaller": + global _installer_instance + if _installer_instance is None: + _installer_instance = LMDeployInstaller() + return _installer_instance + + +class LMDeployInstaller: + """Install or remove LMDeploy inside the runtime environment on demand.""" + + def __init__( + self, + *, + log_path: Optional[str] = None, + state_path: Optional[str] = None, + base_dir: Optional[str] = None, + ) -> None: + self._lock = asyncio.Lock() + self._operation: Optional[str] = None + self._operation_started_at: Optional[str] = None + self._current_task: Optional[asyncio.Task] = None + self._last_error: Optional[str] = None + data_root = os.path.abspath("data") + base_path = base_dir or os.path.join(data_root, "lmdeploy") + self._base_dir = os.path.abspath(base_path) + self._venv_path = os.path.join(self._base_dir, "venv") + log_path = log_path or os.path.join(data_root, "logs", "lmdeploy_install.log") + state_path = state_path or os.path.join( + data_root, "configs", "lmdeploy_installer.json" + ) + self._log_path = os.path.abspath(log_path) + self._state_path = os.path.abspath(state_path) + self._ensure_directories() + + def _ensure_directories(self) -> None: + os.makedirs(self._base_dir, exist_ok=True) + os.makedirs(os.path.dirname(self._log_path), exist_ok=True) + os.makedirs(os.path.dirname(self._state_path), exist_ok=True) + + def _venv_bin(self, executable: str) -> str: + if os.name == "nt": + exe = ( + executable + if executable.lower().endswith(".exe") + else f"{executable}.exe" + ) + return os.path.join(self._venv_path, "Scripts", exe) + return os.path.join(self._venv_path, "bin", executable) + + def _venv_python(self) -> str: + return self._venv_bin("python") + + def _ensure_venv(self) -> None: + python_path = self._venv_python() + if os.path.exists(python_path): + return + os.makedirs(self._base_dir, exist_ok=True) + try: + subprocess.run([sys.executable, "-m", "venv", self._venv_path], check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"Failed to create LMDeploy virtual environment: {exc}" + ) from exc + + def _load_state(self) -> Dict[str, Any]: + if not os.path.exists(self._state_path): + return {} + try: + with open(self._state_path, "r", encoding="utf-8") as handle: + data = json.load(handle) + return data if isinstance(data, dict) else {} + except Exception as exc: + logger.warning(f"Failed to load LMDeploy installer state: {exc}") + return {} + + def _save_state(self, state: Dict[str, Any]) -> None: + tmp_path = f"{self._state_path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as handle: + json.dump(state, handle, indent=2) + os.replace(tmp_path, self._state_path) + + def _detect_installed_version(self) -> Optional[str]: + python_exe = self._venv_python() + if not os.path.exists(python_exe): + return None + script = ( + "import importlib, sys\n" + "try:\n" + " from importlib import metadata\n" + "except ImportError:\n" + " import importlib_metadata as metadata\n" + "try:\n" + " print(metadata.version('lmdeploy'))\n" + "except metadata.PackageNotFoundError:\n" + " sys.exit(1)\n" + ) + try: + output = subprocess.check_output( + [python_exe, "-c", script], text=True + ).strip() + return output or None + except subprocess.CalledProcessError: + return None + except Exception as exc: # pragma: no cover + logger.debug(f"Unable to determine LMDeploy version: {exc}") + return None + + def _resolve_binary_path(self) -> Optional[str]: + override = os.getenv("LMDEPLOY_BIN") + if override: + override_path = os.path.abspath(os.path.expanduser(override)) + if os.path.exists(override_path): + return override_path + resolved_override = shutil.which(override) + if resolved_override: + return resolved_override + + candidate = self._venv_bin("lmdeploy") + if os.path.exists(candidate) and os.access(candidate, os.X_OK): + return os.path.abspath(candidate) + + resolved = shutil.which("lmdeploy") + return resolved + + def _update_installed_state( + self, installed: bool, version: Optional[str] = None + ) -> None: + state = self._load_state() + if installed: + state["installed_at"] = _utcnow() + if version: + state["installed_version"] = version + state["venv_path"] = self._venv_path + else: + state["installed_version"] = None + state["installed_at"] = None + state["removed_at"] = _utcnow() + state["venv_path"] = self._venv_path + self._save_state(state) + + def _refresh_state_from_environment(self) -> None: + state = self._load_state() + version = self._detect_installed_version() + state["installed_version"] = version + if version is None: + state["removed_at"] = _utcnow() + state["venv_path"] = self._venv_path + self._save_state(state) + + async def _run_pip( + self, + args: list[str], + operation: str, + ensure_venv: bool = True, + cwd: Optional[str] = None, + ) -> int: + if ensure_venv: + self._ensure_venv() + python_exe = self._venv_python() + if not os.path.exists(python_exe): + raise RuntimeError( + "LMDeploy virtual environment is missing; cannot run pip." + ) + header = ( + f"[{_utcnow()}] Starting LMDeploy {operation} via pip {' '.join(args)}\n" + ) + with open(self._log_path, "w", encoding="utf-8") as log_file: + log_file.write(header) + process = await asyncio.create_subprocess_exec( + python_exe, + "-m", + "pip", + *args, + stdout=PIPE, + stderr=STDOUT, + cwd=cwd, + ) + + async def _stream_output() -> None: + if process.stdout is None: + return + with open(self._log_path, "a", encoding="utf-8", buffering=1) as log_file: + while True: + chunk = await process.stdout.readline() + if not chunk: + break + text = chunk.decode("utf-8", errors="replace") + log_file.write(text) + await self._broadcast_log_line(text.rstrip("\n")) + + await asyncio.gather(process.wait(), _stream_output()) + return process.returncode or 0 + + async def _broadcast_log_line(self, line: str) -> None: + try: + await get_progress_manager().broadcast( + { + "type": "lmdeploy_install_log", + "line": line, + "timestamp": _utcnow(), + } + ) + except Exception as exc: # pragma: no cover + logger.debug(f"Failed to broadcast LMDeploy log line: {exc}") + + async def _set_operation(self, operation: str) -> None: + self._operation = operation + self._operation_started_at = _utcnow() + self._last_error = None + await get_progress_manager().broadcast( + { + "type": "lmdeploy_install_status", + "status": operation, + "started_at": self._operation_started_at, + } + ) + + async def _finish_operation(self, success: bool, message: str = "") -> None: + payload = { + "type": "lmdeploy_install_status", + "status": "completed" if success else "failed", + "operation": self._operation, + "message": message, + "ended_at": _utcnow(), + } + await get_progress_manager().broadcast(payload) + self._operation = None + self._operation_started_at = None + + def _create_task(self, coro: Awaitable[Any]) -> None: + loop = asyncio.get_running_loop() + task = loop.create_task(coro) + self._current_task = task + + def _cleanup(fut: asyncio.Future) -> None: + try: + fut.result() + except Exception as exc: # pragma: no cover - surfaced via status + logger.error(f"LMDeploy installer task error: {exc}") + finally: + self._current_task = None + + task.add_done_callback(_cleanup) + + async def install( + self, version: Optional[str] = None, force_reinstall: bool = False + ) -> Dict[str, Any]: + async with self._lock: + if self._operation: + raise RuntimeError( + "Another LMDeploy installer operation is already running" + ) + await self._set_operation("install") + args = ["install", "--upgrade"] + if force_reinstall: + args.append("--force-reinstall") + package = "lmdeploy" + if version: + package = f"lmdeploy=={version}" + args.append(package) + + async def _runner(): + try: + code = await self._run_pip(args, "install") + if code != 0: + raise RuntimeError(f"pip exited with status {code}") + detected_version = self._detect_installed_version() + self._update_installed_state(True, detected_version) + await self._finish_operation(True, "LMDeploy installed") + except Exception as exc: + self._last_error = str(exc) + self._refresh_state_from_environment() + await self._finish_operation(False, str(exc)) + + self._create_task(_runner()) + return {"message": "LMDeploy installation started"} + + async def install_from_source( + self, + repo_url: str = "https://github.com/InternLM/lmdeploy.git", + branch: str = "main", + ) -> Dict[str, Any]: + """Install LMDeploy from a git repo and branch (for development).""" + async with self._lock: + if self._operation: + raise RuntimeError( + "Another LMDeploy installer operation is already running" + ) + await self._set_operation("install_source") + clone_dir = os.path.join(self._base_dir, "source") + async def _runner(): + try: + self._ensure_venv() + if os.path.exists(clone_dir): + shutil.rmtree(clone_dir) + os.makedirs(clone_dir, exist_ok=True) + proc = await asyncio.create_subprocess_exec( + "git", "clone", "--depth", "1", "--branch", branch, repo_url, clone_dir, + stdout=PIPE, stderr=STDOUT, + ) + await proc.wait() + if proc.returncode != 0: + raise RuntimeError(f"git clone failed with code {proc.returncode}") + code = await self._run_pip( + ["install", "-e", "."], + "install_source", + cwd=clone_dir, + ) + if code != 0: + raise RuntimeError(f"pip install -e . failed with code {code}") + detected = self._detect_installed_version() + self._update_installed_state(True, detected) + from backend.data_store import get_store + get_store().update_lmdeploy({ + "install_type": "source", + "source_repo": repo_url, + "source_branch": branch, + }) + await self._finish_operation(True, f"Installed from {branch}") + except Exception as exc: + self._last_error = str(exc) + self._refresh_state_from_environment() + await self._finish_operation(False, str(exc)) + self._create_task(_runner()) + return {"message": "LMDeploy install from source started", "repo": repo_url, "branch": branch} + + async def remove(self) -> Dict[str, Any]: + async with self._lock: + if self._operation: + raise RuntimeError( + "Another LMDeploy installer operation is already running" + ) + await self._set_operation("remove") + args = ["uninstall", "-y", "lmdeploy"] + + async def _runner(): + try: + python_exists = os.path.exists(self._venv_python()) + if python_exists: + code = await self._run_pip(args, "remove", ensure_venv=False) + if code != 0: + raise RuntimeError(f"pip exited with status {code}") + shutil.rmtree(self._venv_path, ignore_errors=True) + self._update_installed_state(False) + await self._finish_operation(True, "LMDeploy removed") + except Exception as exc: + self._last_error = str(exc) + self._refresh_state_from_environment() + await self._finish_operation(False, str(exc)) + + self._create_task(_runner()) + return {"message": "LMDeploy removal started"} + + def status(self) -> Dict[str, Any]: + version = self._detect_installed_version() + binary_path = self._resolve_binary_path() + installed = version is not None and binary_path is not None + state = self._load_state() + return { + "installed": installed, + "version": version, + "binary_path": binary_path, + "venv_path": state.get("venv_path") or self._venv_path, + "installed_at": state.get("installed_at"), + "removed_at": state.get("removed_at"), + "operation": self._operation, + "operation_started_at": self._operation_started_at, + "last_error": self._last_error, + "log_path": self._log_path, + } + + async def _broadcast_status(self) -> None: + """Broadcast current status via SSE.""" + try: + status_data = self.status() + get_progress_manager().emit("lmdeploy_status", {**status_data, "timestamp": _utcnow()}) + except Exception as exc: + logger.debug(f"Failed to broadcast LMDeploy status: {exc}") + + def is_operation_running(self) -> bool: + return self._operation is not None + + def read_log_tail(self, max_bytes: int = 8192) -> str: + if not os.path.exists(self._log_path): + return "" + with open(self._log_path, "rb") as log_file: + log_file.seek(0, os.SEEK_END) + size = log_file.tell() + log_file.seek(max(0, size - max_bytes)) + data = log_file.read().decode("utf-8", errors="replace") + if size > max_bytes: + data = data.split("\n", 1)[-1] + return data.strip() diff --git a/backend/lmdeploy_manager.py b/backend/lmdeploy_manager.py index 59e7d01..6328d71 100644 --- a/backend/lmdeploy_manager.py +++ b/backend/lmdeploy_manager.py @@ -1,877 +1,841 @@ -import asyncio -import json -import os -import shlex -import shutil -from datetime import datetime -from typing import Optional, Dict, Any, List - -import httpx -import psutil -from asyncio.subprocess import Process, STDOUT - -from backend.logging_config import get_logger -from backend.database import SessionLocal, Model, RunningInstance -from backend.huggingface import DEFAULT_LMDEPLOY_CONTEXT, MAX_LMDEPLOY_CONTEXT -from backend.websocket_manager import websocket_manager - -logger = get_logger(__name__) - -_lmdeploy_manager_instance: Optional["LMDeployManager"] = None - - -def get_lmdeploy_manager() -> "LMDeployManager": - """Return singleton LMDeploy manager.""" - global _lmdeploy_manager_instance - if _lmdeploy_manager_instance is None: - _lmdeploy_manager_instance = LMDeployManager() - return _lmdeploy_manager_instance - - -class LMDeployManager: - """Manage LMDeploy TurboMind runtime lifecycle.""" - - def __init__( - self, - binary_path: Optional[str] = None, - host: str = "0.0.0.0", - port: int = 2001, - ): - self.binary_path = binary_path or os.getenv("LMDEPLOY_BIN", "lmdeploy") - self.host = host - self.port = int(os.getenv("LMDEPLOY_PORT", port)) - self._process: Optional[Process] = None - self._log_file = None - self._lock = asyncio.Lock() - self._current_instance: Optional[Dict[str, Any]] = None - self._started_at: Optional[str] = None - self._log_path = os.path.join("data", "logs", "lmdeploy.log") - self._health_timeout = 180 # seconds - self._last_health_status: Optional[Dict[str, Any]] = None - self._last_detected_external: Optional[Dict[str, Any]] = None - self._last_broadcast_log_position = 0 - - async def start( - self, model_entry: Dict[str, Any], config: Dict[str, Any] - ) -> Dict[str, Any]: - """Start LMDeploy serving the provided model. Only one model may run at once.""" - async with self._lock: - if self._process and self._process.returncode is None: - raise RuntimeError("LMDeploy runtime is already running") - - model_path = model_entry.get("file_path") - if not model_path or not os.path.exists(model_path): - raise FileNotFoundError(f"Model file not found at {model_path}") - model_dir = model_entry.get("model_dir") or os.path.dirname(model_path) - if not os.path.isdir(model_dir): - raise FileNotFoundError(f"Model directory not found at {model_dir}") - model_dir_abs = os.path.abspath(model_dir) - - # Derive a stable model name for LMDeploy's --model-name flag. - # Preference order: - # 1) Explicit model_name passed in model_entry - # 2) Base model / display name from model_entry - # 3) Hugging Face repo id - # 4) Directory name - model_name = ( - model_entry.get("model_name") - or model_entry.get("display_name") - or model_entry.get("huggingface_id") - or os.path.basename(model_dir_abs.rstrip(os.sep)) - ) - - # Inject model_name into config passed to LMDeploy so the command builder - # can add --model-name and we persist it in status/config reflection. - effective_config = dict(config or {}) - if model_name and not effective_config.get("model_name"): - effective_config["model_name"] = model_name - - binary = self._resolve_binary() - command = self._build_command(binary, model_dir_abs, effective_config) - env = os.environ.copy() - env.setdefault("LMDEPLOY_LOG_DIR", os.path.dirname(self._log_path)) - os.makedirs(os.path.dirname(self._log_path), exist_ok=True) - self._log_file = open(self._log_path, "ab", buffering=0) - - logger.info(f"Starting LMDeploy with command: {' '.join(command)}") - self._process = await asyncio.create_subprocess_exec( - *command, - stdout=self._log_file, - stderr=STDOUT, - cwd=model_dir_abs, - env=env, - ) - self._started_at = datetime.utcnow().isoformat() + "Z" - self._current_instance = { - "model_id": model_entry.get("model_id"), - "huggingface_id": model_entry.get("huggingface_id"), - "file_path": model_path, - "config": effective_config, - "pid": self._process.pid, - } - - try: - await self._wait_for_ready() - except Exception as exc: - await self.stop(force=True) - raise exc - - return self.status() - - async def stop(self, force: bool = False) -> None: - """Stop LMDeploy process if running.""" - async with self._lock: - if not self._process: - return - if self._process.returncode is None: - try: - self._process.terminate() - await asyncio.wait_for(self._process.wait(), timeout=30) - except asyncio.TimeoutError: - logger.warning( - "LMDeploy did not terminate gracefully; killing process" - ) - self._process.kill() - await self._process.wait() - except ProcessLookupError: - logger.debug("LMDeploy process already stopped") - elif force: - try: - self._process.kill() - except ProcessLookupError: - pass - self._cleanup_process_state() - - async def restart( - self, model_entry: Dict[str, Any], config: Dict[str, Any] - ) -> Dict[str, Any]: - """Restart LMDeploy with a new model/config.""" - await self.stop() - return await self.start(model_entry, config) - - def status(self) -> Dict[str, Any]: - """Return status payload describing the running instance.""" - running = bool(self._process and self._process.returncode is None) - detection = None - if not running: - detection = self._detect_external_process() - if detection: - running = True - self._last_detected_external = detection - if not self._current_instance: - self._current_instance = detection.get("instance") - if not self._started_at: - self._started_at = detection.get("started_at") - else: - self._last_detected_external = None - else: - self._last_detected_external = None - - return { - "running": running, - "port": self.port, - "host": self.host, - "process_id": self._process.pid if running else None, - "started_at": self._started_at, - "current_instance": self._current_instance if running else None, - "health": self._last_health_status, - "binary_path": self._current_binary_path(), - "log_path": self._log_path, - "auto_detected": bool(detection), - "detection": detection, - } - - def _current_binary_path(self) -> Optional[str]: - try: - return self._resolve_binary() - except FileNotFoundError: - return None - - def _resolve_binary(self) -> str: - try: - from backend.lmdeploy_installer import get_lmdeploy_installer - - installer_binary = get_lmdeploy_installer().status().get("binary_path") - if installer_binary and os.path.exists(installer_binary): - return installer_binary - except Exception as exc: - logger.debug( - f"Failed to resolve LMDeploy binary via installer status: {exc}" - ) - - resolved = shutil.which(self.binary_path) - if resolved: - return resolved - - candidate = os.path.expanduser(self.binary_path) - if os.path.isabs(candidate) and os.path.exists(candidate): - return candidate - raise FileNotFoundError( - "LMDeploy binary not found in PATH. Install LMDeploy from the LMDeploy page or set LMDEPLOY_BIN." - ) - - def _build_command( - self, binary: str, model_dir: str, config: Dict[str, Any] - ) -> list: - """Convert stored config into lmdeploy CLI arguments.""" - tensor_parallel = max(1, int(config.get("tensor_parallel") or 1)) - base_session_len = max( - 1024, - int( - config.get("session_len") - or config.get("context_length") - or DEFAULT_LMDEPLOY_CONTEXT - ), - ) - rope_scaling_mode = str(config.get("rope_scaling_mode") or "disabled").lower() - rope_scaling_factor = float(config.get("rope_scaling_factor") or 1.0) - scaling_enabled = ( - rope_scaling_mode not in {"", "none", "disabled"} - and rope_scaling_factor > 1.0 - ) - effective_session_len = base_session_len - if scaling_enabled: - scaled = int(base_session_len * rope_scaling_factor) - effective_session_len = max( - base_session_len, min(scaled, MAX_LMDEPLOY_CONTEXT) - ) - max_batch_size = max(1, int(config.get("max_batch_size") or 4)) - base_prefill = int( - config.get("max_prefill_token_num") - or config.get("max_batch_tokens") - or (base_session_len * 2) - ) - if scaling_enabled: - scaled_prefill = int(base_prefill * rope_scaling_factor) - max_prefill_token_num = scaled_prefill - else: - max_prefill_token_num = base_prefill - - command = [ - binary, - "serve", - "api_server", - model_dir, - "--backend", - "turbomind", - "--server-name", - self.host, - "--server-port", - str(self.port), - "--tp", - str(tensor_parallel), - "--session-len", - str(effective_session_len), - "--max-batch-size", - str(max_batch_size), - ] - - # Optional model identity for OpenAI-style /v1/models listing - model_name = config.get("model_name") - if model_name and str(model_name).strip(): - command.extend(["--model-name", str(model_name).strip()]) - - # Optional inference settings - dtype = config.get("dtype") - if dtype and str(dtype).strip(): - command.extend(["--dtype", str(dtype).strip()]) - if max_prefill_token_num: - command.extend(["--max-prefill-token-num", str(max_prefill_token_num)]) - cache_max_entry_count = config.get("cache_max_entry_count") - if cache_max_entry_count is not None: - command.extend(["--cache-max-entry-count", str(cache_max_entry_count)]) - cache_block_seq_len = config.get("cache_block_seq_len") - if cache_block_seq_len: - command.extend(["--cache-block-seq-len", str(cache_block_seq_len)]) - if config.get("enable_prefix_caching"): - command.append("--enable-prefix-caching") - quant_policy = config.get("quant_policy") - if quant_policy is not None: - command.extend(["--quant-policy", str(quant_policy)]) - model_format = config.get("model_format") - if model_format and str(model_format).strip(): - command.extend(["--model-format", str(model_format).strip()]) - hf_overrides = config.get("hf_overrides") - if isinstance(hf_overrides, dict) and hf_overrides: - - def _flatten(prefix: str, value: Any): - if isinstance(value, dict): - for key, nested in value.items(): - if not isinstance(key, str) or not key: - continue - new_prefix = f"{prefix}.{key}" if prefix else key - yield from _flatten(new_prefix, nested) - else: - yield prefix, value - - def _format_override_value(val: Any) -> str: - if isinstance(val, bool): - return "true" if val else "false" - if val is None: - return "null" - return str(val) - - for path, value in _flatten("", hf_overrides): - if not path: - continue - command.extend( - [f"--hf-overrides.{path}", _format_override_value(value)] - ) - elif isinstance(hf_overrides, str) and hf_overrides.strip(): - command.extend(["--hf-overrides", hf_overrides.strip()]) - # LMDeploy uses --disable-metrics (inverted logic) - # When enable_metrics=false, send --disable-metrics - # When enable_metrics=true (default), don't send anything (metrics enabled by default) - if not config.get("enable_metrics", True): - command.append("--disable-metrics") - if scaling_enabled: - command.extend(["--rope-scaling-factor", str(rope_scaling_factor)]) - num_tokens_per_iter = config.get("num_tokens_per_iter") - if num_tokens_per_iter: - command.extend(["--num-tokens-per-iter", str(num_tokens_per_iter)]) - max_prefill_iters = config.get("max_prefill_iters") - if max_prefill_iters: - command.extend(["--max-prefill-iters", str(max_prefill_iters)]) - communicator = config.get("communicator") - if communicator and str(communicator).strip(): - command.extend(["--communicator", str(communicator).strip()]) - - # Server configuration parameters - allow_origins = config.get("allow_origins") - if allow_origins: - if isinstance(allow_origins, list): - command.extend( - ["--allow-origins"] + [str(origin) for origin in allow_origins] - ) - elif isinstance(allow_origins, str): - command.extend(["--allow-origins", allow_origins]) - if config.get("allow_credentials"): - command.append("--allow-credentials") - allow_methods = config.get("allow_methods") - if allow_methods: - if isinstance(allow_methods, list): - command.extend( - ["--allow-methods"] + [str(method) for method in allow_methods] - ) - elif isinstance(allow_methods, str): - command.extend(["--allow-methods", allow_methods]) - allow_headers = config.get("allow_headers") - if allow_headers: - if isinstance(allow_headers, list): - command.extend( - ["--allow-headers"] + [str(header) for header in allow_headers] - ) - elif isinstance(allow_headers, str): - command.extend(["--allow-headers", allow_headers]) - proxy_url = config.get("proxy_url") - if proxy_url and str(proxy_url).strip(): - command.extend(["--proxy-url", str(proxy_url).strip()]) - max_concurrent_requests = config.get("max_concurrent_requests") - if max_concurrent_requests is not None: - command.extend( - ["--max-concurrent-requests", str(int(max_concurrent_requests))] - ) - log_level = config.get("log_level") - if log_level and str(log_level).strip(): - command.extend(["--log-level", str(log_level).strip()]) - api_keys = config.get("api_keys") - if api_keys: - if isinstance(api_keys, list): - command.extend(["--api-keys"] + [str(key) for key in api_keys]) - elif isinstance(api_keys, str): - command.extend(["--api-keys", api_keys]) - if config.get("ssl"): - command.append("--ssl") - max_log_len = config.get("max_log_len") - if max_log_len is not None: - command.extend(["--max-log-len", str(int(max_log_len))]) - if config.get("disable_fastapi_docs"): - command.append("--disable-fastapi-docs") - if config.get("allow_terminate_by_client"): - command.append("--allow-terminate-by-client") - if config.get("enable_abort_handling"): - command.append("--enable-abort-handling") - - # Model configuration parameters - chat_template = config.get("chat_template") - if chat_template and str(chat_template).strip(): - command.extend(["--chat-template", str(chat_template).strip()]) - tool_call_parser = config.get("tool_call_parser") - if tool_call_parser and str(tool_call_parser).strip(): - command.extend(["--tool-call-parser", str(tool_call_parser).strip()]) - reasoning_parser = config.get("reasoning_parser") - if reasoning_parser and str(reasoning_parser).strip(): - command.extend(["--reasoning-parser", str(reasoning_parser).strip()]) - revision = config.get("revision") - if revision and str(revision).strip(): - command.extend(["--revision", str(revision).strip()]) - download_dir = config.get("download_dir") - if download_dir and str(download_dir).strip(): - command.extend(["--download-dir", str(download_dir).strip()]) - adapters = config.get("adapters") - if adapters: - if isinstance(adapters, list): - command.extend(["--adapters"] + [str(adapter) for adapter in adapters]) - elif isinstance(adapters, str): - command.extend(["--adapters", adapters]) - device = config.get("device") - if device and str(device).strip(): - command.extend(["--device", str(device).strip()]) - if config.get("eager_mode"): - command.append("--eager-mode") - if config.get("disable_vision_encoder"): - command.append("--disable-vision-encoder") - logprobs_mode = config.get("logprobs_mode") - if logprobs_mode is not None: - command.extend(["--logprobs-mode", str(logprobs_mode)]) - - # DLLM parameters - dllm_block_length = config.get("dllm_block_length") - if dllm_block_length is not None: - command.extend(["--dllm-block-length", str(int(dllm_block_length))]) - dllm_unmasking_strategy = config.get("dllm_unmasking_strategy") - if dllm_unmasking_strategy and str(dllm_unmasking_strategy).strip(): - command.extend( - ["--dllm-unmasking-strategy", str(dllm_unmasking_strategy).strip()] - ) - dllm_denoising_steps = config.get("dllm_denoising_steps") - if dllm_denoising_steps is not None: - command.extend(["--dllm-denoising-steps", str(int(dllm_denoising_steps))]) - dllm_confidence_threshold = config.get("dllm_confidence_threshold") - if dllm_confidence_threshold is not None: - command.extend( - ["--dllm-confidence-threshold", str(float(dllm_confidence_threshold))] - ) - - # Distributed/Multi-node parameters - dp = config.get("dp") - if dp is not None: - command.extend(["--dp", str(int(dp))]) - ep = config.get("ep") - if ep is not None: - command.extend(["--ep", str(int(ep))]) - if config.get("enable_microbatch"): - command.append("--enable-microbatch") - if config.get("enable_eplb"): - command.append("--enable-eplb") - role = config.get("role") - if role and str(role).strip(): - command.extend(["--role", str(role).strip()]) - migration_backend = config.get("migration_backend") - if migration_backend and str(migration_backend).strip(): - command.extend(["--migration-backend", str(migration_backend).strip()]) - node_rank = config.get("node_rank") - if node_rank is not None: - command.extend(["--node-rank", str(int(node_rank))]) - nnodes = config.get("nnodes") - if nnodes is not None: - command.extend(["--nnodes", str(int(nnodes))]) - cp = config.get("cp") - if cp is not None: - command.extend(["--cp", str(int(cp))]) - if config.get("enable_return_routed_experts"): - command.append("--enable-return-routed-experts") - distributed_executor_backend = config.get("distributed_executor_backend") - if distributed_executor_backend and str(distributed_executor_backend).strip(): - command.extend( - [ - "--distributed-executor-backend", - str(distributed_executor_backend).strip(), - ] - ) - - # Vision parameters - vision_max_batch_size = config.get("vision_max_batch_size") - if vision_max_batch_size is not None: - command.extend(["--vision-max-batch-size", str(int(vision_max_batch_size))]) - - # Speculative decoding parameters - speculative_algorithm = config.get("speculative_algorithm") - if speculative_algorithm and str(speculative_algorithm).strip(): - command.extend( - ["--speculative-algorithm", str(speculative_algorithm).strip()] - ) - speculative_draft_model = config.get("speculative_draft_model") - if speculative_draft_model and str(speculative_draft_model).strip(): - command.extend( - ["--speculative-draft-model", str(speculative_draft_model).strip()] - ) - speculative_num_draft_tokens = config.get("speculative_num_draft_tokens") - if speculative_num_draft_tokens is not None: - command.extend( - [ - "--speculative-num-draft-tokens", - str(int(speculative_num_draft_tokens)), - ] - ) - - additional_args = config.get("additional_args") - if isinstance(additional_args, str) and additional_args.strip(): - command.extend(shlex.split(additional_args.strip())) - - return command - - async def _wait_for_ready(self) -> None: - """Poll LMDeploy server until healthy or timeout.""" - start_time = asyncio.get_event_loop().time() - url = f"http://{self.host}:{self.port}/v1/models" - async with httpx.AsyncClient(timeout=5.0) as client: - while True: - if self._process and self._process.returncode not in (None, 0): - self._raise_with_logs( - f"LMDeploy exited unexpectedly with code {self._process.returncode}" - ) - try: - response = await client.get(url) - if response.status_code == 200: - self._last_health_status = { - "status": "ready", - "checked_at": datetime.utcnow().isoformat() + "Z", - } - return - except Exception as exc: - logger.debug(f"LMDeploy health check pending: {exc}") - if asyncio.get_event_loop().time() - start_time > self._health_timeout: - self._raise_with_logs( - "Timed out waiting for LMDeploy server to become ready" - ) - await asyncio.sleep(2) - - def _cleanup_process_state(self) -> None: - if self._log_file: - try: - self._log_file.close() - except Exception: - pass - self._log_file = None - self._process = None - self._current_instance = None - self._started_at = None - self._last_health_status = { - "status": "stopped", - "checked_at": datetime.utcnow().isoformat() + "Z", - } - - def read_log_tail(self, max_bytes: int = 8192) -> str: - """Return the tail of the lmdeploy log file for debugging.""" - try: - with open(self._log_path, "rb") as log_file: - log_file.seek(0, os.SEEK_END) - file_size = log_file.tell() - seek_pos = max(0, file_size - max_bytes) - log_file.seek(seek_pos) - data = log_file.read().decode("utf-8", errors="replace") - if seek_pos > 0: - # Remove potential partial first line - data = data.split("\n", 1)[-1] - return data.strip() - except Exception as exc: - logger.error(f"Failed to read LMDeploy log tail: {exc}") - return "" - - async def _broadcast_runtime_logs(self) -> None: - """Broadcast new runtime log lines via WebSocket.""" - try: - if not os.path.exists(self._log_path): - return - - # Read new content since last broadcast - current_size = os.path.getsize(self._log_path) - if current_size <= self._last_broadcast_log_position: - return # No new content - - # Read only new content - with open(self._log_path, "rb") as log_file: - log_file.seek(self._last_broadcast_log_position) - new_content = log_file.read().decode("utf-8", errors="replace") - self._last_broadcast_log_position = current_size - - if new_content: - # Split into lines and broadcast each non-empty line - lines = new_content.split('\n') - for line in lines: - if line.strip(): # Only send non-empty lines - await websocket_manager.send_lmdeploy_runtime_log(line.strip()) - except Exception as exc: - logger.debug(f"Failed to broadcast LMDeploy runtime logs: {exc}") - - def _read_log_tail(self, max_bytes: int = 8192) -> str: - """Private alias for backward compatibility.""" - return self.read_log_tail(max_bytes) - - def _raise_with_logs(self, message: str) -> None: - """Raise a runtime error that includes the recent LMDeploy logs.""" - log_tail = self.read_log_tail() - if log_tail: - logger.error( - f"{message}\n--- LMDeploy log tail ---\n{log_tail}\n--- end ---" - ) - raise RuntimeError(f"{message}. See logs for details.\n{log_tail}") - raise RuntimeError(message) - - def _detect_external_process(self) -> Optional[Dict[str, Any]]: - """Scan system processes for an LMDeploy server launched outside the manager.""" - try: - for proc in psutil.process_iter(attrs=["pid", "cmdline", "create_time"]): - cmdline: List[str] = proc.info.get("cmdline") or [] - if not cmdline: - continue - lowered = " ".join(cmdline).lower() - if "lmdeploy" not in lowered: - continue - if "serve" not in lowered or "api_server" not in lowered: - continue - - try: - api_server_idx = cmdline.index("api_server") - except ValueError: - continue - model_dir = ( - cmdline[api_server_idx + 1] - if len(cmdline) > api_server_idx + 1 - else None - ) - detection = { - "pid": proc.info["pid"], - "cmdline": cmdline, - "model_dir": model_dir, - "detected_at": datetime.utcnow().isoformat() + "Z", - } - - config = self._config_from_cmdline(cmdline) - model_entry = ( - self._lookup_model_by_dir(model_dir) if model_dir else None - ) - if model_entry: - self._ensure_running_instance_record(model_entry.id, config) - detection["instance"] = { - "model_id": model_entry.id, - "huggingface_id": model_entry.huggingface_id, - "file_path": model_entry.file_path, - "config": config, - "pid": proc.info["pid"], - "auto_detected": True, - } - detection["model_id"] = model_entry.id - detection["huggingface_id"] = model_entry.huggingface_id - else: - detection["instance"] = { - "model_id": None, - "huggingface_id": None, - "file_path": model_dir, - "config": config, - "pid": proc.info["pid"], - "auto_detected": True, - } - - started_at = proc.info.get("create_time") - if started_at: - detection["started_at"] = ( - datetime.utcfromtimestamp(started_at).isoformat() + "Z" - ) - else: - detection["started_at"] = datetime.utcnow().isoformat() + "Z" - return detection - except Exception as exc: - logger.debug(f"LMDeploy external scan failed: {exc}") - return None - - def _config_from_cmdline(self, cmdline: List[str]) -> Dict[str, Any]: - """Reconstruct a minimal config dict from lmdeploy CLI arguments.""" - - def _extract(flag: str, cast, default=None): - if flag in cmdline: - idx = cmdline.index(flag) - if idx + 1 < len(cmdline): - try: - return cast(cmdline[idx + 1]) - except (ValueError, TypeError): - return default - return default - - def _extract_list(flag: str, default=None): - """Extract list of values for flags that accept multiple arguments.""" - if flag not in cmdline: - return default - idx = cmdline.index(flag) - result = [] - i = idx + 1 - while i < len(cmdline) and not cmdline[i].startswith("--"): - result.append(cmdline[i]) - i += 1 - return result if result else default - - session_len = _extract("--session-len", int, DEFAULT_LMDEPLOY_CONTEXT) - max_prefill = _extract("--max-prefill-token-num", int, session_len) - # Note: --max-context-token-num doesn't exist in LMDeploy, so derive from session_len - max_context = session_len - - rope_scaling_factor = _extract("--rope-scaling-factor", float, 1.0) - rope_scaling_mode = "disabled" - if rope_scaling_factor and rope_scaling_factor > 1.0: - rope_scaling_mode = "detected" - - hf_overrides: Dict[str, Any] = {} - - def _assign_nested(target: Dict[str, Any], path: List[str], value: Any) -> None: - current = target - for segment in path[:-1]: - current = current.setdefault(segment, {}) - current[path[-1]] = value - - def _coerce_override_value(raw: str) -> Any: - lowered = raw.lower() - if lowered in {"true", "false"}: - return lowered == "true" - if lowered == "null": - return None - try: - if "." in raw: - return float(raw) - return int(raw) - except ValueError: - return raw - - i = 0 - while i < len(cmdline): - token = cmdline[i] - if token.startswith("--hf-overrides."): - path_str = token[len("--hf-overrides.") :] - if path_str and i + 1 < len(cmdline): - value = _coerce_override_value(cmdline[i + 1]) - _assign_nested(hf_overrides, path_str.split("."), value) - i += 2 - continue - i += 1 - - config = { - "session_len": session_len, - "tensor_parallel": _extract("--tp", int, 1), - "max_batch_size": _extract("--max-batch-size", int, 4), - "max_prefill_token_num": max_prefill, - "max_context_token_num": max_context, - "dtype": _extract("--dtype", str, "auto"), - "cache_max_entry_count": _extract("--cache-max-entry-count", float, 0.8), - "cache_block_seq_len": _extract("--cache-block-seq-len", int, 64), - "enable_prefix_caching": "--enable-prefix-caching" in cmdline, - "quant_policy": _extract("--quant-policy", int, 0), - "model_format": _extract("--model-format", str, ""), - "hf_overrides": hf_overrides or _extract("--hf-overrides", str, ""), - # LMDeploy uses --disable-metrics, so enable_metrics=True when flag is NOT present - "enable_metrics": "--disable-metrics" not in cmdline, - "rope_scaling_factor": rope_scaling_factor, - "rope_scaling_mode": rope_scaling_mode, - "num_tokens_per_iter": _extract("--num-tokens-per-iter", int, 0), - "max_prefill_iters": _extract("--max-prefill-iters", int, 1), - "communicator": _extract("--communicator", str, "nccl"), - "model_name": _extract("--model-name", str, ""), - # Server configuration - "allow_origins": _extract_list("--allow-origins"), - "allow_credentials": "--allow-credentials" in cmdline, - "allow_methods": _extract_list("--allow-methods"), - "allow_headers": _extract_list("--allow-headers"), - "proxy_url": _extract("--proxy-url", str, ""), - "max_concurrent_requests": _extract("--max-concurrent-requests", int), - "log_level": _extract("--log-level", str, ""), - "api_keys": _extract_list("--api-keys"), - "ssl": "--ssl" in cmdline, - "max_log_len": _extract("--max-log-len", int), - "disable_fastapi_docs": "--disable-fastapi-docs" in cmdline, - "allow_terminate_by_client": "--allow-terminate-by-client" in cmdline, - "enable_abort_handling": "--enable-abort-handling" in cmdline, - # Model configuration - "chat_template": _extract("--chat-template", str, ""), - "tool_call_parser": _extract("--tool-call-parser", str, ""), - "reasoning_parser": _extract("--reasoning-parser", str, ""), - "revision": _extract("--revision", str, ""), - "download_dir": _extract("--download-dir", str, ""), - "adapters": _extract_list("--adapters"), - "device": _extract("--device", str, ""), - "eager_mode": "--eager-mode" in cmdline, - "disable_vision_encoder": "--disable-vision-encoder" in cmdline, - "logprobs_mode": _extract("--logprobs-mode", str), - # DLLM parameters - "dllm_block_length": _extract("--dllm-block-length", int), - "dllm_unmasking_strategy": _extract("--dllm-unmasking-strategy", str, ""), - "dllm_denoising_steps": _extract("--dllm-denoising-steps", int), - "dllm_confidence_threshold": _extract("--dllm-confidence-threshold", float), - # Distributed/Multi-node parameters - "dp": _extract("--dp", int), - "ep": _extract("--ep", int), - "enable_microbatch": "--enable-microbatch" in cmdline, - "enable_eplb": "--enable-eplb" in cmdline, - "role": _extract("--role", str, ""), - "migration_backend": _extract("--migration-backend", str, ""), - "node_rank": _extract("--node-rank", int), - "nnodes": _extract("--nnodes", int), - "cp": _extract("--cp", int), - "enable_return_routed_experts": "--enable-return-routed-experts" in cmdline, - "distributed_executor_backend": _extract( - "--distributed-executor-backend", str, "" - ), - # Vision parameters - "vision_max_batch_size": _extract("--vision-max-batch-size", int), - # Speculative decoding parameters - "speculative_algorithm": _extract("--speculative-algorithm", str, ""), - "speculative_draft_model": _extract("--speculative-draft-model", str, ""), - "speculative_num_draft_tokens": _extract( - "--speculative-num-draft-tokens", int - ), - "additional_args": "", - } - - return config - - def _lookup_model_by_dir(self, model_dir: Optional[str]) -> Optional[Model]: - if not model_dir: - return None - db = SessionLocal() - try: - candidates = ( - db.query(Model).filter(Model.model_format == "safetensors").all() - ) - for candidate in candidates: - if ( - candidate.file_path - and os.path.dirname(candidate.file_path) == model_dir - ): - return candidate - finally: - db.close() - return None - - def _ensure_running_instance_record( - self, model_id: Optional[int], config: Dict[str, Any] - ) -> None: - if not model_id: - return - db = SessionLocal() - try: - existing = ( - db.query(RunningInstance) - .filter( - RunningInstance.model_id == model_id, - RunningInstance.runtime_type == "lmdeploy", - ) - .first() - ) - if existing: - return - instance = RunningInstance( - model_id=model_id, - llama_version="lmdeploy", - proxy_model_name=f"lmdeploy::{model_id}", - started_at=datetime.utcnow(), - config=json.dumps({"lmdeploy": config}), - runtime_type="lmdeploy", - ) - db.add(instance) - model = db.query(Model).filter(Model.id == model_id).first() - if model: - model.is_active = True - db.commit() - except Exception as exc: - logger.warning(f"Failed to create LMDeploy running instance record: {exc}") - db.rollback() - finally: - db.close() +import asyncio +import json +import os +import shlex +import shutil +from datetime import datetime +from typing import Optional, Dict, Any, List + +import httpx +import psutil +from asyncio.subprocess import Process, STDOUT + +from backend.logging_config import get_logger +from backend.data_store import get_store +from backend.huggingface import DEFAULT_LMDEPLOY_CONTEXT, MAX_LMDEPLOY_CONTEXT +from backend.progress_manager import get_progress_manager + +logger = get_logger(__name__) + +_lmdeploy_manager_instance: Optional["LMDeployManager"] = None + + +def get_lmdeploy_manager() -> "LMDeployManager": + """Return singleton LMDeploy manager.""" + global _lmdeploy_manager_instance + if _lmdeploy_manager_instance is None: + _lmdeploy_manager_instance = LMDeployManager() + return _lmdeploy_manager_instance + + +class LMDeployManager: + """Manage LMDeploy TurboMind runtime lifecycle.""" + + def __init__( + self, + binary_path: Optional[str] = None, + host: str = "0.0.0.0", + port: int = 2001, + ): + self.binary_path = binary_path or os.getenv("LMDEPLOY_BIN", "lmdeploy") + self.host = host + self.port = int(os.getenv("LMDEPLOY_PORT", port)) + self._process: Optional[Process] = None + self._log_file = None + self._lock = asyncio.Lock() + self._current_instance: Optional[Dict[str, Any]] = None + self._started_at: Optional[str] = None + self._log_path = os.path.join("data", "logs", "lmdeploy.log") + self._health_timeout = 180 # seconds + self._last_health_status: Optional[Dict[str, Any]] = None + self._last_detected_external: Optional[Dict[str, Any]] = None + self._last_broadcast_log_position = 0 + + async def start( + self, model_entry: Dict[str, Any], config: Dict[str, Any] + ) -> Dict[str, Any]: + """Start LMDeploy serving the provided model. Only one model may run at once.""" + async with self._lock: + if self._process and self._process.returncode is None: + raise RuntimeError("LMDeploy runtime is already running") + + model_path = model_entry.get("file_path") + if not model_path or not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found at {model_path}") + model_dir = model_entry.get("model_dir") or os.path.dirname(model_path) + if not os.path.isdir(model_dir): + raise FileNotFoundError(f"Model directory not found at {model_dir}") + model_dir_abs = os.path.abspath(model_dir) + + # Derive a stable model name for LMDeploy's --model-name flag. + # Preference order: + # 1) Explicit model_name passed in model_entry + # 2) Base model / display name from model_entry + # 3) Hugging Face repo id + # 4) Directory name + model_name = ( + model_entry.get("model_name") + or model_entry.get("display_name") + or model_entry.get("huggingface_id") + or os.path.basename(model_dir_abs.rstrip(os.sep)) + ) + + # Inject model_name into config passed to LMDeploy so the command builder + # can add --model-name and we persist it in status/config reflection. + effective_config = dict(config or {}) + if model_name and not effective_config.get("model_name"): + effective_config["model_name"] = model_name + + binary = self._resolve_binary() + command = self._build_command(binary, model_dir_abs, effective_config) + env = os.environ.copy() + env.setdefault("LMDEPLOY_LOG_DIR", os.path.dirname(self._log_path)) + os.makedirs(os.path.dirname(self._log_path), exist_ok=True) + self._log_file = open(self._log_path, "ab", buffering=0) + + logger.info(f"Starting LMDeploy with command: {' '.join(command)}") + self._process = await asyncio.create_subprocess_exec( + *command, + stdout=self._log_file, + stderr=STDOUT, + cwd=model_dir_abs, + env=env, + ) + self._started_at = datetime.utcnow().isoformat() + "Z" + self._current_instance = { + "model_id": model_entry.get("model_id"), + "huggingface_id": model_entry.get("huggingface_id"), + "file_path": model_path, + "config": effective_config, + "pid": self._process.pid, + } + + try: + await self._wait_for_ready() + except Exception as exc: + await self.stop(force=True) + raise exc + + return self.status() + + async def stop(self, force: bool = False) -> None: + """Stop LMDeploy process if running.""" + async with self._lock: + if not self._process: + return + if self._process.returncode is None: + try: + self._process.terminate() + await asyncio.wait_for(self._process.wait(), timeout=30) + except asyncio.TimeoutError: + logger.warning( + "LMDeploy did not terminate gracefully; killing process" + ) + self._process.kill() + await self._process.wait() + except ProcessLookupError: + logger.debug("LMDeploy process already stopped") + elif force: + try: + self._process.kill() + except ProcessLookupError: + pass + self._cleanup_process_state() + + async def restart( + self, model_entry: Dict[str, Any], config: Dict[str, Any] + ) -> Dict[str, Any]: + """Restart LMDeploy with a new model/config.""" + await self.stop() + return await self.start(model_entry, config) + + def status(self) -> Dict[str, Any]: + """Return status payload describing the running instance.""" + running = bool(self._process and self._process.returncode is None) + detection = None + if not running: + detection = self._detect_external_process() + if detection: + running = True + self._last_detected_external = detection + if not self._current_instance: + self._current_instance = detection.get("instance") + if not self._started_at: + self._started_at = detection.get("started_at") + else: + self._last_detected_external = None + else: + self._last_detected_external = None + + return { + "running": running, + "port": self.port, + "host": self.host, + "process_id": self._process.pid if running else None, + "started_at": self._started_at, + "current_instance": self._current_instance if running else None, + "health": self._last_health_status, + "binary_path": self._current_binary_path(), + "log_path": self._log_path, + "auto_detected": bool(detection), + "detection": detection, + } + + def _current_binary_path(self) -> Optional[str]: + try: + return self._resolve_binary() + except FileNotFoundError: + return None + + def _resolve_binary(self) -> str: + try: + from backend.lmdeploy_installer import get_lmdeploy_installer + + installer_binary = get_lmdeploy_installer().status().get("binary_path") + if installer_binary and os.path.exists(installer_binary): + return installer_binary + except Exception as exc: + logger.debug( + f"Failed to resolve LMDeploy binary via installer status: {exc}" + ) + + resolved = shutil.which(self.binary_path) + if resolved: + return resolved + + candidate = os.path.expanduser(self.binary_path) + if os.path.isabs(candidate) and os.path.exists(candidate): + return candidate + raise FileNotFoundError( + "LMDeploy binary not found in PATH. Install LMDeploy from the LMDeploy page or set LMDEPLOY_BIN." + ) + + def _build_command( + self, binary: str, model_dir: str, config: Dict[str, Any] + ) -> list: + """Convert stored config into lmdeploy CLI arguments.""" + tensor_parallel = max(1, int(config.get("tensor_parallel") or 1)) + base_session_len = max( + 1024, + int( + config.get("session_len") + or config.get("context_length") + or DEFAULT_LMDEPLOY_CONTEXT + ), + ) + rope_scaling_mode = str(config.get("rope_scaling_mode") or "disabled").lower() + rope_scaling_factor = float(config.get("rope_scaling_factor") or 1.0) + scaling_enabled = ( + rope_scaling_mode not in {"", "none", "disabled"} + and rope_scaling_factor > 1.0 + ) + effective_session_len = base_session_len + if scaling_enabled: + scaled = int(base_session_len * rope_scaling_factor) + effective_session_len = max( + base_session_len, min(scaled, MAX_LMDEPLOY_CONTEXT) + ) + max_batch_size = max(1, int(config.get("max_batch_size") or 4)) + base_prefill = int( + config.get("max_prefill_token_num") + or config.get("max_batch_tokens") + or (base_session_len * 2) + ) + if scaling_enabled: + scaled_prefill = int(base_prefill * rope_scaling_factor) + max_prefill_token_num = scaled_prefill + else: + max_prefill_token_num = base_prefill + + command = [ + binary, + "serve", + "api_server", + model_dir, + "--backend", + "turbomind", + "--server-name", + self.host, + "--server-port", + str(self.port), + "--tp", + str(tensor_parallel), + "--session-len", + str(effective_session_len), + "--max-batch-size", + str(max_batch_size), + ] + + # Optional model identity for OpenAI-style /v1/models listing + model_name = config.get("model_name") + if model_name and str(model_name).strip(): + command.extend(["--model-name", str(model_name).strip()]) + + # Optional inference settings + dtype = config.get("dtype") + if dtype and str(dtype).strip(): + command.extend(["--dtype", str(dtype).strip()]) + if max_prefill_token_num: + command.extend(["--max-prefill-token-num", str(max_prefill_token_num)]) + cache_max_entry_count = config.get("cache_max_entry_count") + if cache_max_entry_count is not None: + command.extend(["--cache-max-entry-count", str(cache_max_entry_count)]) + cache_block_seq_len = config.get("cache_block_seq_len") + if cache_block_seq_len: + command.extend(["--cache-block-seq-len", str(cache_block_seq_len)]) + if config.get("enable_prefix_caching"): + command.append("--enable-prefix-caching") + quant_policy = config.get("quant_policy") + if quant_policy is not None: + command.extend(["--quant-policy", str(quant_policy)]) + model_format = config.get("model_format") + if model_format and str(model_format).strip(): + command.extend(["--model-format", str(model_format).strip()]) + hf_overrides = config.get("hf_overrides") + if isinstance(hf_overrides, dict) and hf_overrides: + + def _flatten(prefix: str, value: Any): + if isinstance(value, dict): + for key, nested in value.items(): + if not isinstance(key, str) or not key: + continue + new_prefix = f"{prefix}.{key}" if prefix else key + yield from _flatten(new_prefix, nested) + else: + yield prefix, value + + def _format_override_value(val: Any) -> str: + if isinstance(val, bool): + return "true" if val else "false" + if val is None: + return "null" + return str(val) + + for path, value in _flatten("", hf_overrides): + if not path: + continue + command.extend( + [f"--hf-overrides.{path}", _format_override_value(value)] + ) + elif isinstance(hf_overrides, str) and hf_overrides.strip(): + command.extend(["--hf-overrides", hf_overrides.strip()]) + # LMDeploy uses --disable-metrics (inverted logic) + # When enable_metrics=false, send --disable-metrics + # When enable_metrics=true (default), don't send anything (metrics enabled by default) + if not config.get("enable_metrics", True): + command.append("--disable-metrics") + if scaling_enabled: + command.extend(["--rope-scaling-factor", str(rope_scaling_factor)]) + num_tokens_per_iter = config.get("num_tokens_per_iter") + if num_tokens_per_iter: + command.extend(["--num-tokens-per-iter", str(num_tokens_per_iter)]) + max_prefill_iters = config.get("max_prefill_iters") + if max_prefill_iters: + command.extend(["--max-prefill-iters", str(max_prefill_iters)]) + communicator = config.get("communicator") + if communicator and str(communicator).strip(): + command.extend(["--communicator", str(communicator).strip()]) + + # Server configuration parameters + allow_origins = config.get("allow_origins") + if allow_origins: + if isinstance(allow_origins, list): + command.extend( + ["--allow-origins"] + [str(origin) for origin in allow_origins] + ) + elif isinstance(allow_origins, str): + command.extend(["--allow-origins", allow_origins]) + if config.get("allow_credentials"): + command.append("--allow-credentials") + allow_methods = config.get("allow_methods") + if allow_methods: + if isinstance(allow_methods, list): + command.extend( + ["--allow-methods"] + [str(method) for method in allow_methods] + ) + elif isinstance(allow_methods, str): + command.extend(["--allow-methods", allow_methods]) + allow_headers = config.get("allow_headers") + if allow_headers: + if isinstance(allow_headers, list): + command.extend( + ["--allow-headers"] + [str(header) for header in allow_headers] + ) + elif isinstance(allow_headers, str): + command.extend(["--allow-headers", allow_headers]) + proxy_url = config.get("proxy_url") + if proxy_url and str(proxy_url).strip(): + command.extend(["--proxy-url", str(proxy_url).strip()]) + max_concurrent_requests = config.get("max_concurrent_requests") + if max_concurrent_requests is not None: + command.extend( + ["--max-concurrent-requests", str(int(max_concurrent_requests))] + ) + log_level = config.get("log_level") + if log_level and str(log_level).strip(): + command.extend(["--log-level", str(log_level).strip()]) + api_keys = config.get("api_keys") + if api_keys: + if isinstance(api_keys, list): + command.extend(["--api-keys"] + [str(key) for key in api_keys]) + elif isinstance(api_keys, str): + command.extend(["--api-keys", api_keys]) + if config.get("ssl"): + command.append("--ssl") + max_log_len = config.get("max_log_len") + if max_log_len is not None: + command.extend(["--max-log-len", str(int(max_log_len))]) + if config.get("disable_fastapi_docs"): + command.append("--disable-fastapi-docs") + if config.get("allow_terminate_by_client"): + command.append("--allow-terminate-by-client") + if config.get("enable_abort_handling"): + command.append("--enable-abort-handling") + + # Model configuration parameters + chat_template = config.get("chat_template") + if chat_template and str(chat_template).strip(): + command.extend(["--chat-template", str(chat_template).strip()]) + tool_call_parser = config.get("tool_call_parser") + if tool_call_parser and str(tool_call_parser).strip(): + command.extend(["--tool-call-parser", str(tool_call_parser).strip()]) + reasoning_parser = config.get("reasoning_parser") + if reasoning_parser and str(reasoning_parser).strip(): + command.extend(["--reasoning-parser", str(reasoning_parser).strip()]) + revision = config.get("revision") + if revision and str(revision).strip(): + command.extend(["--revision", str(revision).strip()]) + download_dir = config.get("download_dir") + if download_dir and str(download_dir).strip(): + command.extend(["--download-dir", str(download_dir).strip()]) + adapters = config.get("adapters") + if adapters: + if isinstance(adapters, list): + command.extend(["--adapters"] + [str(adapter) for adapter in adapters]) + elif isinstance(adapters, str): + command.extend(["--adapters", adapters]) + device = config.get("device") + if device and str(device).strip(): + command.extend(["--device", str(device).strip()]) + if config.get("eager_mode"): + command.append("--eager-mode") + if config.get("disable_vision_encoder"): + command.append("--disable-vision-encoder") + logprobs_mode = config.get("logprobs_mode") + if logprobs_mode is not None: + command.extend(["--logprobs-mode", str(logprobs_mode)]) + + # DLLM parameters + dllm_block_length = config.get("dllm_block_length") + if dllm_block_length is not None: + command.extend(["--dllm-block-length", str(int(dllm_block_length))]) + dllm_unmasking_strategy = config.get("dllm_unmasking_strategy") + if dllm_unmasking_strategy and str(dllm_unmasking_strategy).strip(): + command.extend( + ["--dllm-unmasking-strategy", str(dllm_unmasking_strategy).strip()] + ) + dllm_denoising_steps = config.get("dllm_denoising_steps") + if dllm_denoising_steps is not None: + command.extend(["--dllm-denoising-steps", str(int(dllm_denoising_steps))]) + dllm_confidence_threshold = config.get("dllm_confidence_threshold") + if dllm_confidence_threshold is not None: + command.extend( + ["--dllm-confidence-threshold", str(float(dllm_confidence_threshold))] + ) + + # Distributed/Multi-node parameters + dp = config.get("dp") + if dp is not None: + command.extend(["--dp", str(int(dp))]) + ep = config.get("ep") + if ep is not None: + command.extend(["--ep", str(int(ep))]) + if config.get("enable_microbatch"): + command.append("--enable-microbatch") + if config.get("enable_eplb"): + command.append("--enable-eplb") + role = config.get("role") + if role and str(role).strip(): + command.extend(["--role", str(role).strip()]) + migration_backend = config.get("migration_backend") + if migration_backend and str(migration_backend).strip(): + command.extend(["--migration-backend", str(migration_backend).strip()]) + node_rank = config.get("node_rank") + if node_rank is not None: + command.extend(["--node-rank", str(int(node_rank))]) + nnodes = config.get("nnodes") + if nnodes is not None: + command.extend(["--nnodes", str(int(nnodes))]) + cp = config.get("cp") + if cp is not None: + command.extend(["--cp", str(int(cp))]) + if config.get("enable_return_routed_experts"): + command.append("--enable-return-routed-experts") + distributed_executor_backend = config.get("distributed_executor_backend") + if distributed_executor_backend and str(distributed_executor_backend).strip(): + command.extend( + [ + "--distributed-executor-backend", + str(distributed_executor_backend).strip(), + ] + ) + + # Vision parameters + vision_max_batch_size = config.get("vision_max_batch_size") + if vision_max_batch_size is not None: + command.extend(["--vision-max-batch-size", str(int(vision_max_batch_size))]) + + # Speculative decoding parameters + speculative_algorithm = config.get("speculative_algorithm") + if speculative_algorithm and str(speculative_algorithm).strip(): + command.extend( + ["--speculative-algorithm", str(speculative_algorithm).strip()] + ) + speculative_draft_model = config.get("speculative_draft_model") + if speculative_draft_model and str(speculative_draft_model).strip(): + command.extend( + ["--speculative-draft-model", str(speculative_draft_model).strip()] + ) + speculative_num_draft_tokens = config.get("speculative_num_draft_tokens") + if speculative_num_draft_tokens is not None: + command.extend( + [ + "--speculative-num-draft-tokens", + str(int(speculative_num_draft_tokens)), + ] + ) + + additional_args = config.get("additional_args") + if isinstance(additional_args, str) and additional_args.strip(): + command.extend(shlex.split(additional_args.strip())) + + return command + + async def _wait_for_ready(self) -> None: + """Poll LMDeploy server until healthy or timeout.""" + start_time = asyncio.get_event_loop().time() + url = f"http://{self.host}:{self.port}/v1/models" + async with httpx.AsyncClient(timeout=5.0) as client: + while True: + if self._process and self._process.returncode not in (None, 0): + self._raise_with_logs( + f"LMDeploy exited unexpectedly with code {self._process.returncode}" + ) + try: + response = await client.get(url) + if response.status_code == 200: + self._last_health_status = { + "status": "ready", + "checked_at": datetime.utcnow().isoformat() + "Z", + } + return + except Exception as exc: + logger.debug(f"LMDeploy health check pending: {exc}") + if asyncio.get_event_loop().time() - start_time > self._health_timeout: + self._raise_with_logs( + "Timed out waiting for LMDeploy server to become ready" + ) + await asyncio.sleep(2) + + def _cleanup_process_state(self) -> None: + if self._log_file: + try: + self._log_file.close() + except Exception: + pass + self._log_file = None + self._process = None + self._current_instance = None + self._started_at = None + self._last_health_status = { + "status": "stopped", + "checked_at": datetime.utcnow().isoformat() + "Z", + } + + def read_log_tail(self, max_bytes: int = 8192) -> str: + """Return the tail of the lmdeploy log file for debugging.""" + try: + with open(self._log_path, "rb") as log_file: + log_file.seek(0, os.SEEK_END) + file_size = log_file.tell() + seek_pos = max(0, file_size - max_bytes) + log_file.seek(seek_pos) + data = log_file.read().decode("utf-8", errors="replace") + if seek_pos > 0: + # Remove potential partial first line + data = data.split("\n", 1)[-1] + return data.strip() + except Exception as exc: + logger.error(f"Failed to read LMDeploy log tail: {exc}") + return "" + + async def _broadcast_runtime_logs(self) -> None: + """Broadcast new runtime log lines via SSE.""" + try: + if not os.path.exists(self._log_path): + return + + # Read new content since last broadcast + current_size = os.path.getsize(self._log_path) + if current_size <= self._last_broadcast_log_position: + return # No new content + + # Read only new content + with open(self._log_path, "rb") as log_file: + log_file.seek(self._last_broadcast_log_position) + new_content = log_file.read().decode("utf-8", errors="replace") + self._last_broadcast_log_position = current_size + + if new_content: + # Split into lines and broadcast each non-empty line via SSE + lines = new_content.split('\n') + for line in lines: + if line.strip(): + get_progress_manager().emit("lmdeploy_runtime_log", {"line": line.strip(), "timestamp": datetime.utcnow().isoformat()}) + except Exception as exc: + logger.debug(f"Failed to broadcast LMDeploy runtime logs: {exc}") + + def _read_log_tail(self, max_bytes: int = 8192) -> str: + """Private alias for backward compatibility.""" + return self.read_log_tail(max_bytes) + + def _raise_with_logs(self, message: str) -> None: + """Raise a runtime error that includes the recent LMDeploy logs.""" + log_tail = self.read_log_tail() + if log_tail: + logger.error( + f"{message}\n--- LMDeploy log tail ---\n{log_tail}\n--- end ---" + ) + raise RuntimeError(f"{message}. See logs for details.\n{log_tail}") + raise RuntimeError(message) + + def _detect_external_process(self) -> Optional[Dict[str, Any]]: + """Scan system processes for an LMDeploy server launched outside the manager.""" + try: + for proc in psutil.process_iter(attrs=["pid", "cmdline", "create_time"]): + cmdline: List[str] = proc.info.get("cmdline") or [] + if not cmdline: + continue + lowered = " ".join(cmdline).lower() + if "lmdeploy" not in lowered: + continue + if "serve" not in lowered or "api_server" not in lowered: + continue + + try: + api_server_idx = cmdline.index("api_server") + except ValueError: + continue + model_dir = ( + cmdline[api_server_idx + 1] + if len(cmdline) > api_server_idx + 1 + else None + ) + detection = { + "pid": proc.info["pid"], + "cmdline": cmdline, + "model_dir": model_dir, + "detected_at": datetime.utcnow().isoformat() + "Z", + } + + config = self._config_from_cmdline(cmdline) + model_entry = ( + self._lookup_model_by_dir(model_dir) if model_dir else None + ) + if model_entry: + self._ensure_running_instance_record(model_entry.get("id"), config) + detection["instance"] = { + "model_id": model_entry.get("id"), + "huggingface_id": model_entry.get("huggingface_id"), + "file_path": model_entry.get("file_path"), + "config": config, + "pid": proc.info["pid"], + "auto_detected": True, + } + detection["model_id"] = model_entry.get("id") + detection["huggingface_id"] = model_entry.get("huggingface_id") + else: + detection["instance"] = { + "model_id": None, + "huggingface_id": None, + "file_path": model_dir, + "config": config, + "pid": proc.info["pid"], + "auto_detected": True, + } + + started_at = proc.info.get("create_time") + if started_at: + detection["started_at"] = ( + datetime.utcfromtimestamp(started_at).isoformat() + "Z" + ) + else: + detection["started_at"] = datetime.utcnow().isoformat() + "Z" + return detection + except Exception as exc: + logger.debug(f"LMDeploy external scan failed: {exc}") + return None + + def _config_from_cmdline(self, cmdline: List[str]) -> Dict[str, Any]: + """Reconstruct a minimal config dict from lmdeploy CLI arguments.""" + + def _extract(flag: str, cast, default=None): + if flag in cmdline: + idx = cmdline.index(flag) + if idx + 1 < len(cmdline): + try: + return cast(cmdline[idx + 1]) + except (ValueError, TypeError): + return default + return default + + def _extract_list(flag: str, default=None): + """Extract list of values for flags that accept multiple arguments.""" + if flag not in cmdline: + return default + idx = cmdline.index(flag) + result = [] + i = idx + 1 + while i < len(cmdline) and not cmdline[i].startswith("--"): + result.append(cmdline[i]) + i += 1 + return result if result else default + + session_len = _extract("--session-len", int, DEFAULT_LMDEPLOY_CONTEXT) + max_prefill = _extract("--max-prefill-token-num", int, session_len) + # Note: --max-context-token-num doesn't exist in LMDeploy, so derive from session_len + max_context = session_len + + rope_scaling_factor = _extract("--rope-scaling-factor", float, 1.0) + rope_scaling_mode = "disabled" + if rope_scaling_factor and rope_scaling_factor > 1.0: + rope_scaling_mode = "detected" + + hf_overrides: Dict[str, Any] = {} + + def _assign_nested(target: Dict[str, Any], path: List[str], value: Any) -> None: + current = target + for segment in path[:-1]: + current = current.setdefault(segment, {}) + current[path[-1]] = value + + def _coerce_override_value(raw: str) -> Any: + lowered = raw.lower() + if lowered in {"true", "false"}: + return lowered == "true" + if lowered == "null": + return None + try: + if "." in raw: + return float(raw) + return int(raw) + except ValueError: + return raw + + i = 0 + while i < len(cmdline): + token = cmdline[i] + if token.startswith("--hf-overrides."): + path_str = token[len("--hf-overrides.") :] + if path_str and i + 1 < len(cmdline): + value = _coerce_override_value(cmdline[i + 1]) + _assign_nested(hf_overrides, path_str.split("."), value) + i += 2 + continue + i += 1 + + config = { + "session_len": session_len, + "tensor_parallel": _extract("--tp", int, 1), + "max_batch_size": _extract("--max-batch-size", int, 4), + "max_prefill_token_num": max_prefill, + "max_context_token_num": max_context, + "dtype": _extract("--dtype", str, "auto"), + "cache_max_entry_count": _extract("--cache-max-entry-count", float, 0.8), + "cache_block_seq_len": _extract("--cache-block-seq-len", int, 64), + "enable_prefix_caching": "--enable-prefix-caching" in cmdline, + "quant_policy": _extract("--quant-policy", int, 0), + "model_format": _extract("--model-format", str, ""), + "hf_overrides": hf_overrides or _extract("--hf-overrides", str, ""), + # LMDeploy uses --disable-metrics, so enable_metrics=True when flag is NOT present + "enable_metrics": "--disable-metrics" not in cmdline, + "rope_scaling_factor": rope_scaling_factor, + "rope_scaling_mode": rope_scaling_mode, + "num_tokens_per_iter": _extract("--num-tokens-per-iter", int, 0), + "max_prefill_iters": _extract("--max-prefill-iters", int, 1), + "communicator": _extract("--communicator", str, "nccl"), + "model_name": _extract("--model-name", str, ""), + # Server configuration + "allow_origins": _extract_list("--allow-origins"), + "allow_credentials": "--allow-credentials" in cmdline, + "allow_methods": _extract_list("--allow-methods"), + "allow_headers": _extract_list("--allow-headers"), + "proxy_url": _extract("--proxy-url", str, ""), + "max_concurrent_requests": _extract("--max-concurrent-requests", int), + "log_level": _extract("--log-level", str, ""), + "api_keys": _extract_list("--api-keys"), + "ssl": "--ssl" in cmdline, + "max_log_len": _extract("--max-log-len", int), + "disable_fastapi_docs": "--disable-fastapi-docs" in cmdline, + "allow_terminate_by_client": "--allow-terminate-by-client" in cmdline, + "enable_abort_handling": "--enable-abort-handling" in cmdline, + # Model configuration + "chat_template": _extract("--chat-template", str, ""), + "tool_call_parser": _extract("--tool-call-parser", str, ""), + "reasoning_parser": _extract("--reasoning-parser", str, ""), + "revision": _extract("--revision", str, ""), + "download_dir": _extract("--download-dir", str, ""), + "adapters": _extract_list("--adapters"), + "device": _extract("--device", str, ""), + "eager_mode": "--eager-mode" in cmdline, + "disable_vision_encoder": "--disable-vision-encoder" in cmdline, + "logprobs_mode": _extract("--logprobs-mode", str), + # DLLM parameters + "dllm_block_length": _extract("--dllm-block-length", int), + "dllm_unmasking_strategy": _extract("--dllm-unmasking-strategy", str, ""), + "dllm_denoising_steps": _extract("--dllm-denoising-steps", int), + "dllm_confidence_threshold": _extract("--dllm-confidence-threshold", float), + # Distributed/Multi-node parameters + "dp": _extract("--dp", int), + "ep": _extract("--ep", int), + "enable_microbatch": "--enable-microbatch" in cmdline, + "enable_eplb": "--enable-eplb" in cmdline, + "role": _extract("--role", str, ""), + "migration_backend": _extract("--migration-backend", str, ""), + "node_rank": _extract("--node-rank", int), + "nnodes": _extract("--nnodes", int), + "cp": _extract("--cp", int), + "enable_return_routed_experts": "--enable-return-routed-experts" in cmdline, + "distributed_executor_backend": _extract( + "--distributed-executor-backend", str, "" + ), + # Vision parameters + "vision_max_batch_size": _extract("--vision-max-batch-size", int), + # Speculative decoding parameters + "speculative_algorithm": _extract("--speculative-algorithm", str, ""), + "speculative_draft_model": _extract("--speculative-draft-model", str, ""), + "speculative_num_draft_tokens": _extract( + "--speculative-num-draft-tokens", int + ), + "additional_args": "", + } + + return config + + def _lookup_model_by_dir(self, model_dir: Optional[str]) -> Optional[Dict[str, Any]]: + if not model_dir: + return None + store = get_store() + for candidate in store.list_models(): + if (candidate.get("format") or candidate.get("model_format")) != "safetensors": + continue + fp = candidate.get("file_path") + if fp and os.path.dirname(fp) == model_dir: + return candidate + return None + + def _ensure_running_instance_record( + self, model_id: Optional[Any], config: Dict[str, Any] + ) -> None: + # No-op: running state is not persisted to DB (Phase 1 YAML store) + pass diff --git a/backend/main.py b/backend/main.py index 40fc33b..c692ee4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,13 +3,13 @@ import uvicorn import time from datetime import datetime -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request +from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from contextlib import asynccontextmanager -from backend.database import init_db, LlamaVersion +from backend.data_store import get_store from backend.routes import ( models, llama_versions, @@ -18,9 +18,7 @@ llama_version_manager, lmdeploy, ) -from backend.websocket_manager import websocket_manager from backend.huggingface import set_huggingface_token -from backend.unified_monitor import unified_monitor from backend.logging_config import setup_logging, get_logger from backend.lmdeploy_installer import get_lmdeploy_installer from backend.lmdeploy_manager import get_lmdeploy_manager @@ -38,7 +36,7 @@ def ensure_data_directories(): else: data_dir = "data" - subdirs = ["models", "configs", "logs", "llama-cpp", "lmdeploy", "temp"] + subdirs = ["config", "configs", "logs", "llama-cpp", "lmdeploy", "temp"] try: # Ensure main data directory exists @@ -70,7 +68,8 @@ def ensure_data_directories(): logger.info(f"Data directory {data_dir} is writable") except PermissionError as e: logger.error(f"Data directory {data_dir} is not writable: {e}") - logger.warning(f"Current user: {os.getuid() if hasattr(os, 'getuid') else 'unknown'}, directory owner check needed")("Attempting to fix permissions...") + logger.warning(f"Current user: {os.getuid() if hasattr(os, 'getuid') else 'unknown'}, directory owner check needed") + logger.warning("Attempting to fix permissions...") # Try to fix permissions (may fail if not running as root) try: import stat @@ -95,74 +94,72 @@ def ensure_data_directories(): async def register_all_models_with_llama_swap(): """Register all downloaded models with llama-swap on startup""" - from backend.database import SessionLocal, Model - from backend.llama_manager import LlamaManager - - db = SessionLocal() - try: - # Get all downloaded models - models = db.query(Model).all() - if not models: - logger.info("No models found to register with llama-swap") - return - - logger.info(f"Found {len(models)} models to register with llama-swap") - - llama_server_path = None - # Get llama-server path from active version - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - if active_version and os.path.exists(active_version.binary_path): - llama_server_path = active_version.binary_path - logger.info(f"Using active llama-cpp version: {active_version.version}") - else: - # Fallback: try to find llama-server in the llama-cpp directory - llama_cpp_dir = ( - "data/llama-cpp" if os.path.exists("data") else "/app/data/llama-cpp" - ) - if os.path.exists(llama_cpp_dir): - for version_dir in os.listdir(llama_cpp_dir): - server_path = os.path.join( - llama_cpp_dir, version_dir, "build", "bin", "llama-server" - ) - if os.path.exists(server_path) and os.access(server_path, os.X_OK): - llama_server_path = server_path - logger.info(f"Found llama-server at: {llama_server_path}") - break - - if not llama_server_path: - logger.warning("llama-server not found, skipping model registration") - return - - # Register each model with llama-swap (without binary path) - for model in models: - try: - # Create a basic config for the model - config = { - "model": model.file_path, - "host": "0.0.0.0", - "ctx_size": 2048, - "batch_size": 512, - "threads": 4, - } - - # Register with llama-swap (no binary path needed) - proxy_name = await llama_swap_manager.register_model(model, config) - logger.info( - f"Registered model '{model.name}' as '{proxy_name}' with llama-swap" + store = get_store() + model_list = store.list_models() + if not model_list: + logger.info("No models found to register with llama-swap") + return + + logger.info(f"Found {len(model_list)} models to register with llama-swap") + + llama_server_path = None + for engine in ("llama_cpp", "ik_llama"): + active_version = store.get_active_engine_version(engine) + if active_version and active_version.get("binary_path"): + path = active_version["binary_path"] + if os.path.isabs(path) and os.path.exists(path): + llama_server_path = path + else: + abs_path = os.path.abspath(path) + if os.path.exists(abs_path): + llama_server_path = abs_path + if llama_server_path: + logger.info(f"Using active {engine} version: {active_version.get('version')}") + break + + if not llama_server_path: + llama_cpp_dir = "data/llama-cpp" if os.path.exists("data") else "/app/data/llama-cpp" + if os.path.exists(llama_cpp_dir): + for version_dir in os.listdir(llama_cpp_dir): + server_path = os.path.join( + llama_cpp_dir, version_dir, "build", "bin", "llama-server" ) + if os.path.exists(server_path) and os.access(server_path, os.X_OK): + llama_server_path = server_path + logger.info(f"Found llama-server at: {llama_server_path}") + break + + if not llama_server_path: + logger.warning("llama-server not found, skipping model registration") + return + + from backend.routes.models import _get_model_file_path + from backend.data_store import generate_proxy_name + + for model in model_list: + file_path = _get_model_file_path(model) + if not file_path or not os.path.exists(file_path): + logger.debug(f"Model '{model.get('id')}' not found in HF cache, skipping") + continue + try: + proxy_name = generate_proxy_name( + model.get("huggingface_id", ""), + model.get("quantization"), + ) + config = (model.get("config") or {}).copy() + config.setdefault("host", "0.0.0.0") + config.setdefault("ctx_size", 2048) + config.setdefault("batch_size", 512) + config.setdefault("threads", 4) + model_with_proxy = dict(model, proxy_name=proxy_name) + await llama_swap_manager.register_model(model_with_proxy, config) + logger.info( + f"Registered model '{model.get('display_name', model.get('id'))}' as '{proxy_name}' with llama-swap" + ) + except Exception as e: + logger.error(f"Failed to register model '{model.get('id')}' with llama-swap: {e}") - except Exception as e: - logger.error( - f"Failed to register model '{model.name}' with llama-swap: {e}" - ) - - # Generate config with the active version - await llama_swap_manager.regenerate_config_with_active_version() - - finally: - db.close() + await llama_swap_manager.regenerate_config_with_active_version() @asynccontextmanager @@ -171,7 +168,7 @@ async def lifespan(app: FastAPI): # Startup ensure_data_directories() - await init_db() + get_store() # Ensure YAML config files exist huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY") if huggingface_api_key: @@ -182,15 +179,20 @@ async def lifespan(app: FastAPI): llama_swap_manager = get_llama_swap_manager() - from backend.database import SessionLocal, LlamaVersion, RunningInstance, Model - - session = SessionLocal() - active_version = ( - session.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - session.close() - - if active_version and active_version.binary_path: + store = get_store() + active_version = None + for engine in ("llama_cpp", "ik_llama"): + v = store.get_active_engine_version(engine) + if v and v.get("binary_path"): + path = v["binary_path"] + if os.path.isabs(path) and os.path.exists(path): + active_version = v + break + if os.path.exists(os.path.abspath(path)): + active_version = v + break + + if active_version and active_version.get("binary_path"): try: await llama_swap_manager.start_proxy() logger.info("llama-swap proxy started on port 2000") @@ -203,64 +205,14 @@ async def lifespan(app: FastAPI): "Install or activate a llama.cpp build to enable multi-model serving." ) - db = SessionLocal() - try: - stale_instances = db.query(RunningInstance).all() - if stale_instances: - logger.info(f"Cleaning {len(stale_instances)} stale instances") - for instance in stale_instances: - model = db.query(Model).filter(Model.id == instance.model_id).first() - if model: - model.is_active = False - db.delete(instance) - db.commit() - finally: - db.close() - try: await register_all_models_with_llama_swap() except Exception as e: logger.error(f"Failed to register models with llama-swap: {e}") - await unified_monitor.start_monitoring() - - # Start background task for LMDeploy status and logs broadcasting - lmdeploy_broadcast_task = None - - async def broadcast_lmdeploy_updates(): - """Periodically broadcast LMDeploy status and runtime logs.""" - installer = get_lmdeploy_installer() - manager = get_lmdeploy_manager() - last_runtime_log_position = 0 - - while True: - try: - # Broadcast status every 2 seconds - await installer._broadcast_status() - - # Broadcast new runtime log lines every 1 second - await manager._broadcast_runtime_logs() - - await asyncio.sleep(1) # Check every 1 second - except Exception as e: - logger.debug(f"Error in LMDeploy broadcast task: {e}") - await asyncio.sleep(2) # Wait longer on error - - lmdeploy_broadcast_task = asyncio.create_task(broadcast_lmdeploy_updates()) - logger.info("Started LMDeploy WebSocket broadcasting task") - yield # Shutdown - if lmdeploy_broadcast_task: - lmdeploy_broadcast_task.cancel() - try: - await lmdeploy_broadcast_task - except asyncio.CancelledError: - pass - logger.info("Stopped LMDeploy WebSocket broadcasting task") - - await unified_monitor.stop_monitoring() # Stop llama-swap (automatically stops all models) if llama_swap_manager: @@ -282,9 +234,16 @@ async def broadcast_lmdeploy_updates(): # CORS configuration via environment variables (safer defaults) # BACKEND_CORS_ORIGINS: comma-separated list of origins. Example: "http://localhost:5173,http://localhost:8080" # BACKEND_CORS_ALLOW_CREDENTIALS: "true"/"false" (default false; forced false when origins == ["*"]) -cors_origins_env = os.getenv("BACKEND_CORS_ORIGINS", "http://localhost:5173").strip() +cors_origins_env = os.getenv( + "BACKEND_CORS_ORIGINS", + "http://localhost:5173,http://localhost:5174,http://localhost:5175,http://localhost:5176,http://localhost:8080", +).strip() allow_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()] or [ - "http://localhost:5173" + "http://localhost:5173", + "http://localhost:5174", + "http://localhost:5175", + "http://localhost:5176", + "http://localhost:8080", ] allow_credentials_env = ( @@ -302,8 +261,6 @@ async def broadcast_lmdeploy_updates(): allow_headers=["*"], ) -# Use the global WebSocket manager instance - # Include routers app.include_router(models.router, prefix="/api/models", tags=["models"]) app.include_router( @@ -316,42 +273,34 @@ async def broadcast_lmdeploy_updates(): app.include_router(gpu_info.router, prefix="/api", tags=["gpu"]) app.include_router(lmdeploy.router, prefix="/api", tags=["lmdeploy"]) -# Include monitoring routes -from backend.routes import unified_monitoring - -app.include_router(unified_monitoring.router, prefix="/api", tags=["monitoring"]) - - -# WebSocket endpoint for real-time updates (must be before static file serving) -@app.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): - import json - - try: - logger.info("New WebSocket connection attempt") - await websocket_manager.connect(websocket) - logger.info( - f"WebSocket connected successfully. Total connections: {len(websocket_manager.active_connections)}" - ) - - try: - while True: - # Keep connection alive and handle any incoming messages - data = await websocket.receive_text() - message = json.loads(data) - - # Handle any client messages if needed - logger.debug(f"Received WebSocket message: {message}") - - except WebSocketDisconnect: - logger.info("WebSocket disconnected by client") - websocket_manager.disconnect(websocket) - except Exception as e: - logger.error(f"WebSocket error: {e}") - websocket_manager.disconnect(websocket) - except Exception as e: - logger.error(f"Failed to establish WebSocket connection: {e}") - websocket_manager.disconnect(websocket) +# SSE endpoint for progress tracking +from backend.progress_manager import get_progress_manager +from fastapi.responses import StreamingResponse + + +@app.get("/api/events") +async def sse_events(request: Request): + """Server-Sent Events endpoint for progress tracking.""" + logger.info("SSE /api/events: client connected") + pm = get_progress_manager() + + async def logged_stream(): + first = True + async for chunk in pm.subscribe(): + if first: + logger.info("SSE: sending first chunk to client") + first = False + yield chunk + + return StreamingResponse( + logged_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Disable proxy buffering (nginx, etc.) + }, + ) # Serve static files (built frontend) @@ -399,8 +348,8 @@ async def serve_favicon(): # Catch-all route for Vue Router (must be after API routes) @app.get("/{full_path:path}") async def serve_spa(full_path: str): - # If it's an API route or WebSocket route, let it pass through - if full_path.startswith("api/") or full_path.startswith("ws"): + # If it's an API route, let it pass through + if full_path.startswith("api/"): return {"error": "Not found"} # Serve index.html for all other routes (Vue Router will handle routing) @@ -455,9 +404,12 @@ async def serve_spa(full_path: str): if __name__ == "__main__": - # Enable hot reload in development (set RELOAD=true environment variable) - enable_reload = os.getenv("RELOAD", "false").lower() in ("true", "1", "yes") - reload_dirs = ["/app/backend"] if enable_reload else None + # Auto-reload in development: on by default when not in Docker; set RELOAD=false to disable + in_docker = os.path.exists("/app/data") + enable_reload = os.getenv("RELOAD", "true" if not in_docker else "false").lower() in ("true", "1", "yes") + # Watch the backend package directory (works when run from repo root with --app-dir backend) + backend_dir = os.path.abspath(os.path.dirname(__file__)) + reload_dirs = [backend_dir] if enable_reload else None uvicorn.run( "main:app", diff --git a/backend/param_registry.py b/backend/param_registry.py new file mode 100644 index 0000000..9b8e5bb --- /dev/null +++ b/backend/param_registry.py @@ -0,0 +1,117 @@ +""" +Registry of model config parameters for llama.cpp (and optionally LMDeploy). +Used by the frontend to render basic vs advanced settings from a single source of truth. +""" + +import copy +from typing import Any, Dict, List + +# Param entry: key, label, type ("int"|"float"|"bool"|"string"), default, min, max (optional), description (optional) +ParamDef = Dict[str, Any] + +# Basic params shown by default (most common for chat/embedding) +# Host and port are not included: they are managed by llama-swap (--port ${PORT}, host default 0.0.0.0) +LLAMA_CPP_BASIC: List[ParamDef] = [ + {"key": "ctx_size", "label": "Context size", "type": "int", "default": 2048, "min": 512, "max": 1_000_000, "description": "Maximum context length in tokens"}, + {"key": "n_gpu_layers", "label": "GPU layers", "type": "int", "default": -1, "min": -1, "max": 1000, "description": "Number of layers to offload to GPU (-1 = all)"}, + {"key": "batch_size", "label": "Batch size", "type": "int", "default": 512, "min": 1, "max": 2048, "description": "Batch size for prompt processing"}, + {"key": "threads", "label": "Threads", "type": "int", "default": 4, "min": 1, "max": 64, "description": "Number of threads"}, + {"key": "embedding", "label": "Embedding mode", "type": "bool", "default": False, "description": "Enable embedding-only mode"}, +] + +# Advanced params (shown in expandable "Advanced" section) +LLAMA_CPP_ADVANCED: List[ParamDef] = [ + {"key": "n_predict", "label": "Max tokens to predict", "type": "int", "default": -1, "min": -1, "max": 100_000}, + {"key": "ubatch_size", "label": "Ubatch size", "type": "int", "default": 512, "min": 1, "max": 2048}, + {"key": "temp", "label": "Temperature", "type": "float", "default": 0.8, "min": 0, "max": 2}, + {"key": "top_k", "label": "Top K", "type": "int", "default": 40, "min": 0, "max": 1000}, + {"key": "top_p", "label": "Top P", "type": "float", "default": 0.9, "min": 0, "max": 1}, + {"key": "min_p", "label": "Min P", "type": "float", "default": 0.0, "min": 0, "max": 1}, + {"key": "typical_p", "label": "Typical P", "type": "float", "default": 1.0, "min": 0, "max": 1}, + {"key": "repeat_penalty", "label": "Repeat penalty", "type": "float", "default": 1.1, "min": 1, "max": 2}, + {"key": "presence_penalty", "label": "Presence penalty", "type": "float", "default": 0, "min": -2, "max": 2}, + {"key": "frequency_penalty", "label": "Frequency penalty", "type": "float", "default": 0, "min": -2, "max": 2}, + {"key": "seed", "label": "Seed", "type": "int", "default": -1, "min": -1, "max": 2**31 - 1}, + {"key": "threads_batch", "label": "Threads (batch)", "type": "int", "default": -1, "min": -1, "max": 64}, + {"key": "parallel", "label": "Parallel", "type": "int", "default": 1, "min": 1, "max": 64}, + {"key": "rope_freq_base", "label": "RoPE freq base", "type": "float", "default": 0, "min": 0}, + {"key": "rope_freq_scale", "label": "RoPE freq scale", "type": "float", "default": 0, "min": 0}, + {"key": "flash_attn", "label": "Flash attention", "type": "bool", "default": False}, + {"key": "yarn_ext_factor", "label": "YaRN ext factor", "type": "float", "default": -1, "min": -1}, + {"key": "yarn_attn_factor", "label": "YaRN attn factor", "type": "float", "default": 1, "min": 0}, + {"key": "no_mmap", "label": "No mmap", "type": "bool", "default": False}, + {"key": "mlock", "label": "MLock", "type": "bool", "default": False}, + {"key": "low_vram", "label": "Low VRAM", "type": "bool", "default": False}, + {"key": "logits_all", "label": "Logits all", "type": "bool", "default": False}, + {"key": "cont_batching", "label": "Continuous batching", "type": "bool", "default": True}, + {"key": "no_kv_offload", "label": "No KV offload", "type": "bool", "default": False}, + {"key": "tensor_split", "label": "Tensor split", "type": "string", "default": ""}, + {"key": "main_gpu", "label": "Main GPU", "type": "int", "default": 0, "min": 0}, + {"key": "split_mode", "label": "Split mode", "type": "string", "default": ""}, + {"key": "cache_type_k", "label": "Cache type K", "type": "string", "default": ""}, + {"key": "cache_type_v", "label": "Cache type V", "type": "string", "default": ""}, + {"key": "grammar", "label": "Grammar", "type": "string", "default": ""}, + {"key": "json_schema", "label": "JSON schema", "type": "string", "default": ""}, + {"key": "cpu_moe", "label": "CPU MoE", "type": "bool", "default": False}, + {"key": "n_cpu_moe", "label": "N CPU MoE", "type": "int", "default": 0, "min": 0}, + {"key": "override_tensor", "label": "Override tensor", "type": "string", "default": ""}, + {"key": "rope_scaling", "label": "RoPE scaling", "type": "string", "default": ""}, + {"key": "mirostat", "label": "Mirostat", "type": "int", "default": 0, "min": 0, "max": 2}, + {"key": "mirostat_tau", "label": "Mirostat tau", "type": "float", "default": 5.0, "min": 0}, + {"key": "mirostat_eta", "label": "Mirostat eta", "type": "float", "default": 0.1, "min": 0}, +] + +# ik_llama.cpp: same as llama_cpp plus these extras (and different mirostat flag names) +IK_LLAMA_EXTRA: List[ParamDef] = [ + {"key": "mla_attn", "label": "MLA attention", "type": "bool", "default": False, "description": "Enable MLA attention"}, + {"key": "attn_max_batch", "label": "Attention max batch", "type": "int", "default": 0, "min": 0, "description": "Max attention batch size"}, + {"key": "fused_moe", "label": "Fused MoE", "type": "bool", "default": True, "description": "Enable fused MoE"}, + {"key": "smart_expert_reduction", "label": "Smart expert reduction", "type": "bool", "default": False, "description": "Enable smart expert reduction"}, +] + +# LMDeploy (safetensors / TurboMind) +LMDEPLOY_BASIC: List[ParamDef] = [ + {"key": "session_len", "label": "Session length", "type": "int", "default": 2048, "min": 512, "max": 1_000_000, "description": "Maximum session length"}, + {"key": "max_batch_size", "label": "Max batch size", "type": "int", "default": 128, "min": 1, "max": 1024, "description": "Maximum batch size"}, + {"key": "tensor_parallel", "label": "Tensor parallel", "type": "int", "default": 1, "min": 1, "max": 8, "description": "Tensor parallelism degree"}, +] +LMDEPLOY_ADVANCED: List[ParamDef] = [ + {"key": "dtype", "label": "Dtype", "type": "string", "default": "auto", "description": "Model dtype (auto, float16, bfloat16)"}, + {"key": "quant_policy", "label": "Quantization policy", "type": "int", "default": 0, "min": 0, "max": 8, "description": "KV cache quantization (0=off, 4=4bit, 8=8bit)"}, + {"key": "enable_prefix_caching", "label": "Prefix caching", "type": "bool", "default": False, "description": "Enable prefix caching"}, + {"key": "chat_template", "label": "Chat template", "type": "string", "default": "", "description": "Override chat template"}, +] + + +def get_llama_cpp_param_registry() -> Dict[str, List[ParamDef]]: + """Return basic and advanced param definitions for llama.cpp config forms.""" + return { + "basic": LLAMA_CPP_BASIC, + "advanced": LLAMA_CPP_ADVANCED, + } + + +def get_ik_llama_param_registry() -> Dict[str, List[ParamDef]]: + """Return param definitions for ik_llama.cpp (llama_cpp params plus ik_llama extras).""" + basic = copy.deepcopy(LLAMA_CPP_BASIC) + advanced = copy.deepcopy(LLAMA_CPP_ADVANCED) + copy.deepcopy(IK_LLAMA_EXTRA) + return {"basic": basic, "advanced": advanced} + + +def get_lmdeploy_param_registry() -> Dict[str, List[ParamDef]]: + """Return param definitions for LMDeploy (safetensors / TurboMind).""" + return { + "basic": LMDEPLOY_BASIC, + "advanced": LMDEPLOY_ADVANCED, + } + + +def get_param_registry(engine: str = "llama_cpp") -> Dict[str, List[ParamDef]]: + """Return param registry for the given engine.""" + if engine == "llama_cpp": + return get_llama_cpp_param_registry() + if engine == "ik_llama": + return get_ik_llama_param_registry() + if engine == "lmdeploy": + return get_lmdeploy_param_registry() + return {"basic": [], "advanced": []} diff --git a/backend/presets.py b/backend/presets.py deleted file mode 100644 index 517a9cf..0000000 --- a/backend/presets.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import Dict, Any, Tuple -import os - -from backend.gguf_reader import get_model_layer_info -from backend.logging_config import get_logger - -logger = get_logger(__name__) - - -def _detect_architecture_from_name(model_name: str) -> str: - """Detect model architecture from model name""" - name = (model_name or "").lower() - - if "llama" in name: - if "codellama" in name: - return "codellama" - elif "llama3" in name or "llama-3" in name: - return "llama3" - elif "llama2" in name or "llama-2" in name: - return "llama2" - return "llama" - elif "mistral" in name: - return "mistral" - elif "phi" in name: - return "phi" - elif "glm" in name or "chatglm" in name: - if "glm-4" in name or "glm4" in name: - return "glm4" - return "glm" - elif "deepseek" in name: - if "v3" in name or "v3.1" in name: - return "deepseek-v3" - return "deepseek" - elif "qwen" in name: - if "qwen3" in name or "qwen-3" in name: - return "qwen3" - elif "qwen2" in name or "qwen-2" in name: - return "qwen2" - return "qwen" - elif "gemma" in name: - if "gemma3" in name or "gemma-3" in name: - return "gemma3" - return "gemma" - - return "unknown" - - -def get_architecture_and_presets(model) -> Tuple[str, Dict[str, Dict[str, Any]]]: - """ - Source of truth for presets. Returns (architecture, presets dict). - Presets include keys like temp, top_p, top_k, repeat_penalty. - """ - # Import normalize_architecture from model_metadata to ensure consistency - from backend.smart_auto.architecture_config import ( - normalize_architecture, - detect_architecture_from_name, - ) - - # Try GGUF metadata - architecture = "unknown" - try: - if model.file_path and os.path.exists(model.file_path): - layer_info = get_model_layer_info(model.file_path) - if layer_info: - raw_architecture = layer_info.get("architecture", "") - architecture = normalize_architecture(raw_architecture) - if architecture != "unknown" and raw_architecture != architecture: - logger.debug( - f"Normalized architecture for presets: '{raw_architecture}' -> '{architecture}'" - ) - except Exception as e: - logger.warning(f"Failed to get layer info for presets: {e}") - - # Fallback to name-based detection if architecture is still unknown or empty - if not architecture or architecture == "unknown": - architecture = detect_architecture_from_name(model.name) - if architecture != "unknown": - logger.debug( - f"Detected architecture from name for presets: '{architecture}'" - ) - - # Defaults - presets: Dict[str, Dict[str, Any]] = { - "coding": {}, - "conversational": {}, - } - - model_lower = (model.name or "").lower() - is_coding_model = "code" in model_lower or architecture in ["codellama", "deepseek"] - - if architecture in ["glm", "glm4"]: - presets["coding"] = { - "temp": 1.0, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.05, - } - presets["conversational"] = { - "temp": 1.0, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - elif architecture in ["deepseek", "deepseek-v3"]: - presets["coding"] = { - "temp": 1.0, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.05, - } - presets["conversational"] = { - "temp": 0.9, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - elif architecture in ["qwen", "qwen2", "qwen3"]: - presets["coding"] = { - "temp": 0.7, - "top_p": 0.8, - "top_k": 20, - "repeat_penalty": 1.05, - } - presets["conversational"] = { - "temp": 0.7, - "top_p": 0.8, - "top_k": 20, - "repeat_penalty": 1.05, - } - elif architecture in ["gemma", "gemma3"]: - presets["coding"] = { - "temp": 0.9, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.05, - } - presets["conversational"] = { - "temp": 0.9, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - elif is_coding_model: - presets["coding"] = { - "temp": 0.1, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.05, - } - presets["conversational"] = { - "temp": 0.7, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - else: - presets["coding"] = { - "temp": 0.7, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - presets["conversational"] = { - "temp": 0.8, - "top_p": 0.95, - "top_k": 40, - "repeat_penalty": 1.1, - } - - return architecture, presets diff --git a/backend/progress_manager.py b/backend/progress_manager.py new file mode 100644 index 0000000..46269e4 --- /dev/null +++ b/backend/progress_manager.py @@ -0,0 +1,236 @@ +"""SSE-based progress tracking.""" + +import asyncio +import json +import time +import uuid +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional + + +class ProgressManager: + """In-memory task tracker with SSE streaming.""" + + def __init__(self): + self._tasks: Dict[str, dict] = {} + self._subscribers: list[asyncio.Queue] = [] + + def create_task( + self, + task_type: str, + description: str, + metadata: Optional[dict] = None, + task_id: Optional[str] = None, + ) -> str: + """Create a new tracked task. Returns task_id (uses provided task_id if given).""" + task_id = task_id or str(uuid.uuid4())[:8] + self._tasks[task_id] = { + "task_id": task_id, + "type": task_type, + "description": description, + "progress": 0.0, + "status": "running", + "message": "", + "metadata": metadata or {}, + "created_at": time.time(), + } + self._broadcast({"event": "task_created", "data": self._tasks[task_id]}) + return task_id + + def update_task( + self, + task_id: str, + progress: Optional[float] = None, + message: Optional[str] = None, + status: Optional[str] = None, + metadata_update: Optional[dict] = None, + ): + """Update a task's progress/status.""" + task = self._tasks.get(task_id) + if not task: + return + if progress is not None: + task["progress"] = min(100.0, max(0.0, progress)) + if message is not None: + task["message"] = message + if status is not None: + task["status"] = status + if metadata_update: + task["metadata"].update(metadata_update) + self._broadcast({"event": "task_updated", "data": task}) + + def complete_task(self, task_id: str, message: str = "Done"): + self.update_task(task_id, progress=100.0, status="completed", message=message) + + def fail_task(self, task_id: str, error: str): + self.update_task(task_id, status="failed", message=error) + + def get_task(self, task_id: str) -> Optional[dict]: + return self._tasks.get(task_id) + + def get_active_tasks(self) -> list: + return [t for t in self._tasks.values() if t["status"] == "running"] + + def _broadcast(self, event: dict): + dead = [] + for q in self._subscribers: + try: + q.put_nowait(event) + except asyncio.QueueFull: + dead.append(q) + for q in dead: + self._subscribers.remove(q) + + def emit(self, event_type: str, data: Any): + """Emit a generic event (e.g. log, notification, model_status) to SSE subscribers.""" + self._broadcast({"event": event_type, "data": data}) + + @property + def active_connections(self) -> List: + """SSE has no persistent connection list; returns empty for compatibility.""" + return [] + + async def send_download_progress( + self, + task_id: str, + progress: int, + message: str = "", + bytes_downloaded: int = 0, + total_bytes: int = 0, + speed_mbps: float = 0, + eta_seconds: int = 0, + filename: str = "", + model_format: str = "gguf", + files_completed: int = None, + files_total: int = None, + current_filename: str = None, + huggingface_id: str = None, + **kwargs, + ): + self.update_task( + task_id, + progress=float(progress), + message=message or filename, + metadata_update=kwargs, + ) + self.emit( + "download_progress", + { + "task_id": task_id, + "progress": progress, + "message": message, + "bytes_downloaded": bytes_downloaded, + "total_bytes": total_bytes, + "speed_mbps": speed_mbps, + "eta_seconds": eta_seconds, + "filename": filename, + "model_format": model_format, + "files_completed": files_completed, + "files_total": files_total, + "current_filename": current_filename or filename, + "huggingface_id": huggingface_id, + "timestamp": datetime.utcnow().isoformat(), + **kwargs, + }, + ) + + async def broadcast(self, message: dict): + msg_type = message.get("type", "broadcast") + self.emit(msg_type, message) + + async def send_model_status_update( + self, model_id: Any, status: str, details: dict = None + ): + self.emit( + "model_status", + { + "model_id": model_id, + "status": status, + "details": details or {}, + "timestamp": datetime.utcnow().isoformat(), + }, + ) + + async def send_notification( + self, + title: str = "", + message: str = "", + type: str = "info", + actions: List[dict] = None, + *args, + **kwargs, + ): + # Support (title, message, type) keyword and (type, title, message, task_id) positional + if args and len(args) >= 3: + type, title, message = args[0], args[1], args[2] + else: + type = kwargs.get("type", type) + title = kwargs.get("title", title) + message = kwargs.get("message", message) + self.emit( + "notification", + { + "title": title, + "message": message, + "type": type, + "notification_type": type, + "actions": actions or [], + "timestamp": datetime.utcnow().isoformat(), + **{k: v for k, v in kwargs.items() if k not in ("title", "message", "type", "actions")}, + }, + ) + + async def send_build_progress( + self, + task_id: str, + stage: str, + progress: int, + message: str = "", + log_lines: List[str] = None, + ): + self.update_task( + task_id, + progress=float(progress), + message=message, + metadata_update={"stage": stage, "log_lines": log_lines or []}, + ) + self.emit( + "build_progress", + { + "task_id": task_id, + "stage": stage, + "progress": progress, + "message": message, + "log_lines": log_lines or [], + "timestamp": datetime.utcnow().isoformat(), + }, + ) + + async def subscribe(self) -> AsyncGenerator[str, None]: + """Yields SSE-formatted strings. Sends an initial comment so the client connection opens.""" + queue: asyncio.Queue = asyncio.Queue(maxsize=100) + self._subscribers.append(queue) + try: + # Send immediate heartbeat so EventSource receives data and fires onopen + yield ": heartbeat\n\n" + await asyncio.sleep(0) # Allow first chunk to be flushed to the client + for task in self.get_active_tasks(): + yield f"event: task_updated\ndata: {json.dumps(task)}\n\n" + while True: + event = await queue.get() + yield f"event: {event['event']}\ndata: {json.dumps(event['data'])}\n\n" + except asyncio.CancelledError: + pass + finally: + if queue in self._subscribers: + self._subscribers.remove(queue) + + +_progress_manager: Optional[ProgressManager] = None + + +def get_progress_manager() -> ProgressManager: + global _progress_manager + if _progress_manager is None: + _progress_manager = ProgressManager() + return _progress_manager diff --git a/backend/routes/llama_version_manager.py b/backend/routes/llama_version_manager.py index 47e3bd5..6b9ee9a 100644 --- a/backend/routes/llama_version_manager.py +++ b/backend/routes/llama_version_manager.py @@ -1,13 +1,10 @@ -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks -from sqlalchemy.orm import Session -from typing import List, Dict, Any +from fastapi import APIRouter, HTTPException import os import shutil import stat import time -from datetime import datetime -from backend.database import get_db, LlamaVersion, Model +from backend.data_store import get_store from backend.logging_config import get_logger logger = get_logger(__name__) @@ -15,7 +12,6 @@ def _remove_readonly(func, path, exc): - """Helper function to handle readonly files on Windows""" try: os.chmod(path, stat.S_IWRITE) func(path) @@ -24,198 +20,140 @@ def _remove_readonly(func, path, exc): def _robust_rmtree(path: str, max_retries: int = 3) -> None: - """Robustly remove a directory tree, handling Windows file locks""" if not os.path.exists(path): return - for attempt in range(max_retries): try: - # Use onerror callback to handle readonly files (common on Windows) shutil.rmtree(path, onerror=_remove_readonly) logger.info(f"Successfully deleted directory: {path}") return - except PermissionError as e: - if attempt < max_retries - 1: - logger.warning( - f"Permission error deleting {path}, attempt {attempt + 1}/{max_retries}: {e}" - ) - time.sleep(0.5) # Wait a bit before retrying - else: - logger.error( - f"Failed to delete {path} after {max_retries} attempts: {e}" - ) - raise - except OSError as e: + except (PermissionError, OSError) as e: if attempt < max_retries - 1: - logger.warning( - f"OS error deleting {path}, attempt {attempt + 1}/{max_retries}: {e}" - ) time.sleep(0.5) else: - logger.error( - f"Failed to delete {path} after {max_retries} attempts: {e}" - ) + logger.error(f"Failed to delete {path} after {max_retries} attempts: {e}") raise -@router.get("/llama-versions") -async def list_llama_versions(db: Session = Depends(get_db)): - """List all installed llama-cpp versions""" - versions = db.query(LlamaVersion).all() - - # Also scan the filesystem for any versions not in the database - llama_cpp_dir = ( - "data/llama-cpp" if os.path.exists("data") else "/app/data/llama-cpp" - ) - if os.path.exists(llama_cpp_dir): - for version_dir in os.listdir(llama_cpp_dir): - if os.path.isdir(os.path.join(llama_cpp_dir, version_dir)): - # Check if this version is already in the database - existing_version = ( - db.query(LlamaVersion) - .filter(LlamaVersion.version == version_dir) - .first() - ) - if not existing_version: - # Add to database - binary_path = os.path.join( - llama_cpp_dir, version_dir, "build", "bin", "llama-server" - ) - if os.path.exists(binary_path): - new_version = LlamaVersion( - version=version_dir, - install_type="source", - source_commit=version_dir, - is_active=False, - binary_path=binary_path, - ) - db.add(new_version) - db.commit() - logger.info( - f"Added llama-cpp version {version_dir} to database" - ) - - # Refresh the list - versions = db.query(LlamaVersion).all() +def _resolve_binary_path(binary_path: str) -> str: + if not binary_path: + return "" + if os.path.isabs(binary_path): + return binary_path + return os.path.join("/app", binary_path) - return { - "versions": [ - { - "id": v.id, - "version": v.version, - "install_type": v.install_type, - "source_commit": v.source_commit, - "is_active": v.is_active, - "installed_at": v.installed_at.isoformat() if v.installed_at else None, - "binary_path": v.binary_path, - "exists": os.path.exists(v.binary_path) if v.binary_path else False, - } - for v in versions - ] - } + +@router.get("/llama-versions") +async def list_llama_versions(): + """List all installed llama-cpp versions (llama_cpp engine).""" + store = get_store() + versions = store.get_engine_versions("llama_cpp") + result = [] + for i, v in enumerate(versions): + binary_path = _resolve_binary_path(v.get("binary_path")) + result.append({ + "id": i, + "version": v.get("version"), + "install_type": v.get("type", "source"), + "source_commit": v.get("source_commit"), + "is_active": store.get_active_engine_version("llama_cpp") and store.get_active_engine_version("llama_cpp").get("version") == v.get("version"), + "installed_at": v.get("installed_at"), + "binary_path": v.get("binary_path"), + "exists": os.path.exists(binary_path) if binary_path else False, + }) + return {"versions": result} @router.post("/llama-versions/{version_id}/activate") -async def activate_llama_version(version_id: int, db: Session = Depends(get_db)): - """Activate a specific llama-cpp version""" - # Deactivate all versions first - db.query(LlamaVersion).update({"is_active": False}) - - # Activate the selected version - version = db.query(LlamaVersion).filter(LlamaVersion.id == version_id).first() - if not version: +async def activate_llama_version(version_id: str): + """Activate a specific llama-cpp version (version_id can be index, version string, or "llama_cpp:version").""" + store = get_store() + versions = store.get_engine_versions("llama_cpp") + # Frontend may send id from list endpoint: "llama_cpp:version_str" + lookup_id = version_id + if ":" in str(version_id): + parts = str(version_id).split(":", 1) + if parts[0] == "llama_cpp": + lookup_id = parts[1] + version_entry = None + try: + idx = int(lookup_id) + if 0 <= idx < len(versions): + version_entry = versions[idx] + except ValueError: + pass + if not version_entry: + version_entry = next((v for v in versions if str(v.get("version")) == str(lookup_id)), None) + if not version_entry: raise HTTPException(status_code=404, detail="Version not found") - - if not os.path.exists(version.binary_path): + binary_path = _resolve_binary_path(version_entry.get("binary_path")) + if not os.path.exists(binary_path): raise HTTPException(status_code=400, detail="Binary file does not exist") - - version.is_active = True - db.commit() - - # Ensure binary path is correct for the newly activated version + version_str = str(version_entry.get("version")) + store.set_active_engine_version("llama_cpp", version_str) try: from backend.llama_swap_manager import get_llama_swap_manager - llama_swap_manager = get_llama_swap_manager() - - # Check and fix binary path if needed await llama_swap_manager._ensure_correct_binary_path() - logger.info(f"Binary path verified for activated version: {version.version}") - - # Regenerate llama-swap configuration with new binary path - # This will also ensure llama-swap is started await llama_swap_manager.regenerate_config_with_active_version() - - logger.info( - f"Regenerated llama-swap config with new active version: {version.version}" - ) - - # Explicitly ensure llama-swap is running after activation try: await llama_swap_manager.start_proxy() - logger.info("Ensured llama-swap is running after version activation") except Exception as e: logger.warning(f"Failed to start llama-swap after version activation: {e}") except Exception as e: logger.error(f"Failed to regenerate llama-swap config: {e}") - # Don't fail the activation if config regeneration fails - - logger.info(f"Activated llama-cpp version: {version.version}") - return {"message": f"Activated llama-cpp version {version.version}"} + logger.info(f"Activated llama-cpp version: {version_str}") + return {"message": f"Activated llama-cpp version {version_str}"} @router.delete("/llama-versions/{version_id}") -async def delete_llama_version(version_id: int, db: Session = Depends(get_db)): - """Delete a llama-cpp version""" - version = db.query(LlamaVersion).filter(LlamaVersion.id == version_id).first() - if not version: +async def delete_llama_version(version_id: str): + """Delete a llama-cpp version (version_id can be index or version string).""" + store = get_store() + versions = store.get_engine_versions("llama_cpp") + version_entry = None + try: + idx = int(version_id) + if 0 <= idx < len(versions): + version_entry = versions[idx] + except ValueError: + pass + if not version_entry: + version_entry = next((v for v in versions if str(v.get("version")) == str(version_id)), None) + if not version_entry: raise HTTPException(status_code=404, detail="Version not found") - - if version.is_active: + version_str = str(version_entry.get("version")) + active = store.get_active_engine_version("llama_cpp") + if active and str(active.get("version")) == version_str: raise HTTPException(status_code=400, detail="Cannot delete active version") - - # Delete the directory - version_dir = os.path.dirname( - os.path.dirname(version.binary_path) - ) # Go up from build/bin/llama-server - if os.path.exists(version_dir): + binary_path = _resolve_binary_path(version_entry.get("binary_path")) + version_dir = os.path.dirname(os.path.dirname(binary_path)) if binary_path else None + if version_dir and os.path.exists(version_dir): try: _robust_rmtree(version_dir) except Exception as e: logger.error(f"Failed to delete directory {version_dir}: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to delete directory: {e}" - ) - - # Remove from database - db.delete(version) - db.commit() - - logger.info(f"Deleted llama-cpp version: {version.version}") - return {"message": f"Deleted llama-cpp version {version.version}"} + raise HTTPException(status_code=500, detail=f"Failed to delete directory: {e}") + store.delete_engine_version("llama_cpp", version_str) + logger.info(f"Deleted llama-cpp version: {version_str}") + return {"message": f"Deleted llama-cpp version {version_str}"} @router.get("/llama-versions/active") -async def get_active_llama_version(db: Session = Depends(get_db)): - """Get the currently active llama-cpp version""" - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) - +async def get_active_llama_version(): + """Get the currently active llama-cpp version.""" + store = get_store() + active_version = store.get_active_engine_version("llama_cpp") if not active_version: return {"active_version": None} - + binary_path = _resolve_binary_path(active_version.get("binary_path")) return { "active_version": { - "id": active_version.id, - "version": active_version.version, - "install_type": active_version.install_type, - "source_commit": active_version.source_commit, - "binary_path": active_version.binary_path, - "exists": ( - os.path.exists(active_version.binary_path) - if active_version.binary_path - else False - ), + "id": 0, + "version": active_version.get("version"), + "install_type": active_version.get("type"), + "source_commit": active_version.get("source_commit"), + "binary_path": active_version.get("binary_path"), + "exists": os.path.exists(binary_path) if binary_path else False, } } diff --git a/backend/routes/llama_versions.py b/backend/routes/llama_versions.py index 4f296c8..4941c5c 100644 --- a/backend/routes/llama_versions.py +++ b/backend/routes/llama_versions.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks -from sqlalchemy.orm import Session +from fastapi import APIRouter, HTTPException, Body from typing import List, Optional +import asyncio import json import os import subprocess @@ -11,9 +11,9 @@ import stat from datetime import datetime -from backend.database import get_db, LlamaVersion +from backend.data_store import get_store from backend.llama_manager import LlamaManager, BuildConfig -from backend.websocket_manager import websocket_manager +from backend.progress_manager import get_progress_manager from backend.logging_config import get_logger from backend.gpu_detector import get_gpu_info, detect_build_capabilities from backend.cuda_installer import get_cuda_installer @@ -69,62 +69,77 @@ def _robust_rmtree(path: str, max_retries: int = 3) -> None: @router.get("") @router.get("/") -async def list_llama_versions(db: Session = Depends(get_db)): - """List all installed llama.cpp versions""" - versions = db.query(LlamaVersion).all() - return [ - { - "id": version.id, - "version": version.version, - "install_type": version.install_type, - "binary_path": version.binary_path, - "source_commit": version.source_commit, - "patches": json.loads(version.patches) if version.patches else [], - "installed_at": version.installed_at, - "is_active": version.is_active, - "build_config": version.build_config, - "repository_source": version.repository_source or "llama.cpp", - } - for version in versions - ] +async def list_llama_versions(): + """List all installed llama.cpp and ik_llama versions""" + store = get_store() + result = [] + for engine, repo_label in [("llama_cpp", "llama.cpp"), ("ik_llama", "ik_llama.cpp")]: + active = store.get_active_engine_version(engine) + active_version = active.get("version") if active else None + for i, v in enumerate(store.get_engine_versions(engine)): + version_str = v.get("version") + result.append({ + "id": f"{engine}:{version_str}", + "version": version_str, + "install_type": v.get("type", "source"), + "binary_path": v.get("binary_path"), + "source_commit": v.get("source_commit"), + "patches": [], # No longer storing patches in YAML + "installed_at": v.get("installed_at"), + "is_active": v.get("version") == active_version, + "build_config": v.get("build_config"), + "repository_source": v.get("repository_source") or repo_label, + }) + return result @router.get("/check-updates") -async def check_updates(): - """Check for llama.cpp updates (both releases and source)""" +async def check_updates(source: str | None = None): + """Check for llama.cpp or ik_llama.cpp updates (releases and/or source). + source: None or 'llama_cpp' for ggerganov/llama.cpp; 'ik_llama' for ikawrakow/ik_llama.cpp. + """ try: - # Use the original URLs with redirect handling - releases_url = "https://api.github.com/repos/ggerganov/llama.cpp/releases" - commits_url = ( - "https://api.github.com/repos/ggerganov/llama.cpp/commits?per_page=1" - ) - - # Check GitHub releases - releases_response = requests.get(releases_url, allow_redirects=True) - releases_response.raise_for_status() - releases = releases_response.json() - - latest_release = releases[0] if releases else None + is_ik = source == "ik_llama" + if is_ik: + commits_url = ( + "https://api.github.com/repos/ikawrakow/ik_llama.cpp/commits?per_page=1" + ) + latest_release = None + else: + # ai-dock/llama.cpp-cuda: pre-built releases with CUDA support + releases_url = "https://api.github.com/repos/ai-dock/llama.cpp-cuda/releases" + commits_url = ( + "https://api.github.com/repos/ggerganov/llama.cpp/commits?per_page=1" + ) + releases_response = requests.get(releases_url, allow_redirects=True) + releases_response.raise_for_status() + releases = releases_response.json() + latest_release = releases[0] if releases else None - # Check latest commit from main branch commits_response = requests.get(commits_url, allow_redirects=True) commits_response.raise_for_status() commits = commits_response.json() latest_commit = commits[0] if commits else None return { - "latest_release": { - "tag_name": latest_release["tag_name"] if latest_release else None, - "published_at": ( - latest_release["published_at"] if latest_release else None - ), - "html_url": latest_release["html_url"] if latest_release else None, - }, - "latest_commit": { - "sha": latest_commit["sha"], - "commit_date": latest_commit["commit"]["committer"]["date"], - "message": latest_commit["commit"]["message"], - }, + "latest_release": ( + { + "tag_name": latest_release["tag_name"], + "published_at": latest_release["published_at"], + "html_url": latest_release["html_url"], + } + if latest_release + else None + ), + "latest_commit": ( + { + "sha": latest_commit["sha"], + "commit_date": latest_commit["commit"]["committer"]["date"], + "message": latest_commit["commit"]["message"], + } + if latest_commit + else None + ), } except requests.exceptions.HTTPError as e: if e.response.status_code == 403: @@ -199,10 +214,8 @@ async def get_build_capabilities_endpoint(): @router.post("/install-release") -async def install_release( - request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db) -): - """Install llama.cpp from GitHub release""" +async def install_release(request: dict): + """Install llama.cpp from ai-dock/llama.cpp-cuda release (CUDA builds).""" try: tag_name = request.get("tag_name") if not tag_name: @@ -241,17 +254,12 @@ async def install_release( version_name = preview.get("version_name") - # Check if version already exists - if version_name: - existing = ( - db.query(LlamaVersion) - .filter(LlamaVersion.version == version_name) - .first() - ) - else: - existing = ( - db.query(LlamaVersion).filter(LlamaVersion.version == tag_name).first() - ) + store = get_store() + existing_versions = store.get_engine_versions("llama_cpp") + existing = next( + (v for v in existing_versions if v.get("version") in (version_name, tag_name)), + None, + ) if existing: detail = "400: Version already installed" if version_name: @@ -261,10 +269,10 @@ async def install_release( # Generate task ID for tracking task_id = f"install_release_{tag_name}_{int(time.time())}" - # Start installation in background - background_tasks.add_task( - install_release_task, tag_name, websocket_manager, task_id, asset_id - ) + # Start installation in background (asyncio.create_task so it runs regardless of middleware) + pm = get_progress_manager() + pm.create_task("install_release", f"Install {tag_name}", {"tag_name": tag_name}, task_id=task_id) + asyncio.create_task(install_release_task(tag_name, pm, task_id, asset_id)) return { "message": f"Installing release {tag_name}", @@ -280,19 +288,15 @@ async def install_release( async def install_release_task( tag_name: str, - websocket_manager=None, + progress_manager=None, task_id: str = None, asset_id: Optional[int] = None, ): - """Background task to install release with WebSocket progress updates""" - # Create a new database session for the background task - from backend.database import SessionLocal - - db = SessionLocal() - + """Background task to install release with SSE progress updates""" + store = get_store() try: install_result = await llama_manager.install_release( - tag_name, websocket_manager, task_id, asset_id + tag_name, progress_manager, task_id, asset_id ) binary_path = install_result.get("binary_path") asset_info = install_result.get("asset") @@ -301,39 +305,36 @@ async def install_release_task( if not binary_path: raise Exception("Installation completed without returning a binary path.") - # Save to database - version = LlamaVersion( - version=version_name, - install_type="release", - binary_path=binary_path, - installed_at=datetime.utcnow(), - build_config=( + version_data = { + "version": version_name, + "type": "release", + "binary_path": binary_path, + "installed_at": datetime.utcnow().isoformat() + "Z", + "build_config": ( {"release_asset": asset_info, "tag_name": tag_name} if asset_info else None ), - ) - db.add(version) - db.commit() + "repository_source": "llama.cpp", + } + store.add_engine_version("llama_cpp", version_data) - # If this is the first version or if there's an active version, ensure llama-swap is running from backend.llama_swap_manager import get_llama_swap_manager - active_version = db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - if active_version and os.path.exists(active_version.binary_path): + active_version = store.get_active_engine_version("llama_cpp") + if active_version and active_version.get("binary_path") and os.path.exists(active_version.get("binary_path", "")): try: llama_swap_manager = get_llama_swap_manager() - # Regenerate config to include any new models, and ensure llama-swap is running await llama_swap_manager.regenerate_config_with_active_version() logger.info("Ensured llama-swap is running after release installation") except Exception as e: logger.warning(f"Failed to ensure llama-swap is running after release installation: {e}") - # Send success notification - if websocket_manager: + if progress_manager: asset_label = "" if asset_info and asset_info.get("name"): asset_label = f" ({asset_info['name']})" - await websocket_manager.send_notification( + progress_manager.complete_task(task_id, f"Installed {version_name}") + await progress_manager.send_notification( title="Installation Complete", message=f"Successfully installed llama.cpp release {version_name}{asset_label}", type="success", @@ -341,21 +342,18 @@ async def install_release_task( except Exception as e: logger.error(f"Release installation failed: {e}") - if websocket_manager: - await websocket_manager.send_notification( + if progress_manager and task_id: + progress_manager.fail_task(task_id, str(e)) + if progress_manager: + await progress_manager.send_notification( title="Installation Failed", message=f"Failed to install llama.cpp release: {str(e)}", type="error", ) - finally: - # Always close the database session - db.close() @router.post("/build-source") -async def build_source( - request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db) -): +async def build_source(request: dict): """Build llama.cpp from source with optional patches""" try: commit_sha = request.get("commit_sha") @@ -367,19 +365,17 @@ async def build_source( if not commit_sha: raise HTTPException(status_code=400, detail="commit_sha is required") - # Generate unique version name commit_short = commit_sha[:8] if version_suffix: version_name = f"source-{commit_short}-{version_suffix}" else: - # Use timestamp for unique naming timestamp = int(time.time()) version_name = f"source-{commit_short}-{timestamp}" - # Check if version already exists (still check to prevent accidental duplicates) - existing = ( - db.query(LlamaVersion).filter(LlamaVersion.version == version_name).first() - ) + store = get_store() + engine = "ik_llama" if repository_source == "ik_llama.cpp" else "llama_cpp" + existing_versions = store.get_engine_versions(engine) + existing = next((v for v in existing_versions if v.get("version") == version_name), None) if existing: raise HTTPException( status_code=400, detail=f"Version '{version_name}' already installed" @@ -393,25 +389,48 @@ async def build_source( detail=f"Unknown repository source: {repository_source}", ) - # Parse build_config if provided + # Parse build_config if provided (map frontend keys to BuildConfig field names) build_config = None - if build_config_dict: - build_config = BuildConfig(**build_config_dict) + if build_config_dict and isinstance(build_config_dict, dict): + def _bool(v): + if isinstance(v, bool): + return v + if isinstance(v, str): + return v.strip().lower() in ("1", "true", "yes", "on") + return bool(v) + + # Frontend sends cuda, flash_attention, native, backend_dl, cpu_all_variants + mapped = { + "enable_cuda": _bool(build_config_dict.get("cuda", False)), + "enable_flash_attention": _bool(build_config_dict.get("flash_attention", False)), + "enable_native": _bool(build_config_dict.get("native", True)), + "enable_backend_dl": _bool(build_config_dict.get("backend_dl", False)), + "enable_cpu_all_variants": _bool(build_config_dict.get("cpu_all_variants", False)), + "cuda_architectures": str(build_config_dict.get("cuda_architectures") or ""), + } + try: + build_config = BuildConfig(**mapped) + except (TypeError, ValueError) as e: + logger.warning("BuildConfig from request failed (%s), using defaults", e) + build_config = BuildConfig() # Generate task ID for tracking task_id = f"build_{version_name}_{int(time.time())}" - # Start build in background - background_tasks.add_task( - build_source_task, - commit_sha, - patches, - build_config, - version_name, - repository_source, - repository_url, - websocket_manager, - task_id, + # Start build in background (asyncio.create_task so it runs regardless of middleware) + pm = get_progress_manager() + pm.create_task("build", f"Build {repository_source} {commit_sha[:8]}", {"version_name": version_name}, task_id=task_id) + asyncio.create_task( + build_source_task( + commit_sha, + patches, + build_config or BuildConfig(), + version_name, + repository_source, + repository_url, + pm, + task_id, + ) ) return { @@ -433,106 +452,94 @@ async def build_source_task( version_name: str, repository_source: str, repository_url: str, - websocket_manager=None, + progress_manager=None, task_id: str = None, ): - """Background task to build from source with WebSocket progress""" - # Create a new database session for the background task - from backend.database import SessionLocal - from dataclasses import asdict - - db = SessionLocal() - + """Background task to build from source with SSE progress""" + logger.info( + "Build task started: version_name=%s, repository_source=%s, commit_sha=%s", + version_name, repository_source, commit_sha[:8] if commit_sha else "", + ) try: + from dataclasses import asdict + store = get_store() + engine = "ik_llama" if repository_source == "ik_llama.cpp" else "llama_cpp" + binary_path = await llama_manager.build_source( commit_sha, patches, build_config, - websocket_manager, + progress_manager, task_id, repository_url=repository_url, version_name=version_name, ) - # Save to database with build_config build_config_dict = None if build_config: build_config_dict = asdict(build_config) - # Add repository_source to build_config for completeness build_config_dict["repository_source"] = repository_source - version = LlamaVersion( - version=version_name, - install_type="patched" if patches else "source", - binary_path=binary_path, - source_commit=commit_sha, - patches=json.dumps(patches), - build_config=build_config_dict, - repository_source=repository_source, - installed_at=datetime.utcnow(), - ) - db.add(version) - db.commit() + version_data = { + "version": version_name, + "type": "patched" if patches else "source", + "binary_path": binary_path, + "source_commit": commit_sha, + "build_config": build_config_dict, + "repository_source": repository_source, + "installed_at": datetime.utcnow().isoformat() + "Z", + } + store.add_engine_version(engine, version_data) - # If there's an active version, ensure llama-swap is running from backend.llama_swap_manager import get_llama_swap_manager - active_version = db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - if active_version and os.path.exists(active_version.binary_path): + active_version = store.get_active_engine_version(engine) + if active_version and active_version.get("binary_path") and os.path.exists(active_version.get("binary_path", "")): try: llama_swap_manager = get_llama_swap_manager() - # Regenerate config to include any new models, and ensure llama-swap is running await llama_swap_manager.regenerate_config_with_active_version() logger.info("Ensured llama-swap is running after source build") except Exception as e: logger.warning(f"Failed to ensure llama-swap is running after source build: {e}") - # Send success notification - if websocket_manager: - await websocket_manager.send_notification( + if progress_manager: + if task_id: + progress_manager.complete_task(task_id, f"Built {version_name}") + await progress_manager.send_notification( title="Build Complete", message=f"Successfully built {repository_source} from source {commit_sha[:8]}", type="success", ) except Exception as e: - logger.error(f"Source build failed: {e}") - if websocket_manager: + logger.exception("Source build failed: %s", e) + if progress_manager: try: - logger.info(f"Sending build failure notification for task {task_id}") - await websocket_manager.send_notification( + if task_id: + progress_manager.fail_task(task_id, str(e)) + await progress_manager.send_notification( title="Build Failed", message=f"Failed to build llama.cpp from source: {str(e)}", type="error", ) - # Also send a build progress error message if task_id: - await websocket_manager.send_build_progress( + await progress_manager.send_build_progress( task_id=task_id, stage="error", progress=0, message=f"Build task failed: {str(e)}", - log_lines=[ - f"Task error: {str(e)}", - f"Error type: {type(e).__name__}", - ], + log_lines=[f"Task error: {str(e)}", f"Error type: {type(e).__name__}"], ) - logger.info(f"Build failure notifications sent successfully") except Exception as ws_error: logger.error(f"Failed to send build failure notification: {ws_error}") - finally: - # Always close the database session - db.close() @router.get("/task-status/{task_id}") async def get_task_status(task_id: str): """Get the status of a background task""" - # This is a simple implementation - in production you might want to store task status in Redis or database - # For now, we'll just return a basic response since the WebSocket provides real-time updates return { "task_id": task_id, - "status": "running", # Could be "running", "completed", "failed" - "message": "Task is running. Use WebSocket for real-time progress updates.", + "status": "running", + "message": "Task is running. Subscribe to GET /api/events for real-time SSE progress updates.", } @@ -563,33 +570,130 @@ async def get_version_commands(version: str): raise HTTPException(status_code=500, detail=str(e)) -@router.delete("/{version_id}") -async def delete_version(version_id: int, db: Session = Depends(get_db)): - """Delete llama.cpp version""" - version = db.query(LlamaVersion).filter(LlamaVersion.id == version_id).first() - if not version: +def _resolve_binary_path(binary_path: str) -> str: + if not binary_path: + return "" + if os.path.isabs(binary_path): + return binary_path + # Docker: paths relative to /app; local: relative to project root + if os.path.exists("/app/data"): + return os.path.normpath(os.path.join("/app", binary_path)) + cwd = os.getcwd() + resolved = os.path.normpath(os.path.join(cwd, binary_path)) + if os.path.exists(resolved): + return resolved + # When run with --app-dir backend, cwd may be backend/; project root is parent + parent = os.path.dirname(cwd) + return os.path.normpath(os.path.join(parent, binary_path)) + + +def _find_version_entry(store, version_id: str): + """Resolve version_id ('engine:version' or plain version) to (version_entry, engine). Returns (None, None) if not found.""" + version_entry = None + engine = None + if ":" in version_id: + parts = version_id.split(":", 1) + eng, version_str = parts[0], parts[1] + if eng in ("llama_cpp", "ik_llama"): + version_entry = next( + (v for v in store.get_engine_versions(eng) if str(v.get("version")) == version_str), + None, + ) + if version_entry: + engine = eng + if not version_entry: + for eng in ("llama_cpp", "ik_llama"): + versions = store.get_engine_versions(eng) + version_entry = next((v for v in versions if str(v.get("version")) == str(version_id)), None) + if version_entry: + engine = eng + break + return version_entry, engine + + +@router.post("/versions/activate") +async def activate_version_body(payload: dict = Body(...)): + """Activate a version; body: { \"version_id\": \"llama_cpp:version\" or \"version\" }.""" + version_id = (payload or {}).get("version_id") + if not version_id: + raise HTTPException(status_code=400, detail="version_id required") + return await _do_activate_version(version_id) + + +async def _do_activate_version(version_id: str): + store = get_store() + version_entry, engine = _find_version_entry(store, version_id) + if not version_entry or not engine: + logger.warning( + "activate_version: version not found, version_id=%r, llama_cpp versions=%s", + version_id, + [v.get("version") for v in store.get_engine_versions("llama_cpp")], + ) raise HTTPException(status_code=404, detail="Version not found") + version_str = str(version_entry.get("version")) + binary_path = _resolve_binary_path(version_entry.get("binary_path")) + if not os.path.exists(binary_path): + raise HTTPException(status_code=400, detail="Binary file does not exist") + store.set_active_engine_version(engine, version_str) + if engine == "llama_cpp": + try: + from backend.llama_swap_manager import get_llama_swap_manager + llama_swap_manager = get_llama_swap_manager() + await llama_swap_manager._ensure_correct_binary_path() + await llama_swap_manager.regenerate_config_with_active_version() + try: + await llama_swap_manager.start_proxy() + except Exception as e: + logger.warning("Failed to start llama-swap after version activation: %s", e) + except Exception as e: + logger.error("Failed to regenerate llama-swap config: %s", e) + logger.info("Activated %s version: %s", engine, version_str) + return {"message": f"Activated {engine} version {version_str}"} - # Prevent deletion of active version - if version.is_active: - raise HTTPException(status_code=400, detail="Cannot delete active version") +@router.delete("/{version_id}") +async def delete_version(version_id: str): + """Delete llama.cpp version (version_id is 'engine:version' or version string).""" + store = get_store() + version_entry = None + if ":" in version_id: + parts = version_id.split(":", 1) + engine, version_str = parts[0], parts[1] + if engine in ("llama_cpp", "ik_llama"): + version_entry = next( + (v for v in store.get_engine_versions(engine) if str(v.get("version")) == version_str), + None, + ) + if version_entry: + version_entry["_engine"] = engine + if not version_entry: + for engine in ("llama_cpp", "ik_llama"): + versions = store.get_engine_versions(engine) + version_entry = next((v for v in versions if str(v.get("version")) == str(version_id)), None) + if version_entry: + version_entry["_engine"] = engine + break + if not version_entry: + raise HTTPException(status_code=404, detail="Version not found") + engine = version_entry.get("_engine", "llama_cpp") + version_str = str(version_entry.get("version")) + active = store.get_active_engine_version(engine) + if active and str(active.get("version")) == version_str: + raise HTTPException(status_code=400, detail="Cannot delete active version") try: - # Delete the entire version directory - if version.binary_path and os.path.exists(version.binary_path): - # Go up two levels from build/bin/llama-server to get the version directory - version_dir = os.path.dirname(os.path.dirname(version.binary_path)) - if os.path.exists(version_dir): - _robust_rmtree(version_dir) - - # Delete from database - db.delete(version) - db.commit() - - logger.info(f"Deleted llama-cpp version: {version.version}") - return {"message": f"Deleted llama-cpp version {version.version}"} + binary_path = version_entry.get("binary_path") + if binary_path: + if not os.path.isabs(binary_path): + binary_path = os.path.join("/app", binary_path) + if os.path.exists(binary_path): + version_dir = os.path.dirname(os.path.dirname(binary_path)) + if os.path.exists(version_dir): + _robust_rmtree(version_dir) + store.delete_engine_version(engine, version_str) + logger.info(f"Deleted version: {version_str}") + return {"message": f"Deleted version {version_str}"} except Exception as e: - logger.error(f"Failed to delete version {version.version}: {e}") + logger.error(f"Failed to delete version {version_str}: {e}") raise HTTPException(status_code=500, detail=f"Failed to delete version: {e}") diff --git a/backend/routes/lmdeploy.py b/backend/routes/lmdeploy.py index 946096f..d24e5cf 100644 --- a/backend/routes/lmdeploy.py +++ b/backend/routes/lmdeploy.py @@ -1,51 +1,82 @@ -from typing import Dict, Optional - -from fastapi import APIRouter, HTTPException - -from backend.lmdeploy_installer import get_lmdeploy_installer -from backend.lmdeploy_manager import get_lmdeploy_manager - -router = APIRouter() - - -@router.get("/lmdeploy/status") -async def lmdeploy_installer_status() -> Dict: - installer = get_lmdeploy_installer() - return installer.status() - - -@router.post("/lmdeploy/install") -async def lmdeploy_install(request: Optional[Dict[str, str]] = None) -> Dict: - installer = get_lmdeploy_installer() - payload = request or {} - version = payload.get("version") - force_reinstall = bool(payload.get("force_reinstall")) - try: - return await installer.install(version=version, force_reinstall=force_reinstall) - except RuntimeError as exc: - raise HTTPException(status_code=409, detail=str(exc)) - - -@router.post("/lmdeploy/remove") -async def lmdeploy_remove() -> Dict: - installer = get_lmdeploy_installer() - try: - return await installer.remove() - except RuntimeError as exc: - raise HTTPException(status_code=409, detail=str(exc)) - - -@router.get("/lmdeploy/logs") -async def lmdeploy_logs(max_bytes: int = 8192) -> Dict[str, str]: - """Get LMDeploy installer logs.""" - installer = get_lmdeploy_installer() - max_bytes = max(1024, min(max_bytes, 1024 * 1024)) - return {"log": installer.read_log_tail(max_bytes)} - - -@router.get("/lmdeploy/runtime-logs") -async def lmdeploy_runtime_logs(max_bytes: int = 8192) -> Dict[str, str]: - """Get LMDeploy runtime logs (from running server instances).""" - manager = get_lmdeploy_manager() - max_bytes = max(1024, min(max_bytes, 1024 * 1024)) - return {"log": manager.read_log_tail(max_bytes)} +from typing import Dict, Optional + +import httpx +from fastapi import APIRouter, HTTPException + +from backend.lmdeploy_installer import get_lmdeploy_installer +from backend.lmdeploy_manager import get_lmdeploy_manager + +router = APIRouter() + + +@router.get("/lmdeploy/check-updates") +async def lmdeploy_check_updates() -> Dict: + """Check PyPI for latest LMDeploy version.""" + try: + async with httpx.AsyncClient() as client: + r = await client.get("https://pypi.org/pypi/lmdeploy/json", timeout=10.0) + r.raise_for_status() + data = r.json() + info = data.get("info", {}) + return { + "latest_version": info.get("version"), + "releases": list(data.get("releases", {}).keys()), + } + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Failed to check PyPI: {exc}") + + +@router.get("/lmdeploy/status") +async def lmdeploy_installer_status() -> Dict: + installer = get_lmdeploy_installer() + return installer.status() + + +@router.post("/lmdeploy/install") +async def lmdeploy_install(request: Optional[Dict[str, str]] = None) -> Dict: + installer = get_lmdeploy_installer() + payload = request or {} + version = payload.get("version") + force_reinstall = bool(payload.get("force_reinstall")) + try: + return await installer.install(version=version, force_reinstall=force_reinstall) + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) + + +@router.post("/lmdeploy/install-source") +async def lmdeploy_install_source(request: Optional[Dict[str, str]] = None) -> Dict: + """Install LMDeploy from a git repo and branch (for development).""" + installer = get_lmdeploy_installer() + payload = request or {} + repo_url = payload.get("repo_url", "https://github.com/InternLM/lmdeploy.git") + branch = payload.get("branch", "main") + try: + return await installer.install_from_source(repo_url=repo_url, branch=branch) + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) + + +@router.post("/lmdeploy/remove") +async def lmdeploy_remove() -> Dict: + installer = get_lmdeploy_installer() + try: + return await installer.remove() + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) + + +@router.get("/lmdeploy/logs") +async def lmdeploy_logs(max_bytes: int = 8192) -> Dict[str, str]: + """Get LMDeploy installer logs.""" + installer = get_lmdeploy_installer() + max_bytes = max(1024, min(max_bytes, 1024 * 1024)) + return {"log": installer.read_log_tail(max_bytes)} + + +@router.get("/lmdeploy/runtime-logs") +async def lmdeploy_runtime_logs(max_bytes: int = 8192) -> Dict[str, str]: + """Get LMDeploy runtime logs (from running server instances).""" + manager = get_lmdeploy_manager() + max_bytes = max(1024, min(max_bytes, 1024 * 1024)) + return {"log": manager.read_log_tail(max_bytes)} diff --git a/backend/routes/models.py b/backend/routes/models.py index e9eee78..5859db4 100644 --- a/backend/routes/models.py +++ b/backend/routes/models.py @@ -1,6 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks -from sqlalchemy.orm import Session -from sqlalchemy import or_ +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query from typing import List, Optional, Dict, Any, Tuple from pydantic import BaseModel import json @@ -10,17 +8,12 @@ import re from datetime import datetime -from backend.database import ( - get_db, - Model, - RunningInstance, - generate_proxy_name, - LlamaVersion, -) +from backend.data_store import get_store, generate_proxy_name +from backend.progress_manager import get_progress_manager from backend.huggingface import ( search_models, download_model, - download_model_with_websocket_progress, + download_model_with_progress, set_huggingface_token, get_huggingface_token, get_model_details, @@ -40,27 +33,21 @@ DEFAULT_LMDEPLOY_CONTEXT, MAX_LMDEPLOY_CONTEXT, MAX_ROPE_SCALING_FACTOR, -) -from backend.smart_auto import SmartAutoConfig -from backend.smart_auto.model_metadata import get_model_metadata -from backend.smart_auto.architecture_config import ( - normalize_architecture, - detect_architecture_from_name, + get_model_disk_size, + get_accurate_file_sizes, + get_mmproj_f16_filename, ) from backend.gpu_detector import get_gpu_info from backend.gguf_reader import get_model_layer_info -from backend.presets import get_architecture_and_presets from backend.logging_config import get_logger logger = get_logger(__name__) from backend.llama_swap_config import get_supported_flags -from backend.logging_config import get_logger from backend.lmdeploy_manager import get_lmdeploy_manager from backend.lmdeploy_installer import get_lmdeploy_installer import psutil router = APIRouter() -logger = get_logger(__name__) # Common embedding indicators for automatic detection EMBEDDING_PIPELINE_TAGS = { @@ -100,19 +87,76 @@ def _looks_like_embedding_model( return any(keyword in combined for keyword in EMBEDDING_KEYWORDS) -def _model_is_embedding(model: Model) -> bool: +def _model_is_embedding(model: dict) -> bool: """Determine if a stored model should run in embedding mode.""" - config = _coerce_model_config(model.config) + config = _coerce_model_config(model.get("config")) if config.get("embedding"): return True return _looks_like_embedding_model( - model.pipeline_tag, - model.huggingface_id, - model.name, - model.base_model_name, + model.get("pipeline_tag"), + model.get("huggingface_id"), + model.get("display_name") or model.get("name"), + model.get("base_model_name"), ) +def _get_model_or_404(store, model_id: str) -> dict: + """Return model dict from store or raise 404. Accepts str model_id (YAML id).""" + if model_id is None: + raise HTTPException(status_code=404, detail="Model not found") + model_id = str(model_id) + model = store.get_model(model_id) + if not model: + raise HTTPException(status_code=404, detail="Model not found") + return model + + +def _get_actual_file_size(file_path: Optional[str]) -> Optional[int]: + """Return actual file size in bytes from disk, or None if not available.""" + if not file_path: + return None + path = _normalize_model_path(file_path) + if not path or not os.path.exists(path): + return None + try: + real = os.path.realpath(path) + return os.path.getsize(real if os.path.exists(real) else path) + except OSError: + return None + + +def _get_model_filename(model: dict) -> Optional[str]: + """Return the filename for a model record. + + Prefers the dedicated ``filename`` field (new records). Falls back to + deriving it from the legacy ``file_path`` field (old records). + """ + fname = model.get("filename") + if fname: + return fname + return _extract_filename(model.get("file_path")) or None + + +def _get_model_file_path(model: dict) -> Optional[str]: + """Return the actual filesystem path for a model file. + + Resolution order: + 1. HF cache via huggingface_id + filename (new records). + 2. Stored file_path (legacy records that still reference custom storage). + """ + from backend.huggingface import resolve_cached_model_path + + hf_id = model.get("huggingface_id") + filename = _get_model_filename(model) + + if hf_id and filename: + cached = resolve_cached_model_path(hf_id, filename) + if cached: + return cached + + return _normalize_model_path(model.get("file_path")) or None + + def _normalize_model_path(file_path: Optional[str]) -> Optional[str]: if not file_path: return None @@ -129,45 +173,28 @@ def _extract_filename(file_path: Optional[str]) -> str: return parts[-1] if parts else normalized -def _cleanup_model_folder_if_no_quantizations( - db: Session, - huggingface_id: Optional[str], - model_format: Optional[str], -) -> None: - """ - If there are no remaining quantizations for a given Hugging Face repo and format, - delete the corresponding local model folder (e.g. data/models/gguf/). - """ - if not huggingface_id or not model_format: - return - - model_format = (model_format or "").lower() - if model_format not in ("gguf", "safetensors"): - return +def normalize_architecture(raw_architecture: str) -> str: + """Normalize GGUF architecture string (stub after smart_auto removal).""" + if not raw_architecture or not isinstance(raw_architecture, str): + return "unknown" + return raw_architecture.strip() or "unknown" - # Check for remaining models of this repo/format, excluding any pending deletions - remaining = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, - Model.model_format == model_format, - ) - .count() - ) - if remaining > 0: - return - safe_repo = (huggingface_id or "unknown").replace("/", "_") or "unknown" - base_dir = os.path.join("data", "models", model_format) - repo_dir = os.path.join(base_dir, safe_repo) +def detect_architecture_from_name(name: str) -> str: + """Infer architecture from model name (stub after smart_auto removal).""" + if not name or not isinstance(name, str): + return "unknown" + name_lower = name.lower() + if "llama" in name_lower: + return "llama" + if "qwen" in name_lower: + return "qwen2" + if "mistral" in name_lower: + return "mistral" + if "phi" in name_lower: + return "phi-2" + return "unknown" - if os.path.isdir(repo_dir): - try: - if not os.listdir(repo_dir): - os.rmdir(repo_dir) - logger.info(f"Removed empty model folder: {repo_dir}") - except Exception as exc: - logger.warning(f"Failed to remove model folder {repo_dir}: {exc}") def _derive_hf_defaults(metadata: Dict[str, Any]) -> Dict[str, Any]: @@ -206,13 +233,13 @@ def _assign_numeric(src_key: str, dest_keys): return defaults -def _apply_hf_defaults_to_model(model: Model, metadata: Dict[str, Any], db: Session): +def _apply_hf_defaults_to_model(model: dict, metadata: Dict[str, Any], store) -> None: if not metadata: return defaults = _derive_hf_defaults(metadata) if not defaults: return - config = _coerce_model_config(model.config) + config = _coerce_model_config(model.get("config")) changed = False for key, value in defaults.items(): if value is None: @@ -222,9 +249,7 @@ def _apply_hf_defaults_to_model(model: Model, metadata: Dict[str, Any], db: Sess config[key] = value changed = True if changed: - model.config = config - db.commit() - db.refresh(model) + store.update_model(model["id"], {"config": config}) def _coerce_model_config(config_value: Optional[Any]) -> Dict[str, Any]: @@ -242,12 +267,12 @@ def _coerce_model_config(config_value: Optional[Any]) -> Dict[str, Any]: return {} -def _refresh_model_metadata_from_file(model: Model, db: Session) -> Dict[str, Any]: +def _refresh_model_metadata_from_file(model: dict, store) -> Dict[str, Any]: """ - Re-read GGUF metadata from disk and update the model record similar to the refresh endpoint. + Re-read GGUF metadata from disk and update the model record. Returns metadata details for downstream consumers. """ - normalized_path = _normalize_model_path(model.file_path) + normalized_path = _get_model_file_path(model) if not normalized_path or not os.path.exists(normalized_path): raise FileNotFoundError("Model file not found on disk") @@ -259,26 +284,19 @@ def _refresh_model_metadata_from_file(model: Model, db: Session) -> Dict[str, An normalized_architecture = normalize_architecture(raw_architecture) if not normalized_architecture or normalized_architecture == "unknown": normalized_architecture = detect_architecture_from_name( - model.name or model.huggingface_id or "" + model.get("display_name") or model.get("name") or model.get("huggingface_id") or "" ) update_fields = {} if ( normalized_architecture and normalized_architecture != "unknown" - and normalized_architecture != model.model_type + and normalized_architecture != model.get("model_type") ): update_fields["model_type"] = normalized_architecture - file_size = os.path.getsize(model.file_path) - if file_size != model.file_size: - update_fields["file_size"] = file_size - if update_fields: - for key, value in update_fields.items(): - setattr(model, key, value) - db.commit() - db.refresh(model) + store.update_model(model["id"], update_fields) return { "updated_fields": update_fields, @@ -286,9 +304,7 @@ def _refresh_model_metadata_from_file(model: Model, db: Session) -> Dict[str, An "architecture": normalized_architecture, "layer_count": layer_info.get("layer_count", 0), "context_length": layer_info.get("context_length", 0), - "parameter_count": layer_info.get( - "parameter_count" - ), # Formatted as "32B", "36B", etc. + "parameter_count": layer_info.get("parameter_count"), "vocab_size": layer_info.get("vocab_size", 0), "embedding_length": layer_info.get("embedding_length", 0), "attention_head_count": layer_info.get("attention_head_count", 0), @@ -419,97 +435,68 @@ def _coerce_positive_float(value: Any) -> Optional[float]: async def _save_safetensors_download( - db: Session, + store, huggingface_id: str, filename: str, file_path: str, file_size: int, pipeline_tag: Optional[str] = None, -) -> Model: +) -> dict: """ - Persist safetensors download information using a single logical Model row per repo. - - Historically we created one Model row per .safetensors file. This caused - multi‑file repositories to appear as multiple independent models. The new - behavior is: - * Exactly one Model row per Hugging Face repo (huggingface_id) with - model_format == "safetensors". - * All individual .safetensors files for that repo are tracked in the - safetensors manifest and share the same model_id. - * The logical Model.file_size reflects the aggregate size of all files. + Persist safetensors download information using a single logical model entry per repo. + Returns the model dict with "id" (string, YAML model id). """ safetensors_metadata, tensor_summary, max_context = ( await _collect_safetensors_runtime_metadata(huggingface_id, filename) ) - # Determine / reuse logical Model for this Hugging Face repo detected_pipeline = pipeline_tag or safetensors_metadata.get("pipeline_tag") is_embedding_like = _looks_like_embedding_model( detected_pipeline, huggingface_id, filename ) - - # Try to find an existing logical model for this repo - model_record = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, Model.model_format == "safetensors" - ) - .first() - ) + model_id = huggingface_id.replace("/", "--") + model_record = store.get_model(model_id) if not model_record: - # Create a single logical model entry for the whole repo - model_record = Model( - name=filename.replace(".safetensors", ""), - huggingface_id=huggingface_id, - base_model_name=extract_base_model_name(filename), - file_path=file_path, - file_size=file_size, - quantization=os.path.splitext(filename)[0], - model_type=extract_model_type(filename), - downloaded_at=datetime.utcnow(), - model_format="safetensors", - pipeline_tag=detected_pipeline, - ) - if is_embedding_like: - model_record.config = {"embedding": True} - db.add(model_record) - db.commit() - db.refresh(model_record) + from datetime import timezone as _tz + model_record = { + "id": model_id, + "huggingface_id": huggingface_id, + "filename": filename, + "display_name": filename.replace(".safetensors", ""), + "base_model_name": extract_base_model_name(filename), + "file_size": file_size, + "quantization": os.path.splitext(filename)[0], + "model_type": extract_model_type(filename), + "downloaded_at": datetime.now(_tz.utc).isoformat(), + "format": "safetensors", + "model_format": "safetensors", + "pipeline_tag": detected_pipeline, + "config": {"embedding": True} if is_embedding_like else {}, + } + store.add_model(model_record) else: - # Update existing logical model with any missing metadata and aggregate size - updated = False - if not model_record.pipeline_tag and detected_pipeline: - model_record.pipeline_tag = detected_pipeline - updated = True - if is_embedding_like and not (model_record.config or {}).get("embedding"): - # Ensure embedding flag is propagated - current_config = _coerce_model_config(model_record.config) - current_config["embedding"] = True - model_record.config = current_config - updated = True - # Aggregate size across all files for this repo by summing manifest entries. - # This avoids double‑counting if a file is redownloaded. + updates = {} + if not model_record.get("pipeline_tag") and detected_pipeline: + updates["pipeline_tag"] = detected_pipeline + if is_embedding_like and not _coerce_model_config(model_record.get("config")).get("embedding"): + cfg = _coerce_model_config(model_record.get("config")) + cfg["embedding"] = True + updates["config"] = cfg try: from backend.huggingface import list_safetensors_downloads - manifests = list_safetensors_downloads() total_size = 0 for manifest in manifests: if manifest.get("huggingface_id") == huggingface_id: - total_size = sum( - (f.get("file_size") or 0) for f in manifest.get("files", []) - ) + total_size = sum((f.get("file_size") or 0) for f in manifest.get("files", [])) break - if total_size and total_size != (model_record.file_size or 0): - model_record.file_size = total_size - updated = True + if total_size and total_size != (model_record.get("file_size") or 0): + updates["file_size"] = total_size except Exception as exc: - logger.warning( - f"Failed to aggregate safetensors file sizes for {huggingface_id}: {exc}" - ) - if updated: - db.commit() - db.refresh(model_record) + logger.warning(f"Failed to aggregate safetensors file sizes for {huggingface_id}: {exc}") + if updates: + store.update_model(model_id, updates) + model_record = store.get_model(model_id) or model_record lmdeploy_config = get_default_lmdeploy_config(max_context) record_safetensors_download( @@ -520,33 +507,28 @@ async def _save_safetensors_download( metadata=safetensors_metadata, tensor_summary=tensor_summary, lmdeploy_config=lmdeploy_config, - model_id=model_record.id, - ) - logger.info( - f"Safetensors download recorded for {huggingface_id}/{filename} (model_id={model_record.id})" + model_id=model_record.get("id"), ) + logger.info(f"Safetensors download recorded for {huggingface_id}/{filename} (model_id={model_record.get('id')})") return model_record -def _get_safetensors_model(model_id: int, db: Session) -> Model: - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - model_format = (model.model_format or "gguf").lower() +def _get_safetensors_model(store, model_id: str) -> dict: + model = _get_model_or_404(store, model_id) + model_format = (model.get("model_format") or model.get("format") or "gguf").lower() if model_format != "safetensors": - raise HTTPException( - status_code=400, detail="Model is not a safetensors download" - ) - normalized_path = _normalize_model_path(model.file_path) - if not normalized_path or not os.path.exists(normalized_path): + raise HTTPException(status_code=400, detail="Model is not a safetensors download") + resolved_path = _get_model_file_path(model) + if not resolved_path or not os.path.exists(resolved_path): raise HTTPException(status_code=400, detail="Model file not found on disk") - model.file_path = normalized_path + model = dict(model) + model["file_path"] = resolved_path return model -def _load_manifest_entry_for_model(model: Model) -> Dict[str, Any]: +def _load_manifest_entry_for_model(model: dict) -> Dict[str, Any]: """Load unified manifest for a safetensors model (repo-level, not per-file).""" - manifest = get_safetensors_manifest_entries(model.huggingface_id) + manifest = get_safetensors_manifest_entries(model.get("huggingface_id")) if not manifest: raise HTTPException(status_code=404, detail="Safetensors manifest not found") return manifest @@ -1143,7 +1125,7 @@ def _as_list(key: str) -> list: class BundleProgressProxy: - """Proxy websocket manager that converts per-file progress into bundle-level updates.""" + """Proxy progress manager that converts per-file progress into bundle-level updates.""" def __init__( self, @@ -1247,7 +1229,7 @@ async def get_cached_gpu_info() -> Dict[str, Any]: class EstimationRequest(BaseModel): - model_id: int + model_id: str # YAML model id config: dict usage_mode: Optional[str] = "single_user" @@ -1258,76 +1240,82 @@ class SafetensorsBundleRequest(BaseModel): files: List[Dict[str, Any]] +@router.get("/param-registry") +async def get_param_registry_endpoint(engine: str = "llama_cpp"): + """Return param definitions (basic + advanced) for config forms.""" + from backend.param_registry import get_param_registry + return get_param_registry(engine) + + @router.get("") @router.get("/") -async def list_models(db: Session = Depends(get_db)): +async def list_models(): """List all managed models grouped by base model""" - # Sync is_active status before returning models - from backend.database import sync_model_active_status - - sync_model_active_status(db) + from backend.llama_swap_client import LlamaSwapClient - models = ( - db.query(Model) - .filter(or_(Model.model_format.is_(None), Model.model_format == "gguf")) - .all() - ) + store = get_store() + models = [m for m in store.list_models() if (m.get("format") or m.get("model_format") or "gguf") == "gguf"] + try: + running_data = await LlamaSwapClient().get_running_models() + running_list = running_data.get("running") or [] + running_names = {item.get("model") for item in running_list if item.get("state") in ("running", "ready")} + except Exception: + running_names = set() - # Group models by huggingface_id and base_model_name grouped_models = {} for model in models: + hf_id = model.get("huggingface_id") or "" + base_name = model.get("base_model_name") or (hf_id.split("/")[-1] if hf_id else model.get("display_name") or "unknown") + proxy_name = generate_proxy_name(hf_id, model.get("quantization")) + is_active = proxy_name in running_names is_embedding = _model_is_embedding(model) - key = f"{model.huggingface_id}_{model.base_model_name}" + key = f"{hf_id}_{base_name}" if key not in grouped_models: - # derive author/owner from huggingface_id - hf_id = model.huggingface_id or "" - author = ( - hf_id.split("/")[0] if isinstance(hf_id, str) and "/" in hf_id else "" - ) + author = hf_id.split("/")[0] if isinstance(hf_id, str) and "/" in hf_id else "" grouped_models[key] = { - "base_model_name": model.base_model_name, - "huggingface_id": model.huggingface_id, - "model_type": model.model_type, + "base_model_name": base_name, + "huggingface_id": hf_id, + "model_type": model.get("model_type"), "author": author, - "pipeline_tag": model.pipeline_tag, + "pipeline_tag": model.get("pipeline_tag"), "is_embedding_model": is_embedding, "quantizations": [], } else: - if model.pipeline_tag and not grouped_models[key].get("pipeline_tag"): - grouped_models[key]["pipeline_tag"] = model.pipeline_tag + if model.get("pipeline_tag") and not grouped_models[key].get("pipeline_tag"): + grouped_models[key]["pipeline_tag"] = model.get("pipeline_tag") if is_embedding and not grouped_models[key].get("is_embedding_model"): grouped_models[key]["is_embedding_model"] = True - grouped_models[key]["quantizations"].append( - { - "id": model.id, - "name": model.name, - "file_path": model.file_path, - "file_size": model.file_size, - "quantization": model.quantization, - "downloaded_at": model.downloaded_at, - "is_active": model.is_active, - "has_config": bool(model.config), - "huggingface_id": model.huggingface_id, - "base_model_name": model.base_model_name, - "model_type": model.model_type, - "config": _coerce_model_config(model.config), - "proxy_name": model.proxy_name, - "pipeline_tag": model.pipeline_tag, - "is_embedding_model": is_embedding, - } - ) + # Resolve actual disk size: prefer HF cache, fall back to stored value + resolved_path = _get_model_file_path(model) + file_size = _get_actual_file_size(resolved_path) or model.get("file_size") or 0 + + grouped_models[key]["quantizations"].append({ + "id": model.get("id"), + "name": model.get("display_name") or model.get("name"), + "filename": _get_model_filename(model), + "file_size": file_size, + "quantization": model.get("quantization"), + "format": model.get("format") or model.get("model_format") or "gguf", + "engine": model.get("engine") or "llama_cpp", + "downloaded_at": model.get("downloaded_at"), + "is_active": is_active, + "has_config": bool(model.get("config")), + "huggingface_id": hf_id, + "base_model_name": base_name, + "model_type": model.get("model_type"), + "config": _coerce_model_config(model.get("config")), + "proxy_name": proxy_name, + "pipeline_tag": model.get("pipeline_tag"), + "is_embedding_model": is_embedding, + }) - # Convert to list and sort quantizations by file size (smallest first) result = [] for group in grouped_models.values(): - group["quantizations"].sort(key=lambda x: x["file_size"] or 0) + group["quantizations"].sort(key=lambda x: x.get("file_size") or 0) result.append(group) - - # Sort groups by base model name - result.sort(key=lambda x: x["base_model_name"]) - + result.sort(key=lambda x: x.get("base_model_name") or "") return result @@ -1363,6 +1351,19 @@ async def clear_search_cache_endpoint(): raise HTTPException(status_code=500, detail=str(e)) +@router.get("/search/{model_id:path}/file-sizes") +async def get_search_file_sizes( + model_id: str, + filenames: str = Query(..., description="Comma-separated list of file paths in the repo"), +): + """Get accurate file sizes for specific files in a repo via HuggingFace API.""" + file_list = [f.strip() for f in filenames.split(",") if f.strip()] + if not file_list: + raise HTTPException(status_code=400, detail="At least one filename is required") + sizes = get_accurate_file_sizes(model_id, file_list) + return {"sizes": sizes} + + @router.get("/search/{model_id}/details") async def get_model_details_endpoint(model_id: str): """Get detailed model information including config and architecture""" @@ -1419,58 +1420,43 @@ async def list_safetensors_models(): @router.delete("/safetensors") -async def delete_safetensors_model(request: dict, db: Session = Depends(get_db)): +async def delete_safetensors_model(request: dict): """Delete entire safetensors model (all files for the repo).""" try: huggingface_id = request.get("huggingface_id") if not huggingface_id: raise HTTPException(status_code=400, detail="huggingface_id is required") - # Prevent deletion while runtime is active for this logical model - active_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .first() - ) - target_model = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, - Model.model_format == "safetensors", - ) - .first() - ) - if ( - active_instance - and target_model - and active_instance.model_id == target_model.id - ): - raise HTTPException( - status_code=400, - detail="Cannot delete a model currently served by LMDeploy", - ) + store = get_store() + model_id = huggingface_id.replace("/", "--") + target_model = store.get_model(model_id) + if not target_model or (target_model.get("format") or target_model.get("model_format")) != "safetensors": + raise HTTPException(status_code=404, detail="Safetensors model not found") + + manager = get_lmdeploy_manager() + status = manager.status() + if status.get("running"): + current = status.get("current_instance") or {} + if str(current.get("model_id")) == str(model_id): + raise HTTPException( + status_code=400, + detail="Cannot delete a model currently served by LMDeploy", + ) - # Get unified manifest and delete all files from backend.huggingface import ( get_safetensors_manifest_entries, delete_safetensors_download, ) - manifest = get_safetensors_manifest_entries(huggingface_id) if not manifest or not manifest.get("files"): raise HTTPException(status_code=404, detail="Safetensors model not found") - # Delete all files in the unified manifest for file_entry in manifest.get("files", []): entry_filename = file_entry.get("filename") if entry_filename: delete_safetensors_download(huggingface_id, entry_filename) - # Delete the single logical Model row - if target_model: - db.delete(target_model) - db.commit() - + store.delete_model(model_id) return {"message": f"Safetensors model {huggingface_id} deleted"} except HTTPException: raise @@ -1479,8 +1465,8 @@ async def delete_safetensors_model(request: dict, db: Session = Depends(get_db)) @router.post("/safetensors/reload-from-disk") -async def reload_safetensors_from_disk(db: Session = Depends(get_db)): - """Reset all safetensors database entries and reload them from file storage.""" +async def reload_safetensors_from_disk(): + """Reset all safetensors store entries and reload them from file storage.""" try: from backend.huggingface import ( SAFETENSORS_DIR, @@ -1488,27 +1474,22 @@ async def reload_safetensors_from_disk(db: Session = Depends(get_db)): get_default_lmdeploy_config, ) - # Prevent reload while runtime is active - active_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .first() - ) - if active_instance: + manager = get_lmdeploy_manager() + if manager.status().get("running"): raise HTTPException( status_code=400, detail="Cannot reload safetensors models while LMDeploy runtime is active. Please stop the runtime first.", ) - # Delete all existing safetensors Model entries - safetensors_models = ( - db.query(Model).filter(Model.model_format == "safetensors").all() - ) + store = get_store() + safetensors_models = [ + m for m in store.list_models() + if (m.get("format") or m.get("model_format")) == "safetensors" + ] deleted_count = len(safetensors_models) for model in safetensors_models: - db.delete(model) - db.commit() - logger.info(f"Deleted {deleted_count} safetensors model entries from database") + store.delete_model(model.get("id")) + logger.info(f"Deleted {deleted_count} safetensors model entries from store") # Delete all existing manifest files to regenerate from HuggingFace with defaults from backend.huggingface import _get_manifest_path @@ -1567,99 +1548,25 @@ async def reload_safetensors_from_disk(db: Session = Depends(get_db)): if not safetensors_files: continue - # Process each file to rebuild database entries - model_record = None + # Process each file to rebuild store entries (one model per repo via _save_safetensors_download) for file_info in safetensors_files: try: filename = file_info["filename"] file_path = file_info["file_path"] file_size = file_info["file_size"] - - # Collect metadata (this will also create/update the manifest) - safetensors_metadata, tensor_summary, max_context = ( - await _collect_safetensors_runtime_metadata( - huggingface_id, filename - ) - ) - - # Get or create model record (one per repo) - if not model_record: - detected_pipeline = safetensors_metadata.get("pipeline_tag") - is_embedding_like = _looks_like_embedding_model( - detected_pipeline, huggingface_id, filename - ) - - model_record = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, - Model.model_format == "safetensors", - ) - .first() - ) - - if not model_record: - model_record = Model( - name=filename.replace(".safetensors", ""), - huggingface_id=huggingface_id, - base_model_name=extract_base_model_name(filename), - file_path=file_path, # Use first file's path - file_size=0, # Will be aggregated below - quantization=os.path.splitext(filename)[0], - model_type=extract_model_type(filename), - downloaded_at=datetime.utcnow(), - model_format="safetensors", - pipeline_tag=detected_pipeline, - ) - if is_embedding_like: - model_record.config = {"embedding": True} - db.add(model_record) - db.commit() - db.refresh(model_record) - - # Record file in manifest - lmdeploy_config = get_default_lmdeploy_config(max_context) - record_safetensors_download( - huggingface_id=huggingface_id, - filename=filename, - file_path=file_path, - file_size=file_size, - metadata=safetensors_metadata, - tensor_summary=tensor_summary, - lmdeploy_config=lmdeploy_config, - model_id=model_record.id, + await _save_safetensors_download( + store, + huggingface_id, + filename, + file_path, + file_size, ) - except Exception as exc: error_msg = f"Failed to reload {huggingface_id}/{file_info.get('filename', 'unknown')}: {exc}" logger.error(error_msg) errors.append(error_msg) continue - - # Update model record with aggregated size - if model_record: - try: - from backend.huggingface import list_safetensors_downloads - - manifests = list_safetensors_downloads() - total_size = 0 - for manifest in manifests: - if manifest.get("huggingface_id") == huggingface_id: - total_size = sum( - (f.get("file_size") or 0) - for f in manifest.get("files", []) - ) - break - if total_size: - model_record.file_size = total_size - db.commit() - db.refresh(model_record) - except Exception as exc: - logger.warning( - f"Failed to update aggregate size for {huggingface_id}: {exc}" - ) - - reloaded_count += 1 + reloaded_count += 1 result = { "message": f"Reloaded {reloaded_count} safetensors models from disk", @@ -1680,10 +1587,11 @@ async def reload_safetensors_from_disk(db: Session = Depends(get_db)): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/safetensors/{model_id}/lmdeploy/config") -async def get_lmdeploy_config_endpoint(model_id: int, db: Session = Depends(get_db)): +@router.get("/safetensors/{model_id:path}/lmdeploy/config") +async def get_lmdeploy_config_endpoint(model_id: str): """Return stored LMDeploy config and metadata for a safetensors model.""" - model = _get_safetensors_model(model_id, db) + store = get_store() + model = _get_safetensors_model(store, model_id) manifest_entry = _load_manifest_entry_for_model(model) metadata = manifest_entry.get("metadata") or {} tensor_summary = manifest_entry.get("tensor_summary") or {} @@ -1705,28 +1613,26 @@ async def get_lmdeploy_config_endpoint(model_id: int, db: Session = Depends(get_ } -@router.put("/safetensors/{model_id}/lmdeploy/config") -async def update_lmdeploy_config_endpoint( - model_id: int, request: Dict[str, Any], db: Session = Depends(get_db) -): +@router.put("/safetensors/{model_id:path}/lmdeploy/config") +async def update_lmdeploy_config_endpoint(model_id: str, request: Dict[str, Any]): """Persist LMDeploy configuration changes for a safetensors model.""" - model = _get_safetensors_model(model_id, db) + store = get_store() + model = _get_safetensors_model(store, model_id) manifest_entry = _load_manifest_entry_for_model(model) validated_config = _validate_lmdeploy_config(request, manifest_entry) - updated_entry = update_lmdeploy_config(model.huggingface_id, validated_config) + updated_entry = update_lmdeploy_config(model.get("huggingface_id"), validated_config) return { "config": updated_entry.get("lmdeploy", {}).get("config", validated_config), "updated_at": updated_entry.get("lmdeploy", {}).get("updated_at"), } -@router.post("/safetensors/{model_id}/metadata/regenerate") -async def regenerate_safetensors_metadata_endpoint( - model_id: int, db: Session = Depends(get_db) -): +@router.post("/safetensors/{model_id:path}/metadata/regenerate") +async def regenerate_safetensors_metadata_endpoint(model_id: str): """Refresh safetensors metadata/manifest entries without redownloading files.""" - model = _get_safetensors_model(model_id, db) - huggingface_id = model.huggingface_id + store = get_store() + model = _get_safetensors_model(store, model_id) + huggingface_id = model.get("huggingface_id") manifest = get_safetensors_manifest_entries(huggingface_id) if not manifest or not manifest.get("files"): raise HTTPException( @@ -1813,7 +1719,7 @@ async def regenerate_safetensors_metadata_endpoint( @router.get("/safetensors/lmdeploy/status") -async def get_lmdeploy_status(db: Session = Depends(get_db)): +async def get_lmdeploy_status(): """Return LMDeploy runtime status and running instance info.""" installer = get_lmdeploy_installer() installer_status = installer.status() @@ -1829,45 +1735,18 @@ async def get_lmdeploy_status(db: Session = Depends(get_db)): ) manager = get_lmdeploy_manager() - installer = get_lmdeploy_installer() manager_status = manager.status() - # Only return running_instance if LMDeploy is actually running + # Use manager's in-memory current_instance (no DB) instance_payload = None if manager_status.get("running"): - running_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .first() - ) - if running_instance: + current_instance = manager_status.get("current_instance") + if current_instance: instance_payload = { - "model_id": running_instance.model_id, - "started_at": ( - running_instance.started_at.isoformat() - if running_instance.started_at - else None - ), - "config": ( - json.loads(running_instance.config) - if running_instance.config - else {} - ), + "model_id": current_instance.get("model_id"), + "started_at": current_instance.get("started_at"), + "config": current_instance.get("config") if isinstance(current_instance.get("config"), dict) else {}, } - else: - # Clean up stale RunningInstance records if LMDeploy is not running - stale_instances = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .all() - ) - if stale_instances: - for instance in stale_instances: - model = db.query(Model).filter(Model.id == instance.model_id).first() - if model: - model.is_active = False - db.delete(instance) - db.commit() return { "manager": manager_status, @@ -1876,27 +1755,25 @@ async def get_lmdeploy_status(db: Session = Depends(get_db)): } -@router.post("/safetensors/{model_id}/lmdeploy/start") +@router.post("/safetensors/{model_id:path}/lmdeploy/start") async def start_lmdeploy_runtime( - model_id: int, + model_id: str, request: Optional[Dict[str, Any]] = None, - db: Session = Depends(get_db), ): """Start LMDeploy runtime for a safetensors model.""" - model = _get_safetensors_model(model_id, db) + store = get_store() + model = _get_safetensors_model(store, model_id) manifest_entry = _load_manifest_entry_for_model(model) requested_config = ( (request or {}).get("config") if isinstance(request, dict) else None ) validated_config = _validate_lmdeploy_config(requested_config, manifest_entry) - existing_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .first() - ) - if existing_instance: - if existing_instance.model_id == model.id: + manager = get_lmdeploy_manager() + status = manager.status() + current_instance = status.get("current_instance") or {} + if status.get("running"): + if current_instance.get("model_id") == model.get("id"): raise HTTPException( status_code=400, detail="LMDeploy is already running for this model" ) @@ -1905,41 +1782,30 @@ async def start_lmdeploy_runtime( detail="Another safetensors model is already running via LMDeploy", ) - manager = get_lmdeploy_manager() - status = manager.status() - current_instance = status.get("current_instance") or {} - if status.get("running") and current_instance.get("model_id") not in ( - None, - model.id, - ): - raise HTTPException( - status_code=400, detail="LMDeploy runtime is already serving another model" - ) - - update_lmdeploy_config(model.huggingface_id, validated_config) + update_lmdeploy_config(model.get("huggingface_id"), validated_config) - from backend.main import websocket_manager - - await websocket_manager.send_model_status_update( - model_id=model.id, - status="starting", - details={ - "runtime": "lmdeploy", - "message": f"Starting LMDeploy for {model.name}", - }, - ) + try: + pm = get_progress_manager() + await pm.send_model_status_update( + model_id=model.get("id"), + status="starting", + details={ + "runtime": "lmdeploy", + "message": f"Starting LMDeploy for {model.get('display_name') or model.get('name')}", + }, + ) + except Exception: + pass try: - # Derive a human-friendly model name for LMDeploy (used by --model-name). - # For unified safetensors models, use the Hugging Face repo id. - display_name = model.huggingface_id or model.base_model_name or model.name - # For unified manifests, use the model directory (contains all files) - model_dir = os.path.dirname(model.file_path) + display_name = model.get("huggingface_id") or model.get("base_model_name") or model.get("display_name") or model.get("name") + resolved_file_path = _get_model_file_path(model) + model_dir = os.path.dirname(resolved_file_path or "") runtime_status = await manager.start( { - "model_id": model.id, - "huggingface_id": model.huggingface_id, - "file_path": model.file_path, + "model_id": model.get("id"), + "huggingface_id": model.get("huggingface_id"), + "file_path": resolved_file_path, "model_dir": model_dir, "model_name": display_name, "display_name": display_name, @@ -1947,81 +1813,61 @@ async def start_lmdeploy_runtime( validated_config, ) except Exception as exc: - await websocket_manager.send_model_status_update( - model_id=model.id, - status="error", - details={"runtime": "lmdeploy", "message": str(exc)}, - ) + try: + await get_progress_manager().send_model_status_update( + model_id=model.get("id"), + status="error", + details={"runtime": "lmdeploy", "message": str(exc)}, + ) + except Exception: + pass raise HTTPException(status_code=500, detail=str(exc)) - running_instance = RunningInstance( - model_id=model.id, - llama_version="lmdeploy", - proxy_model_name=f"lmdeploy::{model.id}", - started_at=datetime.utcnow(), - config=json.dumps({"lmdeploy": validated_config}), - runtime_type="lmdeploy", - ) - db.add(running_instance) - model.is_active = True - db.commit() - - from backend.unified_monitor import unified_monitor - - await unified_monitor._collect_and_send_unified_data() - await websocket_manager.send_model_status_update( - model_id=model.id, - status="running", - details={"runtime": "lmdeploy", "message": "LMDeploy is ready"}, - ) + try: + await get_progress_manager().send_model_status_update( + model_id=model.get("id"), + status="running", + details={"runtime": "lmdeploy", "message": "LMDeploy is ready"}, + ) + except Exception: + pass return {"manager": runtime_status, "config": validated_config} -@router.post("/safetensors/{model_id}/lmdeploy/stop") -async def stop_lmdeploy_runtime(model_id: int, db: Session = Depends(get_db)): +@router.post("/safetensors/{model_id:path}/lmdeploy/stop") +async def stop_lmdeploy_runtime(model_id: str): """Stop the LMDeploy runtime if it is running.""" - running_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.runtime_type == "lmdeploy") - .first() - ) - if not running_instance: + manager = get_lmdeploy_manager() + status = manager.status() + if not status.get("running"): raise HTTPException(status_code=404, detail="No LMDeploy runtime is active") - if running_instance.model_id != model_id: + current_instance = status.get("current_instance") or {} + if str(current_instance.get("model_id")) != str(model_id): raise HTTPException( status_code=400, detail="A different model is currently running in LMDeploy" ) - manager = get_lmdeploy_manager() try: await manager.stop() except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) - db.delete(running_instance) - model = db.query(Model).filter(Model.id == model_id).first() - if model: - model.is_active = False - db.commit() - - from backend.unified_monitor import unified_monitor - - await unified_monitor._collect_and_send_unified_data() - from backend.main import websocket_manager - - await websocket_manager.send_model_status_update( - model_id=model_id, - status="stopped", - details={"runtime": "lmdeploy", "message": "LMDeploy runtime stopped"}, - ) + try: + await get_progress_manager().send_model_status_update( + model_id=model_id, + status="stopped", + details={"runtime": "lmdeploy", "message": "LMDeploy runtime stopped"}, + ) + except Exception: + pass return {"message": "LMDeploy runtime stopped"} @router.post("/download") async def download_huggingface_model( - request: dict, background_tasks: BackgroundTasks, db: Session = Depends(get_db) + request: dict, background_tasks: BackgroundTasks ): """Download model from HuggingFace""" try: @@ -2053,17 +1899,12 @@ async def download_huggingface_model( detail="filename must end with .safetensors for Safetensors downloads", ) - # Check if this specific quantization already exists in database + store = get_store() + # Check if this specific quantization already exists if model_format == "gguf": - existing = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, - Model.name == filename.replace(".gguf", ""), - ) - .first() - ) - if existing: + quantization = _extract_quantization(filename) + model_id = f"{huggingface_id.replace('/', '--')}--{quantization}" + if store.get_model(model_id): raise HTTPException( status_code=400, detail="This quantization is already downloaded" ) @@ -2100,15 +1941,14 @@ async def download_huggingface_model( "model_format": model_format, } - # Get websocket manager from main app - from backend.main import websocket_manager - - # Start download in background (REMOVE db parameter, pass task_id) + # Start download in background with progress_manager for SSE + pm = get_progress_manager() + pm.create_task("download", f"Download {filename}", {"huggingface_id": huggingface_id, "filename": filename}, task_id=task_id) background_tasks.add_task( download_model_task, huggingface_id, filename, - websocket_manager, + pm, task_id, total_bytes, model_format, @@ -2171,26 +2011,24 @@ async def set_huggingface_token_endpoint(request: dict): async def download_model_task( huggingface_id: str, filename: str, - websocket_manager=None, + progress_manager=None, task_id: str = None, total_bytes: int = 0, model_format: str = "gguf", pipeline_tag: Optional[str] = None, ): - """Background task to download model with WebSocket progress""" - from backend.database import SessionLocal - - db = SessionLocal() + """Background task to download model with SSE progress""" + store = get_store() try: model_record = None metadata_result = None - if websocket_manager and task_id: - file_path, file_size = await download_model_with_websocket_progress( + if progress_manager and task_id: + file_path, file_size = await download_model_with_progress( huggingface_id, filename, - websocket_manager, + progress_manager, task_id, total_bytes, model_format, @@ -2203,16 +2041,38 @@ async def download_model_task( if model_format == "gguf": model_record, metadata_result = await _record_gguf_download_post_fetch( - db, + store, huggingface_id, filename, file_path, file_size, pipeline_tag=pipeline_tag, ) + # If vision (mmproj) is available, download F16 projector so the model can run with vision + if model_record: + mmproj_filename = get_mmproj_f16_filename(huggingface_id) + if mmproj_filename: + try: + await download_model( + huggingface_id, mmproj_filename, "gguf" + ) + store.update_model( + model_record["id"], {"mmproj_filename": mmproj_filename} + ) + model_record = store.get_model(model_record["id"]) or model_record + if progress_manager and task_id: + await progress_manager.send_notification( + title="Vision extension", + message=f"Downloaded {mmproj_filename} for vision support", + type="info", + ) + except Exception as mmproj_err: + logger.warning( + f"Could not download vision projector {mmproj_filename} for {huggingface_id}: {mmproj_err}" + ) else: model_record = await _save_safetensors_download( - db, + store, huggingface_id, filename, file_path, @@ -2220,20 +2080,21 @@ async def download_model_task( pipeline_tag=pipeline_tag, ) - # Send download complete WebSocket event (NEW) - if websocket_manager: + # Send download complete via SSE + if progress_manager and task_id: + progress_manager.complete_task(task_id, f"Downloaded {filename}") payload = { "type": "download_complete", "huggingface_id": huggingface_id, "filename": filename, "model_format": model_format, - "quantization": model_record.quantization if model_record else None, - "model_id": model_record.id if model_record else None, + "quantization": model_record.get("quantization") if model_record else None, + "model_id": model_record.get("id") if model_record else None, "base_model_name": ( - model_record.base_model_name if model_record else None + model_record.get("base_model_name") if model_record else None ), "pipeline_tag": ( - model_record.pipeline_tag if model_record else pipeline_tag + model_record.get("pipeline_tag") if model_record else pipeline_tag ), "is_embedding_model": ( _model_is_embedding(model_record) if model_record else False @@ -2248,40 +2109,38 @@ async def download_model_task( "file_size": file_size, "file_path": file_path, } - await websocket_manager.broadcast({**payload}) - - await websocket_manager.send_notification( + await progress_manager.broadcast({**payload}) + await progress_manager.send_notification( title="Download Complete", message=f"Successfully downloaded {filename} ({model_format})", type="success", ) except Exception as e: - if websocket_manager: - await websocket_manager.send_notification( + if progress_manager and task_id: + progress_manager.fail_task(task_id, str(e)) + await progress_manager.send_notification( title="Download Failed", message=f"Failed to download {filename}: {str(e)}", type="error", ) finally: - # Cleanup: remove from active downloads and close session if task_id: async with download_lock: active_downloads.pop(task_id, None) - db.close() async def _record_gguf_download_post_fetch( - db: Session, + store, huggingface_id: str, filename: str, file_path: str, file_size: int, pipeline_tag: Optional[str] = None, -) -> Tuple[Model, Optional[Dict[str, Any]]]: +) -> Tuple[dict, Optional[Dict[str, Any]]]: """ - Shared helper to create GGUF Model rows and manifest entries after a file has been downloaded. - Returns (model_record, metadata_result). + Shared helper to create GGUF model entries and manifest after a file has been downloaded. + Returns (model_record dict, metadata_result). """ quantization = _extract_quantization(filename) base_model_name = extract_base_model_name(filename) @@ -2296,89 +2155,68 @@ async def _record_gguf_download_post_fetch( detected_pipeline = "text-embedding" metadata_result: Optional[Dict[str, Any]] = None - # Reuse a single logical Model row per (huggingface_id, quantization) to avoid - # creating one entry per GGUF shard. Additional shards for the same quantization - # simply update size/metadata and are tracked in the GGUF manifest. - model_record = ( - db.query(Model) - .filter( - Model.huggingface_id == huggingface_id, - Model.quantization == quantization, - Model.model_format == "gguf", - ) - .first() - ) + model_id = f"{huggingface_id.replace('/', '--')}--{quantization}" + model_record = store.get_model(model_id) if not model_record: - model_record = Model( - name=filename.replace(".gguf", ""), - huggingface_id=huggingface_id, - base_model_name=base_model_name, - file_path=file_path, - file_size=file_size, - quantization=quantization, - model_type=extract_model_type(filename), - proxy_name=generate_proxy_name(huggingface_id, quantization), - model_format="gguf", - downloaded_at=datetime.utcnow(), - pipeline_tag=detected_pipeline, - ) - if is_embedding_like: - model_record.config = {"embedding": True} - db.add(model_record) - db.commit() - db.refresh(model_record) + from datetime import timezone as _tz + model_record = { + "id": model_id, + "huggingface_id": huggingface_id, + "filename": filename, + "display_name": filename.replace(".gguf", ""), + "base_model_name": base_model_name, + "file_size": file_size, + "quantization": quantization, + "model_type": extract_model_type(filename), + "proxy_name": generate_proxy_name(huggingface_id, quantization), + "format": "gguf", + "model_format": "gguf", + "downloaded_at": datetime.now(_tz.utc).isoformat(), + "pipeline_tag": detected_pipeline, + "config": {"embedding": True} if is_embedding_like else {}, + } + store.add_model(model_record) else: - updated = False - # Keep first file_path as canonical; just update aggregate size. + updates = {} if file_size and file_size > 0: - current_size = model_record.file_size or 0 - model_record.file_size = current_size + file_size - updated = True - if not model_record.pipeline_tag and detected_pipeline: - model_record.pipeline_tag = detected_pipeline - updated = True + current_size = model_record.get("file_size") or 0 + updates["file_size"] = current_size + file_size + if not model_record.get("pipeline_tag") and detected_pipeline: + updates["pipeline_tag"] = detected_pipeline if is_embedding_like: - current_config = _coerce_model_config(model_record.config) + current_config = _coerce_model_config(model_record.get("config")) if not current_config.get("embedding"): current_config["embedding"] = True - model_record.config = current_config - updated = True - if updated: - db.commit() - db.refresh(model_record) + updates["config"] = current_config + if updates: + store.update_model(model_id, updates) + model_record = store.get_model(model_id) or model_record + metadata_result = None try: - metadata_result = _refresh_model_metadata_from_file(model_record, db) + metadata_result = _refresh_model_metadata_from_file(model_record, store) except FileNotFoundError: - logger.warning( - f"Model file missing during metadata refresh for {model_record.id}" - ) + logger.warning(f"Model file missing during metadata refresh for {model_record.get('id')}") except Exception as meta_exc: - logger.warning( - f"Failed to refresh metadata for model {model_record.id}: {meta_exc}" - ) + logger.warning(f"Failed to refresh metadata for model {model_record.get('id')}: {meta_exc}") manifest_entry = None try: manifest_entry = await create_gguf_manifest_entry( - model_record.huggingface_id, + model_record.get("huggingface_id"), file_path, file_size, - model_id=model_record.id, + model_id=model_record.get("id"), ) except Exception as manifest_exc: - logger.warning( - f"Failed to record GGUF manifest entry for {filename}: {manifest_exc}" - ) + logger.warning(f"Failed to record GGUF manifest entry for {filename}: {manifest_exc}") if manifest_entry: metadata_for_defaults = manifest_entry.get("metadata") or {} try: - _apply_hf_defaults_to_model(model_record, metadata_for_defaults, db) + _apply_hf_defaults_to_model(model_record, metadata_for_defaults, store) except Exception as default_exc: - logger.warning( - f"Failed to apply HF defaults for model {model_record.id}: {default_exc}" - ) + logger.warning(f"Failed to apply HF defaults for model {model_record.get('id')}: {default_exc}") return model_record, metadata_result @@ -2386,13 +2224,11 @@ async def _record_gguf_download_post_fetch( async def download_safetensors_bundle_task( huggingface_id: str, files: List[Dict[str, Any]], - websocket_manager, + progress_manager, task_id: str, total_bundle_bytes: int = 0, ): - from backend.database import SessionLocal - - db = SessionLocal() + store = get_store() try: total_files = len(files) bytes_completed = 0 @@ -2405,7 +2241,7 @@ async def download_safetensors_bundle_task( filename = file_info["filename"] size_hint = max(file_info.get("size") or 0, 0) proxy = BundleProgressProxy( - websocket_manager, + progress_manager, task_id, bytes_completed, aggregate_total or 0, @@ -2416,7 +2252,7 @@ async def download_safetensors_bundle_task( "safetensors-bundle", ) - file_path, file_size = await download_model_with_websocket_progress( + file_path, file_size = await download_model_with_progress( huggingface_id, filename, proxy, @@ -2429,7 +2265,7 @@ async def download_safetensors_bundle_task( if filename.endswith(".safetensors"): try: await _save_safetensors_download( - db, huggingface_id, filename, file_path, file_size + store, huggingface_id, filename, file_path, file_size ) except Exception as exc: logger.error( @@ -2439,7 +2275,7 @@ async def download_safetensors_bundle_task( bytes_completed += file_size final_total = aggregate_total or bytes_completed - await websocket_manager.send_download_progress( + await progress_manager.send_download_progress( task_id=task_id, progress=100, message=f"Safetensors bundle downloaded ({total_files} files)", @@ -2454,8 +2290,9 @@ async def download_safetensors_bundle_task( current_filename=files[-1]["filename"] if files else "", huggingface_id=huggingface_id, ) - - await websocket_manager.broadcast( + if progress_manager: + progress_manager.complete_task(task_id, "Safetensors bundle downloaded") + await progress_manager.broadcast( { "type": "download_complete", "huggingface_id": huggingface_id, @@ -2466,14 +2303,15 @@ async def download_safetensors_bundle_task( ) except Exception as exc: logger.error(f"Safetensors bundle download failed: {exc}") - if websocket_manager: - await websocket_manager.send_notification( + if progress_manager: + await progress_manager.send_notification( "error", "Download Failed", f"Safetensors bundle failed: {str(exc)}", task_id, ) - await websocket_manager.broadcast( + progress_manager.fail_task(task_id, str(exc)) + await progress_manager.broadcast( { "type": "download_complete", "huggingface_id": huggingface_id, @@ -2483,36 +2321,23 @@ async def download_safetensors_bundle_task( "error": str(exc), } ) - else: - await websocket_manager.broadcast( - { - "type": "download_complete", - "huggingface_id": huggingface_id, - "model_format": "safetensors_bundle", - "filenames": [f["filename"] for f in files], - "timestamp": datetime.utcnow().isoformat(), - } - ) finally: if task_id: async with download_lock: active_downloads.pop(task_id, None) - db.close() async def download_gguf_bundle_task( huggingface_id: str, quantization: str, files: List[Dict[str, Any]], - websocket_manager, + progress_manager, task_id: str, total_bundle_bytes: int = 0, pipeline_tag: Optional[str] = None, ): - from backend.database import SessionLocal - - db = SessionLocal() + store = get_store() try: total_files = len(files) bytes_completed = 0 @@ -2525,7 +2350,7 @@ async def download_gguf_bundle_task( filename = file_info["filename"] size_hint = max(file_info.get("size") or 0, 0) proxy = BundleProgressProxy( - websocket_manager, + progress_manager, task_id, bytes_completed, aggregate_total or 0, @@ -2536,7 +2361,7 @@ async def download_gguf_bundle_task( "gguf-bundle", ) - file_path, file_size = await download_model_with_websocket_progress( + file_path, file_size = await download_model_with_progress( huggingface_id, filename, proxy, @@ -2546,10 +2371,9 @@ async def download_gguf_bundle_task( huggingface_id, ) - # Reuse the standard GGUF recording path to keep DB and manifest consistent try: await _record_gguf_download_post_fetch( - db, + store, huggingface_id, filename, file_path, @@ -2562,7 +2386,7 @@ async def download_gguf_bundle_task( bytes_completed += file_size final_total = aggregate_total or bytes_completed - await websocket_manager.send_download_progress( + await progress_manager.send_download_progress( task_id=task_id, progress=100, message=f"GGUF bundle downloaded ({total_files} files)", @@ -2577,8 +2401,9 @@ async def download_gguf_bundle_task( current_filename=files[-1]["filename"] if files else "", huggingface_id=huggingface_id, ) - - await websocket_manager.broadcast( + if progress_manager: + progress_manager.complete_task(task_id, "GGUF bundle downloaded") + await progress_manager.broadcast( { "type": "download_complete", "huggingface_id": huggingface_id, @@ -2590,14 +2415,15 @@ async def download_gguf_bundle_task( ) except Exception as exc: logger.error(f"GGUF bundle download failed: {exc}") - if websocket_manager: - await websocket_manager.send_notification( + if progress_manager: + await progress_manager.send_notification( "error", "Download Failed", f"GGUF bundle failed: {str(exc)}", task_id, ) - await websocket_manager.broadcast( + progress_manager.fail_task(task_id, str(exc)) + await progress_manager.broadcast( { "type": "download_complete", "huggingface_id": huggingface_id, @@ -2612,7 +2438,6 @@ async def download_gguf_bundle_task( if task_id: async with download_lock: active_downloads.pop(task_id, None) - db.close() @router.post("/safetensors/download-bundle") @@ -2656,13 +2481,13 @@ async def download_safetensors_bundle( "model_format": "safetensors_bundle", } - from backend.main import websocket_manager - + pm = get_progress_manager() + pm.create_task("download", f"Safetensors bundle {huggingface_id}", {"huggingface_id": huggingface_id}, task_id=task_id) background_tasks.add_task( download_safetensors_bundle_task, huggingface_id, sanitized_files, - websocket_manager, + pm, task_id, declared_total, ) @@ -2724,14 +2549,14 @@ async def download_gguf_bundle( "model_format": "gguf-bundle", } - from backend.main import websocket_manager - + pm = get_progress_manager() + pm.create_task("download", f"GGUF bundle {huggingface_id} ({quantization})", {"huggingface_id": huggingface_id, "quantization": quantization}, task_id=task_id) background_tasks.add_task( download_gguf_bundle_task, huggingface_id, quantization, sanitized_files, - websocket_manager, + pm, task_id, declared_total, pipeline_tag, @@ -2787,342 +2612,174 @@ def extract_base_model_name(filename: str) -> str: return name if name else filename -@router.get("/{model_id}/config") -async def get_model_config(model_id: int, db: Session = Depends(get_db)): +@router.get("/{model_id:path}/config") +async def get_model_config(model_id: str): """Get model's llama.cpp configuration""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") + store = get_store() + model = _get_model_or_404(store, model_id) + return _coerce_model_config(model.get("config")) - return _coerce_model_config(model.config) - -@router.put("/{model_id}/config") -async def update_model_config( - model_id: int, config: dict, db: Session = Depends(get_db) -): +@router.put("/{model_id:path}/config") +async def update_model_config(model_id: str, config: dict): """Update model's llama.cpp configuration""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - - model.config = config - db.commit() + store = get_store() + model = _get_model_or_404(store, model_id) + store.update_model(model_id, {"config": config}) - # Regenerate llama-swap configuration to reflect the updated model config try: from backend.llama_swap_manager import get_llama_swap_manager - llama_swap_manager = get_llama_swap_manager() await llama_swap_manager.regenerate_config_with_active_version() logger.info( - f"Regenerated llama-swap config after updating model {model.name} configuration" + f"Regenerated llama-swap config after updating model {model.get('display_name') or model.get('name')} configuration" ) except Exception as e: - logger.warning( - f"Failed to regenerate llama-swap config after model config update: {e}" - ) + logger.warning(f"Failed to regenerate llama-swap config after model config update: {e}") return {"message": "Configuration updated"} -@router.post("/{model_id}/auto-config") -async def generate_auto_config(model_id: int, db: Session = Depends(get_db)): - """Generate optimal configuration using Smart-Auto""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - - try: - gpu_info = await get_gpu_info() - smart_auto = SmartAutoConfig() - config = await smart_auto.generate_config(model, gpu_info) - - # Save the generated config - model.config = config - db.commit() +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.post("/{model_id:path}/auto-config") +async def generate_auto_config(model_id: str): + """Stub: return current config (Smart Auto removed). Optionally apply defaults.""" + store = get_store() + model = _get_model_or_404(store, model_id) + config = (model.get("config") or {}).copy() + config.setdefault("ctx_size", 2048) + config.setdefault("batch_size", 512) + config.setdefault("threads", 4) + config.setdefault("n_gpu_layers", -1) + store.update_model(model_id, {"config": config}) + return config - # Regenerate llama-swap configuration to reflect the updated model config - try: - from backend.llama_swap_manager import get_llama_swap_manager - llama_swap_manager = get_llama_swap_manager() - await llama_swap_manager.regenerate_config_with_active_version() - logger.info( - f"Regenerated llama-swap config after auto-config for model {model.name}" - ) - except Exception as e: - logger.warning( - f"Failed to regenerate llama-swap config after auto-config: {e}" - ) - - return config - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/{model_id}/smart-auto") +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.post("/{model_id:path}/smart-auto") async def generate_smart_auto_config( - model_id: int, + model_id: str, preset: Optional[str] = None, usage_mode: str = "single_user", speed_quality: Optional[int] = None, use_case: Optional[str] = None, debug: Optional[bool] = False, - db: Session = Depends(get_db), ): - """ - Generate smart auto configuration with optional preset tuning, speed/quality balance, and use case. - - preset: Optional preset name (coding, conversational, long_context) to use as tuning parameters - usage_mode: 'single_user' (sequential, peak KV cache) or 'multi_user' (server, typical usage) - speed_quality: Speed/quality balance (0-100), where 0 = max speed, 100 = max quality. Default: 50 (balanced) - use_case: Optional use case ('chat', 'code', 'creative', 'analysis') for targeted optimization - """ - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - - try: - gpu_info = await get_gpu_info() - smart_auto = SmartAutoConfig() - debug_map = {} if debug else None - - # Validate usage_mode - if usage_mode not in ["single_user", "multi_user"]: - usage_mode = "single_user" # Default to single_user if invalid - - # Validate and normalize speed_quality (0-100, default 50) - if speed_quality is not None: - speed_quality = max(0, min(100, int(speed_quality))) - else: - speed_quality = 50 - - # Validate use_case - if use_case is not None and use_case not in [ - "chat", - "code", - "creative", - "analysis", - ]: - use_case = None # Invalid use case, ignore it - - # If preset is provided, pass it to generate_config for tuning - # Also pass speed_quality and use_case for wizard-based configuration - config = await smart_auto.generate_config( - model, - gpu_info, - preset=preset, - usage_mode=usage_mode, - speed_quality=speed_quality, - use_case=use_case, - debug=debug_map, - ) - - if debug_map is not None: - return {"config": config, "debug": debug_map} - return config - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/{model_id}/start") -async def start_model(model_id: int, db: Session = Depends(get_db)): + """Stub: apply defaults (Smart Auto removed).""" + store = get_store() + model = _get_model_or_404(store, model_id) + config = (model.get("config") or {}).copy() + config.setdefault("ctx_size", 2048) + config.setdefault("batch_size", 512) + config.setdefault("threads", 4) + config.setdefault("n_gpu_layers", -1) + store.update_model(model_id, {"config": config}) + return config + + +@router.post("/{model_id:path}/start") +async def start_model(model_id: str): """Start model via llama-swap""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") + from backend.llama_swap_client import LlamaSwapClient - # Check if already running - existing = ( - db.query(RunningInstance).filter(RunningInstance.model_id == model_id).first() + store = get_store() + model = _get_model_or_404(store, model_id) + proxy_model_name = model.get("proxy_name") or generate_proxy_name( + model.get("huggingface_id"), model.get("quantization") ) - if existing: - raise HTTPException(status_code=400, detail="Model already running") try: - from backend.unified_monitor import unified_monitor - from backend.main import websocket_manager + running_data = await LlamaSwapClient().get_running_models() + running_list = running_data.get("running") or [] + running_names = {item.get("model") for item in running_list if item.get("state") in ("running", "ready")} + except Exception: + running_names = set() + if proxy_model_name in running_names: + raise HTTPException(status_code=400, detail="Model already running") - await websocket_manager.send_model_status_update( + try: + await get_progress_manager().send_model_status_update( model_id=model_id, status="starting", - details={"message": f"Starting {model.name}"}, - ) - - # Get proxy name from database (config already contains this model) - if not model.proxy_name: - raise ValueError(f"Model '{model.name}' does not have a proxy_name set") - proxy_model_name = model.proxy_name - - # Get model configuration (for database record, not config file) - config = _coerce_model_config(model.config) - if _looks_like_embedding_model( - model.pipeline_tag, model.huggingface_id, model.name, model.base_model_name - ) and not config.get("embedding"): - config["embedding"] = True - model.config = config - db.commit() - - # NOTE: We do NOT trigger model loading here. - # The model will load on-demand when the first API request is made. - # This avoids memory issues from making inference requests during load. - # - # With sendLoadingState: true (llama-swap v171+), the first request will - # stream loading progress to the user. - logger.info( - f"Model {proxy_model_name} registered - will load on first API request" + details={"message": f"Starting {model.get('display_name') or model.get('name')}"}, ) + except Exception: + pass - # Save to database - running_instance = RunningInstance( - model_id=model_id, - llama_version=config.get("llama_version", "default"), - proxy_model_name=proxy_model_name, - started_at=datetime.utcnow(), - config=json.dumps(config), - runtime_type="llama_cpp", - ) - db.add(running_instance) - model.is_active = True - db.commit() - - # Broadcast ready event - model is registered and available for requests - # The actual loading happens on first API request (on-demand) - await unified_monitor.broadcast_model_event( - "ready", proxy_model_name, {"model_id": model_id, "model_name": model.name} - ) - await unified_monitor.trigger_status_update() - - return { - "model_id": model_id, - "proxy_model_name": proxy_model_name, - "port": 2000, - "api_endpoint": f"http://localhost:2000/v1/chat/completions", - } - - except Exception as e: - # Clear loading state on error - if model.proxy_name: - unified_monitor.mark_model_stopped(model.proxy_name) - - await websocket_manager.send_model_status_update( - model_id=model_id, - status="error", - details={"message": f"Failed to start: {str(e)}"}, - ) - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/{model_id}/stop") -async def stop_model(model_id: int, db: Session = Depends(get_db)): - """Stop model via llama-swap""" - running_instance = ( - db.query(RunningInstance).filter(RunningInstance.model_id == model_id).first() - ) - if not running_instance: - raise HTTPException(status_code=404, detail="No running instance found") + config = _coerce_model_config(model.get("config")) + if _model_is_embedding(model) and not config.get("embedding"): + config["embedding"] = True + store.update_model(model_id, {"config": config}) try: from backend.llama_swap_manager import get_llama_swap_manager - from backend.main import websocket_manager - from backend.unified_monitor import unified_monitor - llama_swap_manager = get_llama_swap_manager() - - proxy_name = running_instance.proxy_model_name - - # Clear loading state if model was still loading - if proxy_name: - unified_monitor.mark_model_stopped(proxy_name) - - # Unregister from llama-swap (it stops the process) - if proxy_name: - logger.info(f"Calling unregister_model with proxy_model_name: {proxy_name}") - await llama_swap_manager.unregister_model(proxy_name) - logger.info("unregister_model call completed") - - # Update database - db.delete(running_instance) - model = db.query(Model).filter(Model.id == model_id).first() - if model: - model.is_active = False - db.commit() - - # Broadcast stopped event immediately (event-driven, no polling) - if proxy_name: - await unified_monitor.broadcast_model_event( - "stopped", proxy_name, {"model_id": model_id} - ) - await unified_monitor.trigger_status_update() - - return {"message": "Model stopped"} - + await llama_swap_manager.regenerate_config_with_active_version() + model_with_proxy = {**(model or {}), "proxy_name": proxy_model_name} + await llama_swap_manager.register_model(model_with_proxy, config) except Exception as e: - await websocket_manager.send_model_status_update( - model_id=model_id, - status="error", - details={"message": f"Failed to stop: {str(e)}"}, - ) + try: + await get_progress_manager().send_model_status_update( + model_id=model_id, + status="error", + details={"message": f"Failed to start: {str(e)}"}, + ) + except Exception: + pass raise HTTPException(status_code=500, detail=str(e)) + try: + get_progress_manager().emit("model_event", {"event": "ready", "proxy_name": proxy_model_name, "model_id": model_id, "model_name": model.get("display_name") or model.get("name")}) + except Exception: + pass -@router.post("/vram-estimate") -async def estimate_vram_usage( - request: EstimationRequest, db: Session = Depends(get_db) -): - """Estimate VRAM usage for given configuration""" - model = db.query(Model).filter(Model.id == request.model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") + return { + "model_id": model_id, + "proxy_model_name": proxy_model_name, + "port": 2000, + "api_endpoint": "http://localhost:2000/v1/chat/completions", + } - try: - gpu_info = await get_cached_gpu_info() - smart_auto = SmartAutoConfig() - usage_mode = ( - request.usage_mode - if request.usage_mode in ["single_user", "multi_user"] - else "single_user" - ) - metadata = get_model_metadata(model) - vram_estimate = smart_auto.estimate_vram_usage( - model, - request.config, - gpu_info, - usage_mode=usage_mode, - metadata=metadata, - ) - return vram_estimate - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) +@router.post("/{model_id:path}/stop") +async def stop_model(model_id: str): + """Stop model via llama-swap""" + from backend.llama_swap_client import LlamaSwapClient + store = get_store() + model = _get_model_or_404(store, model_id) + proxy_name = model.get("proxy_name") or generate_proxy_name( + model.get("huggingface_id"), model.get("quantization") + ) -@router.post("/ram-estimate") -async def estimate_ram_usage(request: EstimationRequest, db: Session = Depends(get_db)): - """Estimate RAM usage for given configuration""" try: - model = db.query(Model).filter(Model.id == request.model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - - smart_auto = SmartAutoConfig() - usage_mode = ( - request.usage_mode - if request.usage_mode in ["single_user", "multi_user"] - else "single_user" - ) - metadata = get_model_metadata(model) - ram_estimate = smart_auto.estimate_ram_usage( - model, - request.config, - usage_mode=usage_mode, - metadata=metadata, - ) + running_data = await LlamaSwapClient().get_running_models() + running_list = running_data.get("running") or [] + running_names = {item.get("model") for item in running_list if item.get("state") in ("running", "ready", "loading")} + except Exception: + running_names = set() + if proxy_name not in running_names: + raise HTTPException(status_code=404, detail="No running instance found") - return ram_estimate + try: + from backend.llama_swap_manager import get_llama_swap_manager + llama_swap_manager = get_llama_swap_manager() + logger.info(f"Calling unregister_model with proxy_model_name: {proxy_name}") + await llama_swap_manager.unregister_model(proxy_name) + try: + get_progress_manager().emit("model_event", {"event": "stopped", "proxy_name": proxy_name, "model_id": model_id}) + except Exception: + pass + return {"message": "Model stopped"} except Exception as e: + try: + await get_progress_manager().send_model_status_update( + model_id=model_id, + status="error", + details={"message": f"Failed to stop: {str(e)}"}, + ) + except Exception: + pass raise HTTPException(status_code=500, detail=str(e)) @@ -3186,116 +2843,96 @@ class DeleteGroupRequest(BaseModel): @router.post("/delete-group") -async def delete_model_group( - request: DeleteGroupRequest, db: Session = Depends(get_db) -): +async def delete_model_group(request: DeleteGroupRequest): """Delete all quantizations of a model group""" + from backend.llama_swap_client import LlamaSwapClient + huggingface_id = request.huggingface_id - models = db.query(Model).filter(Model.huggingface_id == huggingface_id).all() + store = get_store() + models = [m for m in store.list_models() if m.get("huggingface_id") == huggingface_id] if not models: raise HTTPException(status_code=404, detail="Model group not found") + try: + running_data = await LlamaSwapClient().get_running_models() + running_list = running_data.get("running") or [] + running_names = {item.get("model") for item in running_list if item.get("state") in ("running", "ready", "loading")} + except Exception: + running_names = set() + deleted_count = 0 for model in models: - # Stop if running - running_instance = ( - db.query(RunningInstance) - .filter(RunningInstance.model_id == model.id) - .first() - ) - if running_instance: - # Stop via llama-swap + proxy_name = model.get("proxy_name") or generate_proxy_name(model.get("huggingface_id"), model.get("quantization")) + if proxy_name in running_names: try: from backend.llama_swap_manager import get_llama_swap_manager - - llama_swap_manager = get_llama_swap_manager() - if running_instance.proxy_model_name: - await llama_swap_manager.unregister_model( - running_instance.proxy_model_name - ) + await get_llama_swap_manager().unregister_model(proxy_name) except Exception as e: - logger.warning( - f"Failed to stop model {running_instance.proxy_model_name}: {e}" - ) - db.delete(running_instance) - - # Delete file - normalized_path = _normalize_model_path(model.file_path) - if normalized_path and os.path.exists(normalized_path): - os.remove(normalized_path) - - # Delete from database - db.delete(model) + logger.warning(f"Failed to stop model {proxy_name}: {e}") + + fname = _get_model_filename(model) + if model.get("huggingface_id") and fname: + from backend.huggingface import delete_cached_model_file + deleted_file = delete_cached_model_file(model.get("huggingface_id"), fname) + if not deleted_file: + legacy_path = _normalize_model_path(model.get("file_path")) + if legacy_path and os.path.exists(legacy_path): + os.remove(legacy_path) + + store.delete_model(model.get("id")) deleted_count += 1 - db.commit() - - # If this was a GGUF group and no models remain, clean up the repo folder - remaining_gguf = ( - db.query(Model) - .filter(Model.huggingface_id == huggingface_id, Model.model_format == "gguf") - .count() - ) - if remaining_gguf == 0: - _cleanup_model_folder_if_no_quantizations(db, huggingface_id, "gguf") - return {"message": f"Deleted {deleted_count} quantizations"} -@router.delete("/{model_id}") -async def delete_model(model_id: int, db: Session = Depends(get_db)): +@router.delete("/{model_id:path}") +async def delete_model(model_id: str): """Delete individual model quantization and its files""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") + from backend.llama_swap_client import LlamaSwapClient - # Stop if running - running_instance = ( - db.query(RunningInstance).filter(RunningInstance.model_id == model_id).first() - ) - if running_instance: - # Stop via llama-swap + store = get_store() + model = _get_model_or_404(store, model_id) + proxy_name = model.get("proxy_name") or generate_proxy_name(model.get("huggingface_id"), model.get("quantization")) + + try: + running_data = await LlamaSwapClient().get_running_models() + running_list = running_data.get("running") or [] + running_names = {item.get("model") for item in running_list if item.get("state") in ("running", "ready", "loading")} + except Exception: + running_names = set() + if proxy_name in running_names: try: from backend.llama_swap_manager import get_llama_swap_manager - - llama_swap_manager = get_llama_swap_manager() - if running_instance.proxy_model_name: - await llama_swap_manager.unregister_model( - running_instance.proxy_model_name - ) + await get_llama_swap_manager().unregister_model(proxy_name) except Exception as e: - logger.warning( - f"Failed to stop model {running_instance.proxy_model_name}: {e}" - ) - db.delete(running_instance) - - huggingface_id = model.huggingface_id - model_format = (model.model_format or "gguf").lower() - - # Delete file - normalized_path = _normalize_model_path(model.file_path) - if normalized_path and os.path.exists(normalized_path): - os.remove(normalized_path) - - # Delete from database - db.delete(model) - db.commit() - - # If this was the last quantization for this repo/format, remove its folder - _cleanup_model_folder_if_no_quantizations(db, huggingface_id, model_format) - + logger.warning(f"Failed to stop model {proxy_name}: {e}") + + huggingface_id = model.get("huggingface_id") + filename = _get_model_filename(model) + + if huggingface_id and filename: + from backend.huggingface import delete_cached_model_file + deleted = delete_cached_model_file(huggingface_id, filename) + if not deleted: + # Fall back to direct removal for legacy records with file_path + legacy_path = _normalize_model_path(model.get("file_path")) + if legacy_path and os.path.exists(legacy_path): + os.remove(legacy_path) + logger.info(f"Removed legacy model file: {legacy_path}") + + store.delete_model(model_id) return {"message": "Model quantization deleted"} -@router.get("/{model_id}/layer-info") -async def get_model_layer_info_endpoint(model_id: int, db: Session = Depends(get_db)): +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.get("/{model_id:path}/layer-info") +async def get_model_layer_info_endpoint(model_id: str): """Get model layer information from GGUF metadata""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") + store = get_store() + model = _get_model_or_404(store, model_id) layer_info = None - normalized_path = _normalize_model_path(model.file_path) + normalized_path = _get_model_file_path(model) if normalized_path and os.path.exists(normalized_path): try: layer_info = get_model_layer_info(normalized_path) @@ -3338,85 +2975,59 @@ async def get_model_layer_info_endpoint(model_id: int, db: Session = Depends(get } -@router.get("/{model_id}/recommendations") -async def get_model_recommendations_endpoint( - model_id: int, db: Session = Depends(get_db) -): - """Get configuration recommendations for a model based on its architecture""" - from backend.smart_auto.recommendations import get_model_recommendations - - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.get("/{model_id:path}/recommendations") +async def get_model_recommendations_endpoint(model_id: str): + """Stub: recommendations removed with smart_auto. Returns empty defaults.""" + return {"gpu_layers": None, "context_size": None, "batch_size": None} - normalized_path = _normalize_model_path(model.file_path) - file_path = ( - normalized_path if normalized_path and os.path.exists(normalized_path) else None - ) - try: - # Get layer info from GGUF metadata (if available) - layer_info = get_model_layer_info(file_path) if file_path else None - except Exception as e: - logger.error( - f"Failed to get layer info for recommendations (model {model_id}): {e}" - ) - layer_info = None +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.get("/{model_id:path}/architecture-presets") +async def get_architecture_presets_endpoint(model_id: str): + """Stub: presets removed. Returns minimal structure.""" + return {"architecture": "unknown", "presets": {}, "available_presets": []} - if not layer_info: - layer_info = { - "layer_count": 32, - "architecture": "unknown", - "context_length": 0, - "attention_head_count": 0, - "embedding_length": 0, - } - try: - recommendations = await get_model_recommendations( - model_layer_info=layer_info, - model_name=model.name or model.huggingface_id or "", - file_path=file_path, - ) - return recommendations - except Exception as e: - logger.error(f"Failed to get recommendations for model {model_id}: {e}") - raise HTTPException( - status_code=500, detail=f"Failed to get recommendations: {str(e)}" - ) - - -@router.get("/{model_id}/architecture-presets") -async def get_architecture_presets_endpoint( - model_id: int, db: Session = Depends(get_db) -): - """Get architecture-specific presets for a model""" - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") - - architecture, presets = get_architecture_and_presets(model) - return { - "architecture": architecture, - "presets": presets, - "available_presets": list(presets.keys()), - } +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.post("/vram-estimate") +async def estimate_vram_usage(request: EstimationRequest): + """Stub: simple VRAM estimate (smart_auto removed).""" + store = get_store() + _get_model_or_404(store, request.model_id) + cfg = request.config or {} + ngl = int(cfg.get("n_gpu_layers") or -1) + ctx = int(cfg.get("ctx_size") or 2048) + # Very rough: ~1GB base + per-layer and context + estimate_mb = 1024 + (abs(ngl) * 50 if ngl != -1 else 2000) + (ctx // 64) + return {"vram_estimate_mb": min(estimate_mb, 96 * 1024), "vram_estimate_gb": round(estimate_mb / 1024, 2)} + + +# DEPRECATED: remove with ModelConfig.vue rewrite +@router.post("/ram-estimate") +async def estimate_ram_usage(request: EstimationRequest): + """Stub: simple RAM estimate (smart_auto removed).""" + store = get_store() + _get_model_or_404(store, request.model_id) + cfg = request.config or {} + ctx = int(cfg.get("ctx_size") or 2048) + estimate_mb = 512 + (ctx // 32) + return {"ram_estimate_mb": estimate_mb, "ram_estimate_gb": round(estimate_mb / 1024, 2)} -@router.get("/{model_id}/hf-metadata") -async def get_model_hf_metadata(model_id: int, db: Session = Depends(get_db)): - model = db.query(Model).filter(Model.id == model_id).first() - if not model: - raise HTTPException(status_code=404, detail="Model not found") +@router.get("/{model_id:path}/hf-metadata") +async def get_model_hf_metadata(model_id: str): + store = get_store() + model = _get_model_or_404(store, model_id) metadata_entry = None - if (model.model_format or "gguf").lower() == "safetensors": + if (model.get("model_format") or model.get("format") or "gguf").lower() == "safetensors": metadata_entry = _load_manifest_entry_for_model(model) else: - filename = _extract_filename(model.file_path) + filename = _get_model_filename(model) if not filename: - raise HTTPException(status_code=400, detail="Model file path is not set") - metadata_entry = get_gguf_manifest_entry(model.huggingface_id, filename) + raise HTTPException(status_code=400, detail="Model filename is not set") + metadata_entry = get_gguf_manifest_entry(model.get("huggingface_id"), filename) if not metadata_entry: raise HTTPException(status_code=404, detail="Metadata not found for model") @@ -3432,19 +3043,16 @@ async def get_model_hf_metadata(model_id: int, db: Session = Depends(get_db)): } -@router.post("/{model_id}/regenerate-info") -async def regenerate_model_info_endpoint(model_id: int, db: Session = Depends(get_db)): +@router.post("/{model_id:path}/regenerate-info") +async def regenerate_model_info_endpoint(model_id: str): """ - Regenerate model information from GGUF metadata and update the database. - This will re-read the model file and update architecture, layer count, and other metadata. + Regenerate model information from GGUF metadata and update the store. """ - model = db.query(Model).filter(Model.id == model_id).first() - - if not model: - raise HTTPException(status_code=404, detail="Model not found") + store = get_store() + model = _get_model_or_404(store, model_id) try: - metadata = _refresh_model_metadata_from_file(model, db) + metadata = _refresh_model_metadata_from_file(model, store) return { "success": True, "model_id": model_id, @@ -3456,32 +3064,27 @@ async def regenerate_model_info_endpoint(model_id: int, db: Session = Depends(ge except ValueError as ve: raise HTTPException(status_code=500, detail=str(ve)) except Exception as e: - logger.error( - f"Failed to regenerate model info for model {model_id}: {e}", exc_info=True - ) - db.rollback() - raise HTTPException( - status_code=500, detail=f"Failed to regenerate model info: {str(e)}" - ) + logger.error(f"Failed to regenerate model info for model {model_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Failed to regenerate model info: {str(e)}") @router.get("/supported-flags") -async def get_supported_flags_endpoint(db: Session = Depends(get_db)): +async def get_supported_flags_endpoint(): """Get the list of supported flags for the active llama-server binary""" try: - # Get the active llama-cpp version - active_version = ( - db.query(LlamaVersion).filter(LlamaVersion.is_active == True).first() - ) + store = get_store() + active_version = store.get_active_engine_version("llama_cpp") + if not active_version: + active_version = store.get_active_engine_version("ik_llama") - if not active_version or not active_version.binary_path: + if not active_version or not active_version.get("binary_path"): return { "supported_flags": [], "binary_path": None, "error": "No active llama-cpp version found", } - binary_path = active_version.binary_path + binary_path = active_version.get("binary_path") # Convert to absolute path if needed if not os.path.isabs(binary_path): diff --git a/backend/routes/status.py b/backend/routes/status.py index bef69ae..c4211d0 100644 --- a/backend/routes/status.py +++ b/backend/routes/status.py @@ -1,50 +1,55 @@ -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session +from fastapi import APIRouter import psutil import os -from backend.database import get_db, RunningInstance +from backend.llama_swap_client import LlamaSwapClient from backend.lmdeploy_manager import get_lmdeploy_manager from backend.lmdeploy_installer import get_lmdeploy_installer router = APIRouter() +DEFAULT_PROXY_PORT = 2000 +LMDEPLOY_PORT = 2001 -@router.get("/status") -async def get_system_status(db: Session = Depends(get_db)): - """Get system status and running instances""" - running_instances = db.query(RunningInstance).all() - # Get system info - cpu_percent = psutil.cpu_percent(interval=1) - memory = psutil.virtual_memory() - # Use data directory at project root or /app/data for Docker - data_dir = "data" if os.path.exists("data") else "/app/data" +@router.get("/status") +async def get_system_status(): + """Get system status and running instances (from llama-swap).""" + client = LlamaSwapClient() try: - disk = psutil.disk_usage(data_dir) - except FileNotFoundError: - # Fallback to root directory if data doesn't exist - disk = psutil.disk_usage("/") + running_data = await client.get_running_models() + except Exception: + running_data = {"running": []} + if isinstance(running_data, list): + running_list = running_data + else: + running_list = running_data.get("running") or [] - # Format running instances (no process checking needed) - DEFAULT_PROXY_PORT = 2000 - LMDEPLOY_PORT = 2001 active_instances = [] - for instance in running_instances: - port = ( - LMDEPLOY_PORT if instance.runtime_type == "lmdeploy" else DEFAULT_PROXY_PORT - ) + for i, item in enumerate(running_list): + proxy_model_name = item.get("model", "") + state = item.get("state", "") + runtime_type = "lmdeploy" if state == "lmdeploy" else "llama_cpp" + port = LMDEPLOY_PORT if runtime_type == "lmdeploy" else DEFAULT_PROXY_PORT active_instances.append( { - "id": instance.id, - "model_id": instance.model_id, + "id": i, + "model_id": proxy_model_name, "port": port, - "runtime_type": instance.runtime_type, - "proxy_model_name": instance.proxy_model_name, - "started_at": instance.started_at, + "runtime_type": runtime_type, + "proxy_model_name": proxy_model_name, + "started_at": None, } ) + cpu_percent = psutil.cpu_percent(interval=1) + memory = psutil.virtual_memory() + data_dir = "data" if os.path.exists("data") else "/app/data" + try: + disk = psutil.disk_usage(data_dir) + except FileNotFoundError: + disk = psutil.disk_usage("/") + lmdeploy_manager = get_lmdeploy_manager() lmdeploy_status = lmdeploy_manager.status() installer_status = get_lmdeploy_installer().status() diff --git a/backend/routes/unified_monitoring.py b/backend/routes/unified_monitoring.py deleted file mode 100644 index 8faf061..0000000 --- a/backend/routes/unified_monitoring.py +++ /dev/null @@ -1,75 +0,0 @@ -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from backend.unified_monitor import unified_monitor -from backend.logging_config import get_logger - -logger = get_logger(__name__) - -router = APIRouter() - - -@router.get("/monitoring/status") -async def get_system_status(): - """Get comprehensive system status""" - return await unified_monitor.get_system_status() - - -@router.get("/monitoring/models") -async def get_running_models(): - """Get currently running models from llama-swap""" - return await unified_monitor.get_running_models() - - -@router.post("/monitoring/unload-all") -async def unload_all_models(): - """Unload all models via llama-swap""" - return await unified_monitor.unload_all_models() - - -@router.get("/monitoring/health") -async def get_system_health(): - """Get llama-swap and system health status""" - return await unified_monitor.get_system_health() - - -@router.get("/monitoring/debug") -async def debug_monitoring_data(): - """Debug endpoint to see what data is being collected""" - from backend.unified_monitor import unified_monitor - from backend.llama_swap_client import LlamaSwapClient - - # Get raw data from external source - external_client = LlamaSwapClient() - try: - external_models = await external_client.get_running_models() - except Exception as e: - external_models = {"error": str(e)} - - # Get system status - try: - system_status = await unified_monitor.get_system_status() - except Exception as e: - system_status = {"error": str(e)} - - return { - "running_models": external_models, - "system_status": system_status, - "timestamp": "2024-01-01T00:00:00Z", - } - - -@router.websocket("/monitoring/ws") -async def monitoring_websocket(websocket: WebSocket): - """WebSocket endpoint for real-time monitoring data""" - await unified_monitor.add_subscriber(websocket) - - try: - while True: - # Keep the connection alive and handle any incoming messages - data = await websocket.receive_text() - # Echo back any received data (for testing) - await websocket.send_text(f"Echo: {data}") - except WebSocketDisconnect: - await unified_monitor.remove_subscriber(websocket) - except Exception as e: - logger.error(f"WebSocket error: {e}") - await unified_monitor.remove_subscriber(websocket) diff --git a/backend/smart_auto/__init__.py b/backend/smart_auto/__init__.py deleted file mode 100644 index 57742a3..0000000 --- a/backend/smart_auto/__init__.py +++ /dev/null @@ -1,480 +0,0 @@ -from typing import Dict, Any, Optional -import psutil -from backend.database import Model -from backend.logging_config import get_logger - -# Import all required modules at module level for better performance -from .model_metadata import get_model_metadata -from .architecture_config import get_architecture_default_context -from .cpu_config import generate_cpu_config -from .gpu_config import generate_gpu_config, parse_compute_capability -from .memory_estimator import get_cpu_memory_gb, estimate_vram_usage, estimate_ram_usage -from .kv_cache import get_optimal_kv_cache_quant -from .moe_handler import get_architecture_specific_flags -from .generation_params import build_generation_params -from .config_builder import generate_server_params, sanitize_config, apply_preset_tuning -from .models import SystemResources, ModelMetadata - -logger = get_logger(__name__) - - -class SmartAutoConfig: - """Smart configuration optimizer for llama.cpp parameters""" - - def __init__(self): - self.current_preset = None - - def _generate_cpu_config( - self, - model_size_mb: float, - metadata, - architecture: str, - layer_count: int, - is_moe: bool, - expert_count: int, - ) -> Dict[str, Any]: - """Generate CPU-only configuration with MoE and architecture-specific flags.""" - cpu_cfg = generate_cpu_config( - model_size_mb, - architecture, - layer_count, - metadata.context_length, - metadata.vocab_size, - metadata.embedding_length, - metadata.attention_head_count, - debug=None, - ) - - # Add MoE parameters for CPU-only mode (MoE layers stay on CPU) - if is_moe: - cpu_cfg["moe_offload_pattern"] = "none" - cpu_cfg["moe_offload_custom"] = "" - logger.debug("MoE model in CPU-only mode - MoE layers will run on CPU") - - # Add jinja flag if needed (for architectures that require it) - if is_moe or architecture in ["glm", "glm4", "qwen3"]: - layer_info_for_flags = { - "is_moe": is_moe, - "expert_count": expert_count, - "model_size_mb": model_size_mb, - "available_vram_gb": 0, - "architecture": architecture, - } - moe_config = get_architecture_specific_flags( - architecture, layer_info_for_flags - ) - if moe_config.get("jinja"): - cpu_cfg["jinja"] = True - - return cpu_cfg - - def _apply_moe_optimizations( - self, - config: Dict[str, Any], - metadata, - model_size_mb: float, - system_resources: SystemResources, - ) -> None: - """Apply MoE-specific optimizations to configuration.""" - if not metadata.is_moe: - return - - layer_info_for_flags = { - "is_moe": metadata.is_moe, - "expert_count": metadata.expert_count, - "model_size_mb": model_size_mb, - "available_vram_gb": system_resources.available_vram_gb, - "architecture": metadata.architecture, - } - moe_config = get_architecture_specific_flags( - metadata.architecture, layer_info_for_flags - ) - - # Set MoE parameters in config - if moe_config.get("moe_offload_custom"): - config["moe_offload_pattern"] = "custom" - config["moe_offload_custom"] = moe_config["moe_offload_custom"] - else: - config["moe_offload_pattern"] = "none" - config["moe_offload_custom"] = "" - - # Set jinja flag if needed - if moe_config.get("jinja"): - config["jinja"] = True - - async def generate_config( - self, - model: Model, - gpu_info: Dict[str, Any], - preset: Optional[str] = None, - usage_mode: str = "single_user", - speed_quality: Optional[int] = None, - use_case: Optional[str] = None, - debug: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """Generate optimal configuration based on model and GPU capabilities - - Args: - model: The model to configure - gpu_info: GPU information dictionary - preset: Optional preset name (coding, conversational, long_context) to use as tuning parameters - usage_mode: 'single_user' (sequential, peak KV cache) or 'multi_user' (server, typical usage) - speed_quality: Speed/quality balance (0-100), where 0 = max speed, 100 = max quality. Default: 50 - use_case: Optional use case ('chat', 'code', 'creative', 'analysis') for targeted optimization - """ - from backend.presets import get_architecture_and_presets - - try: - config = {} - # Store preset for later use in generation params - self.current_preset = preset - - # Get model metadata - model_size_mb = model.file_size / (1024 * 1024) if model.file_size else 0 - model_name = model.name.lower() - - # Get comprehensive model layer information from unified helper - metadata = get_model_metadata(model) - - # Now that get_model_metadata returns dataclass with architecture detection already done - layer_count = metadata.layer_count - architecture = metadata.architecture - context_length = metadata.context_length - vocab_size = metadata.vocab_size - embedding_length = metadata.embedding_length - attention_head_count = metadata.attention_head_count - attention_head_count_kv = metadata.attention_head_count_kv - is_moe = metadata.is_moe - expert_count = metadata.expert_count - - if debug is not None: - debug.update( - { - "model_name": model.name, - "model_size_mb": model_size_mb, - "layer_info": metadata.to_dict(), - } - ) - - # Prepare system resources (GPU capabilities calculated once in SystemResources) - cpu_memory = get_cpu_memory_gb() - cpu_cores = psutil.cpu_count(logical=False) or 1 - - system_resources = SystemResources.from_gpu_info( - gpu_info, cpu_memory, cpu_cores - ) - - # Check flash attention availability using pre-parsed compute capabilities - flash_attn_available = ( - all(cc >= 8.0 for cc in system_resources.compute_capabilities) - if system_resources.compute_capabilities - else False - ) - system_resources.flash_attn_available = flash_attn_available - - if debug is not None: - debug.update( - { - "gpu_count": system_resources.gpu_count, - "total_vram": system_resources.total_vram, - "available_vram_gb": system_resources.available_vram_gb, - "flash_attn_available": flash_attn_available, - } - ) - - # CPU-only configuration path - if not system_resources.gpus: - cpu_cfg = self._generate_cpu_config( - model_size_mb, - metadata, - architecture, - layer_count, - is_moe, - expert_count, - ) - return cpu_cfg - - # Select KV cache quantization BEFORE GPU config generation - # This affects M_kv which influences context and batch size calculations - # Use architecture default context_length for initial selection (will be refined later) - kv_cache_config = get_optimal_kv_cache_quant( - system_resources.available_vram_gb, - context_length, - architecture, - system_resources.flash_attn_available, - ) - cache_type_k = kv_cache_config.get("cache_type_k", "f16") - cache_type_v = kv_cache_config.get("cache_type_v") - - # GPU configuration - pass selected KV cache quantization and usage mode - config.update( - generate_gpu_config( - model_size_mb, - architecture, - system_resources.gpus, - system_resources.total_vram, - system_resources.gpu_count, - system_resources.nvlink_topology, - layer_count, - context_length, - vocab_size, - embedding_length, - attention_head_count, - attention_head_count_kv=attention_head_count_kv, - compute_capabilities=system_resources.compute_capabilities, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - usage_mode=usage_mode, - debug=debug, - ) - ) - - # Apply KV cache quantization to config - config.update(kv_cache_config) - - # Hybrid consideration: if VRAM is tight, keep KV cache partly on CPU - try: - # If we have some CPU RAM headroom and low VRAM, prefer no_kv_offload False only when enough VRAM - if system_resources.available_vram_gb < (model_size_mb / 1024) * 1.2: - # Signal to avoid KV offload to VRAM when VRAM is tight - config["no_kv_offload"] = True - else: - config.setdefault("no_kv_offload", False) - except Exception: - pass - - # Apply MoE optimizations - self._apply_moe_optimizations( - config, metadata, model_size_mb, system_resources - ) - - # Use the computed ctx_size from GPU/CPU config when generating params - effective_ctx = int( - config.get("ctx_size", context_length) or context_length - ) - if debug is not None: - debug["effective_ctx_before_gen_params"] = effective_ctx - config.update(build_generation_params(architecture, effective_ctx, None)) - - # Apply speed/quality balancing if provided (modifies config in-place) - if speed_quality is not None: - self._apply_speed_quality_balancing( - config, speed_quality, use_case, metadata, system_resources, debug - ) - - # Apply preset tuning if provided (modifies config in-place) - # Note: preset takes precedence over use_case if both are provided - if self.current_preset: - apply_preset_tuning(config, self.current_preset) - elif use_case: - # Apply use_case-specific tuning if no preset - self._apply_use_case_tuning(config, use_case) - - # Add server parameters - config.update(generate_server_params()) - - # Final sanitation and clamping - config = sanitize_config(config, system_resources.gpu_count) - - return config - - except Exception as e: - raise Exception(f"Failed to generate smart config: {e}") - - def _apply_speed_quality_balancing( - self, - config: Dict[str, Any], - speed_quality: int, - use_case: Optional[str], - metadata, - system_resources, - debug: Optional[Dict[str, Any]] = None, - ) -> None: - """Apply speed/quality balancing to configuration. - - Args: - config: Configuration dictionary to modify in-place - speed_quality: Speed/quality balance (0-100), where 0 = max speed, 100 = max quality - use_case: Optional use case for additional tuning - metadata: Model metadata - system_resources: SystemResources object - debug: Optional debug dictionary - """ - quality_factor = speed_quality / 100.0 # 0.0 = max speed, 1.0 = max quality - - # Context size adjustment - # Speed (0-33): reduce context, Balanced (34-66): moderate, Quality (67-100): maximize - current_ctx = config.get("ctx_size", 4096) - max_context = metadata.context_length - - if speed_quality < 34: - # Max speed: reduce context (2048-4096 range) - target_ctx = 2048 + int((speed_quality / 34) * 2048) - elif speed_quality < 67: - # Balanced: moderate context (4096-8192 range) - target_ctx = 4096 + int(((speed_quality - 34) / 33) * 4096) - else: - # Max quality: maximize context (8192-max range) - min_quality_ctx = 8192 - target_ctx = min_quality_ctx + int( - ((speed_quality - 67) / 33) * (max_context - min_quality_ctx) - ) - - # Respect use_case minimums - if use_case == "code" and target_ctx < 8192: - target_ctx = 8192 - elif use_case == "analysis" and target_ctx < 16384: - target_ctx = 16384 - - config["ctx_size"] = min(target_ctx, max_context) - - # Batch size adjustment - # Speed-focused: larger batches for throughput - # Quality-focused: smaller batches for lower latency per request - current_batch = config.get("batch_size", 256) - current_ubatch = config.get("ubatch_size", 128) - - if speed_quality < 34: - # Max speed: large batches - config["batch_size"] = 512 + int((speed_quality / 34) * 256) # 512-768 - config["ubatch_size"] = 256 + int((speed_quality / 34) * 128) # 256-384 - elif speed_quality < 67: - # Balanced: medium batches - config["batch_size"] = 384 + int( - ((speed_quality - 34) / 33) * 128 - ) # 384-512 - config["ubatch_size"] = 192 + int( - ((speed_quality - 34) / 33) * 64 - ) # 192-256 - else: - # Max quality: smaller batches - config["batch_size"] = 256 + int( - ((speed_quality - 67) / 33) * 128 - ) # 256-384 - config["ubatch_size"] = 128 + int( - ((speed_quality - 67) / 33) * 64 - ) # 128-192 - - # GPU layers adjustment - # Quality factor affects how many layers to offload - if config.get("n_gpu_layers", 0) > 0: - layer_count = metadata.layer_count - base_layers = config["n_gpu_layers"] - # Adjust based on quality factor (70-100% of base) - adjusted_layers = int(base_layers * (0.7 + (quality_factor * 0.3))) - config["n_gpu_layers"] = min(adjusted_layers, layer_count) - - # Parallel processing adjustment - # Higher for speed-focused, lower for quality-focused - if speed_quality < 50: - config["parallel"] = max(1, int(3 - (speed_quality / 50) * 2)) # 3 to 1 - else: - config["parallel"] = 1 # Quality-focused: sequential processing - - # Threads optimization - cpu_threads = system_resources.cpu_cores or 4 - if speed_quality < 50: - # Speed-focused: use more threads - config["threads"] = cpu_threads - config["threads_batch"] = min(cpu_threads, 8) - else: - # Quality-focused: optimize threads - config["threads"] = max(2, int(cpu_threads * 0.8)) - config["threads_batch"] = max(2, int(cpu_threads * 0.8)) - - # Flash Attention: enable for quality-focused configs when available - if system_resources.flash_attn_available and quality_factor > 0.6: - config["flash_attn"] = True - # Flash attention enables V cache quantization - if quality_factor < 0.7: - config["cache_type_v"] = "q8_0" # Moderate quantization for balanced - else: - config["cache_type_v"] = "f16" # Better quality - - # KV Cache quantization adjustment - available_vram_gb = system_resources.available_vram_gb - total_vram_gb = ( - system_resources.total_vram / (1024**3) - if system_resources.total_vram - else 0 - ) - - if quality_factor < 0.5 and available_vram_gb < total_vram_gb * 0.5: - # Low VRAM or speed-focused: use quantization - if ( - config.get("cache_type_k") is None - or config.get("cache_type_k") == "f16" - ): - config["cache_type_k"] = "q8_0" - if config.get("flash_attn") and config.get("cache_type_v") is None: - config["cache_type_v"] = "q8_0" - elif quality_factor > 0.7: - # Quality-focused: use full precision - config["cache_type_k"] = "f16" - if config.get("flash_attn"): - config["cache_type_v"] = "f16" - - # Low VRAM mode for tight memory situations - if available_vram_gb < total_vram_gb * 0.3 or ( - quality_factor < 0.4 and available_vram_gb < total_vram_gb * 0.5 - ): - config["low_vram"] = True - - if debug is not None: - debug["speed_quality"] = speed_quality - debug["quality_factor"] = quality_factor - debug["use_case"] = use_case - debug["adjusted_ctx_size"] = config["ctx_size"] - debug["adjusted_batch_size"] = config["batch_size"] - - def _apply_use_case_tuning(self, config: Dict[str, Any], use_case: str) -> None: - """Apply use-case-specific generation parameter tuning. - - Args: - config: Configuration dictionary to modify in-place - use_case: Use case ('chat', 'code', 'creative', 'analysis') - """ - if use_case == "code": - config["temp"] = 0.3 - config["temperature"] = 0.3 - config["top_k"] = 30 - if config.get("ctx_size", 4096) < 8192: - config["ctx_size"] = 8192 - elif use_case == "creative": - config["temp"] = 1.2 - config["temperature"] = 1.2 - config["top_k"] = 50 - config["top_p"] = 0.95 - elif use_case == "analysis": - config["temp"] = 0.7 - config["temperature"] = 0.7 - if config.get("ctx_size", 4096) < 16384: - config["ctx_size"] = 16384 - elif use_case == "chat": - config["temp"] = 0.8 - config["temperature"] = 0.8 - - def estimate_vram_usage( - self, - model: Model, - config: Dict[str, Any], - gpu_info: Dict[str, Any], - usage_mode: str = "single_user", - metadata: Optional[ModelMetadata] = None, - ) -> Dict[str, Any]: - """Estimate VRAM usage for given configuration using comprehensive model metadata""" - return estimate_vram_usage( - model, config, gpu_info, metadata=metadata, usage_mode=usage_mode - ) - - def estimate_ram_usage( - self, - model: Model, - config: Dict[str, Any], - usage_mode: str = "single_user", - metadata: Optional[ModelMetadata] = None, - ) -> Dict[str, Any]: - """Estimate RAM usage for given configuration""" - return estimate_ram_usage( - model, config, metadata=metadata, usage_mode=usage_mode - ) diff --git a/backend/smart_auto/architecture_config.py b/backend/smart_auto/architecture_config.py deleted file mode 100644 index 18e6c19..0000000 --- a/backend/smart_auto/architecture_config.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Architecture configuration and detection module. -Consolidates architecture detection and default configuration values. -""" - -from functools import lru_cache -from .constants import ARCHITECTURE_CONTEXT_DEFAULTS, DEFAULT_CONTEXT_LENGTH - - -@lru_cache(maxsize=128) -def resolve_architecture(architecture_or_name: str) -> str: - """ - Unified function to resolve architecture from either a model name or architecture string. - - Handles both detection from model names and normalization of GGUF metadata. - This replaces the separate detect_architecture_from_name() and normalize_architecture() functions. - - Args: - architecture_or_name: Either a model name or architecture string from GGUF metadata - - Returns: - Normalized architecture name (e.g., "llama3", "qwen3", etc.) - """ - if not architecture_or_name or not architecture_or_name.strip(): - return "unknown" - - # Normalize input - text = architecture_or_name.lower().strip() - - # Check architectures in order of specificity (most specific first) - - # Qwen architectures - if "qwen" in text: - if "qwen3" in text or "qwen-3" in text: - return "qwen3" - if "qwen2" in text or "qwen-2" in text: - return "qwen2" - return "qwen" - - # Llama architectures (CodeLlama before other Llama variants) - if "codellama" in text: - return "codellama" - if "llama3" in text or "llama-3" in text: - return "llama3" - if "llama2" in text or "llama-2" in text: - return "llama2" - if "llama" in text: - return "llama" - - # Gemma architectures - if "gemma3" in text or "gemma-3" in text: - return "gemma3" - if "gemma" in text: - return "gemma" - - # GLM architectures - if "glm-4" in text or "glm4" in text: - return "glm4" - if "glm" in text or "chatglm" in text: - return "glm" - - # DeepSeek architectures - if "deepseek" in text: - if "v3" in text or "v3.1" in text: - return "deepseek-v3" - return "deepseek" - - # Other architectures - if "mistral" in text: - return "mistral" - if "phi" in text: - return "phi" - - # If text contains something but not recognized, return as generic for model names - # or unknown for invalid architecture strings - if text and text not in ["unknown", "generic"]: - return "generic" - - return "unknown" - - -def get_architecture_default_context(architecture: str) -> int: - """ - Get default context length for an architecture. - - Args: - architecture: Normalized architecture name - - Returns: - Default context length in tokens - """ - return ARCHITECTURE_CONTEXT_DEFAULTS.get(architecture, DEFAULT_CONTEXT_LENGTH) - - -# Backward compatibility aliases -def detect_architecture_from_name(model_name: str) -> str: - """Deprecated: Use resolve_architecture() instead.""" - return resolve_architecture(model_name) - - -def normalize_architecture(architecture: str) -> str: - """Deprecated: Use resolve_architecture() instead.""" - return resolve_architecture(architecture) diff --git a/backend/smart_auto/calculators.py b/backend/smart_auto/calculators.py deleted file mode 100644 index 10efab8..0000000 --- a/backend/smart_auto/calculators.py +++ /dev/null @@ -1,402 +0,0 @@ -""" -Calculation utilities for smart_auto module. -Pure functions for batch size, context size, and GPU layer calculations. -""" - -from functools import lru_cache -from typing import Tuple, Optional -from .constants import ( - VRAM_FRAGMENTATION_MARGIN, - CONTEXT_SAFETY_MARGIN, - LAYERS_PER_GB_SMALL_MODEL, - LAYERS_PER_GB_LARGE_MODEL, - GPU_LAYER_BUFFER, - CONTEXT_RAM_OVERHEAD_GB, - MIN_CONTEXT_SIZE, - MAX_CONTEXT_SIZE, - MIN_BATCH_SIZE, - ARCHITECTURE_CPU_BATCH_LIMITS, -) - - -def calculate_ubatch_size(batch_size: int) -> int: - """ - Calculate optimal ubatch_size from batch_size. - - Unified helper to derive ubatch_size consistently across GPU and CPU modes. - """ - return max(1, min(batch_size, max(1, batch_size // 2))) - - -def calculate_optimal_batch_size_gpu( - available_vram_gb: float, - model_size_mb: float, - context_size: int, - embedding_length: int, - layer_count: int, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, -) -> int: - """ - Calculate optimal batch size for GPU based on memory requirements. - - Uses data-driven approach when possible, falls back to VRAM-based estimation. - """ - # Memory requirements per batch item - model_memory_gb = model_size_mb / 1024 - - # KV cache memory per batch item - # Note: KV cache is shared across batch items in continuous batching, but we estimate - # as if each item needs its own context window (conservative estimate) - if embedding_length > 0 and layer_count > 0: - # Use actual quantization bytes-per-value if provided, otherwise default to fp16 - from .constants import KV_CACHE_QUANT_FACTORS - - quant_factor_k = KV_CACHE_QUANT_FACTORS.get( - cache_type_k or "f16", 0.5 - ) # f16 = 0.5 - quant_factor_v = KV_CACHE_QUANT_FACTORS.get( - cache_type_v or cache_type_k or "f16", quant_factor_k - ) - bytes_per_k = quant_factor_k * 4 # Convert factor to bytes (f32=4, f16=2, etc.) - bytes_per_v = quant_factor_v * 4 - bytes_per_element = (bytes_per_k + bytes_per_v) / 2 # Average for K+V - # Conservative estimate: use embedding_length directly (overestimates slightly for GQA) - # This is a simplified calculation for batch sizing - precise GQA calculation done in memory_estimator - kv_cache_per_item_gb = ( - context_size * embedding_length * layer_count * bytes_per_element - ) / (1024**3) - else: - # Conservative estimate: 64 bytes per token - kv_cache_per_item_gb = context_size * 64 / (1024**3) - - total_per_item_gb = model_memory_gb + kv_cache_per_item_gb - - if total_per_item_gb <= 0: - return MIN_BATCH_SIZE - - # Calculate max batch size based on available memory - max_batch_size = int( - available_vram_gb * VRAM_FRAGMENTATION_MARGIN / total_per_item_gb - ) - - # Apply reasonable limits based on model size - if embedding_length > 2048: # Large models (7B+) - max_batch_size = min(max_batch_size, 512) - elif embedding_length > 1024: # Medium models (3B-7B) - max_batch_size = min(max_batch_size, 1024) - else: # Small models (<3B) - max_batch_size = min(max_batch_size, 2048) - - return max(MIN_BATCH_SIZE, max_batch_size) - - -def calculate_optimal_batch_size_cpu( - available_ram_gb: float, model_size_mb: float, context_size: int, architecture: str -) -> Tuple[int, int]: - """ - Calculate optimal batch sizes for CPU mode using dict-based architecture profiles. - - Returns: - Tuple of (batch_size, ubatch_size) - """ - model_ram_gb = model_size_mb / 1024 - - # Calculate available RAM for batching after model and context - reserved_ram_gb = model_ram_gb + (context_size / 1000) + CONTEXT_RAM_OVERHEAD_GB - available_for_batch = max(0, available_ram_gb - reserved_ram_gb) - - # Estimate batch memory usage (rough: 1MB per batch item) - max_batch_size = int(available_for_batch * 1000) # 1GB = ~1000 batch items - - # Get architecture-specific limits or use defaults - limits = ARCHITECTURE_CPU_BATCH_LIMITS.get( - architecture, ARCHITECTURE_CPU_BATCH_LIMITS["default"] - ) - - batch_size = min(limits["max_batch"], max(limits["min_batch"], max_batch_size)) - ubatch_size = min(limits["max_ubatch"], max(limits["min_ubatch"], batch_size // 2)) - - return batch_size, ubatch_size - - -@lru_cache(maxsize=128) -def calculate_max_context_size_gpu( - available_vram_gb: float, - model_size_mb: float, - layer_count: int, - embedding_length: int, - attention_head_count: int, - attention_head_count_kv: int, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", -) -> int: - """ - Calculate maximum context size for GPU based on memory requirements. - - Cached with LRU to avoid redundant calculations for same parameters. - - Returns: - Maximum context size in tokens - """ - # Reserve memory for model - model_memory_gb = model_size_mb / 1024 - reserved_memory_gb = model_memory_gb + 1.0 # Model + 1GB overhead - available_for_context_gb = max(0, available_vram_gb - reserved_memory_gb) - - if available_for_context_gb <= 0: - return MIN_CONTEXT_SIZE - - # Calculate KV cache memory per token based on transformer architecture - # GQA-aware formula: M_kv = n_ctx × N_layers × N_head_kv × d_head × (p_a_k + p_a_v) - # where d_head = N_embd / N_head - # Use actual quantization bytes-per-value instead of hardcoded fp16 - if embedding_length > 0 and layer_count > 0: - # Get actual quantization bytes-per-value - from .constants import KV_CACHE_QUANT_FACTORS - - quant_factor_k = KV_CACHE_QUANT_FACTORS.get( - cache_type_k or "f16", 0.5 - ) # f16 = 0.5 - quant_factor_v = KV_CACHE_QUANT_FACTORS.get( - cache_type_v or cache_type_k or "f16", quant_factor_k - ) - bytes_per_k = quant_factor_k * 4 # Convert factor to bytes (f32=4, f16=2, etc.) - bytes_per_v = quant_factor_v * 4 - - if attention_head_count_kv > 0 and attention_head_count > 0: - # GQA-aware calculation - d_head = embedding_length / attention_head_count - # KV cache per token: K and V cache per layer, each storing N_head_kv heads - kv_cache_per_layer_k = attention_head_count_kv * d_head * bytes_per_k - kv_cache_per_layer_v = attention_head_count_kv * d_head * bytes_per_v - kv_cache_per_token_bytes = ( - kv_cache_per_layer_k + kv_cache_per_layer_v - ) * layer_count - else: - # Fallback for non-GQA models (MHA: N_head_kv = N_head) - kv_cache_per_token_bytes = ( - layer_count * embedding_length * (bytes_per_k + bytes_per_v) - ) - - # Apply usage mode factor for multi_user (allows larger context since KV cache is lower) - # For max context calculation: n_ctx = available_vram / (kv_cache_per_token * usage_factor) - # So: tokens_per_gb = 1GB / (kv_cache_per_token * usage_factor) - from .constants import KV_CACHE_SINGLE_USER_FACTOR, KV_CACHE_MULTI_USER_FACTOR - - if usage_mode == "multi_user": - # In multi_user mode, KV cache usage is lower (typical usage), so we can fit more context - usage_factor = KV_CACHE_MULTI_USER_FACTOR - # Calculate tokens per GB: divide by (bytes_per_token * usage_factor) - # This gives more tokens since usage_factor < 1.0 - tokens_per_gb = ( - (1024**3) / (kv_cache_per_token_bytes * usage_factor) - if kv_cache_per_token_bytes > 0 - else 0 - ) - else: - # Single user mode: full KV cache (peak usage), standard calculation - usage_factor = KV_CACHE_SINGLE_USER_FACTOR - tokens_per_gb = ( - (1024**3) / (kv_cache_per_token_bytes * usage_factor) - if kv_cache_per_token_bytes > 0 - else 0 - ) - - # Calculate max context size with safety margin - if tokens_per_gb > 0: - max_context_tokens = int( - available_for_context_gb * tokens_per_gb * CONTEXT_SAFETY_MARGIN - ) - # Ensure minimum context size - return max(MIN_CONTEXT_SIZE, max_context_tokens) - else: - # Fallback if calculation fails (e.g., kv_cache_per_token_bytes is 0) - return MIN_CONTEXT_SIZE - else: - # Fallback to conservative estimate: ~1000 tokens per GB - estimated = int(available_for_context_gb * 1000) - return max(MIN_CONTEXT_SIZE, min(MAX_CONTEXT_SIZE, estimated)) - - -def calculate_optimal_context_size_gpu( - architecture: str, - available_vram: int, - model_size_mb: float = 0, - layer_count: int = 32, - embedding_length: int = 0, - attention_head_count: int = 0, - attention_head_count_kv: int = 0, - base_context: Optional[int] = None, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", -) -> int: - """ - Calculate optimal context size for GPU based on VRAM and architecture defaults. - - Returns: - Optimal context size in tokens - """ - from .architecture_config import get_architecture_default_context - - base_ctx = base_context or get_architecture_default_context(architecture) - - if available_vram == 0: - # CPU mode - conservative context - return max(MIN_CONTEXT_SIZE, min(base_ctx, 2048)) - - # Use data-driven calculation if we have model parameters - if model_size_mb > 0 and layer_count > 0 and embedding_length > 0: - vram_gb = available_vram / (1024**3) - calculated_max = calculate_max_context_size_gpu( - vram_gb, - model_size_mb, - layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - usage_mode=usage_mode, - ) - result = min(base_ctx, calculated_max) if calculated_max > 0 else base_ctx - return max(MIN_CONTEXT_SIZE, min(result, MAX_CONTEXT_SIZE)) - - # Fallback to architecture-based limits if no model data - vram_gb = available_vram / (1024**3) - - # Conservative scaling based on VRAM capacity - if vram_gb >= 24: # High-end GPU - return max(MIN_CONTEXT_SIZE, min(base_ctx, MAX_CONTEXT_SIZE)) - elif vram_gb >= 12: # Mid-range GPU - return max(MIN_CONTEXT_SIZE, min(base_ctx, int(base_ctx * 0.75))) - elif vram_gb >= 8: # Lower-end GPU - return max(MIN_CONTEXT_SIZE, min(base_ctx, int(base_ctx * 0.5))) - else: # Very limited VRAM - return max(MIN_CONTEXT_SIZE, min(base_ctx, 2048)) - - -def calculate_optimal_gpu_layers( - free_vram_gb: float, - model_size_mb: float, - total_layers: int, - context_size: int = 4096, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - ubatch_size: int = 512, - attention_head_count: int = 0, - attention_head_count_kv: int = 0, - embedding_length: int = 0, - layer_count: int = 0, - usage_mode: str = "single_user", -) -> int: - """ - Calculate optimal number of layers to offload to GPU. - - Uses exact M_kv and M_compute calculations according to theoretical model: - n_ngl_max = floor((VRAM_available - M_kv - M_compute) / (M_weights_total / N_layers)) - - Args: - free_vram_gb: Available VRAM in GB - model_size_mb: Model size in MB (GGUF file size) - total_layers: Total number of layers in model - context_size: Context size in tokens (default: 4096) - cache_type_k: K cache quantization type (default: f16) - cache_type_v: V cache quantization type (default: same as cache_type_k) - ubatch_size: Micro-batch size (default: 512) - attention_head_count: Number of attention heads (for GQA calculation) - attention_head_count_kv: Number of KV attention heads (for GQA calculation) - embedding_length: Embedding dimension (for GQA calculation) - layer_count: Layer count (alias for total_layers, for compatibility) - - Returns: - Number of GPU layers - """ - # Use total_layers if provided, otherwise layer_count - actual_layer_count = ( - total_layers if total_layers > 0 else (layer_count if layer_count > 0 else 0) - ) - - if actual_layer_count == 0: - # Fallback to old heuristic if layer count unknown - estimated_layers_per_gb = ( - LAYERS_PER_GB_SMALL_MODEL - if model_size_mb < 1000 - else LAYERS_PER_GB_LARGE_MODEL - ) - max_layers = int(free_vram_gb * estimated_layers_per_gb * GPU_LAYER_BUFFER) - return max_layers - - # Calculate exact M_kv and M_compute - free_vram_bytes = free_vram_gb * (1024**3) - model_size_bytes = model_size_mb * (1024**2) - - # Calculate M_kv using exact formula - from .memory_estimator import calculate_kv_cache_size - - # Use default values if not provided - cache_type_k_actual = cache_type_k or "f16" - cache_type_v_actual = cache_type_v or cache_type_k_actual - - # If we have architecture parameters, use precise calculation - if embedding_length > 0 and attention_head_count > 0: - kv_cache_bytes = calculate_kv_cache_size( - context_size, - 1, # parallel=1 for layer calculation - actual_layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv or attention_head_count, - cache_type_k_actual, - cache_type_v_actual if cache_type_v else None, - usage_mode=usage_mode, - ) - else: - # Fallback: estimate KV cache size (conservative) - # Assume fp16, use embedding_length if available, otherwise estimate - if embedding_length > 0: - # Simplified estimate: assume MHA (not GQA) - bytes_per_token = actual_layer_count * embedding_length * 4 # K+V at fp16 - kv_cache_bytes = context_size * bytes_per_token - else: - # Very conservative fallback: ~64 bytes per token per layer - kv_cache_bytes = context_size * actual_layer_count * 64 - - # Calculate M_compute: Fixed overhead + variable scratch buffer - from .constants import COMPUTE_FIXED_OVERHEAD_MB, COMPUTE_SCRATCH_PER_UBATCH_MB - - compute_overhead_mb = COMPUTE_FIXED_OVERHEAD_MB + ( - ubatch_size * COMPUTE_SCRATCH_PER_UBATCH_MB - ) - compute_overhead_bytes = int(compute_overhead_mb * (1024**2)) - - # Formula from theoretical model: - # n_ngl_max = floor((VRAM_available - M_kv - M_compute) / (M_weights_total / N_layers)) - available_for_weights_bytes = ( - free_vram_bytes - kv_cache_bytes - compute_overhead_bytes - ) - - if available_for_weights_bytes <= 0: - # Not enough VRAM even for M_kv and M_compute - return 0 - - mb_per_layer = ( - model_size_bytes / actual_layer_count if actual_layer_count > 0 else 0 - ) - if mb_per_layer <= 0: - # Fallback if calculation fails - estimated_layers_per_gb = ( - LAYERS_PER_GB_SMALL_MODEL - if model_size_mb < 1000 - else LAYERS_PER_GB_LARGE_MODEL - ) - max_layers = int(free_vram_gb * estimated_layers_per_gb * GPU_LAYER_BUFFER) - return min(max_layers, actual_layer_count) - - max_layers = ( - int(available_for_weights_bytes / mb_per_layer) if mb_per_layer > 0 else 0 - ) - - return min(max_layers, actual_layer_count) diff --git a/backend/smart_auto/config_builder.py b/backend/smart_auto/config_builder.py deleted file mode 100644 index 56ac40b..0000000 --- a/backend/smart_auto/config_builder.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Configuration builder module. -Handles configuration sanitization, server parameters, and preset tuning. -""" - -from typing import Dict, Any, Optional -from backend.logging_config import get_logger - -logger = get_logger(__name__) - - -def clamp_int(name: str, val: Any, lo: int, hi: int, default: int) -> int: - """Helper to clamp integer values.""" - try: - iv = int(val) - except (ValueError, TypeError): - iv = default - return max(lo, min(hi, iv)) - - -def generate_server_params() -> Dict[str, Any]: - """Generate server-specific parameters""" - return { - "host": "0.0.0.0", # Allow external connections - "timeout": 300, # 5 minutes timeout - } - - -def sanitize_config(config: Dict[str, Any], gpu_count: int) -> Dict[str, Any]: - """Clamp and sanitize final config values to enforce invariants and avoid edge-case crashes.""" - sanitized = dict(config) - - # Clamp integer values - sanitized["ctx_size"] = clamp_int( - "ctx_size", sanitized.get("ctx_size", 4096), 512, 262144, 4096 - ) - sanitized["batch_size"] = clamp_int( - "batch_size", sanitized.get("batch_size", 512), 1, 4096, 512 - ) - sanitized["ubatch_size"] = clamp_int( - "ubatch_size", - sanitized.get("ubatch_size"), - 1, - sanitized.get("batch_size", 512), - max(1, sanitized.get("batch_size", 512) // 2), - ) - sanitized["parallel"] = clamp_int( - "parallel", - sanitized.get("parallel", 1), - 1, - max(1, gpu_count if gpu_count > 0 else 1), - 1, - ) - - # Ensure boolean fields are properly typed - boolean_fields = ["no_mmap", "mlock", "low_vram", "logits_all", "flash_attn"] - sanitized.update({b: bool(sanitized[b]) for b in boolean_fields if b in sanitized}) - - return sanitized - - -def apply_preset_tuning(config: Dict[str, Any], preset_name: str) -> None: - """ - Apply preset-specific tuning to configuration parameters. - - Consolidates both generation parameter adjustments and config factor tuning - into a single clear function. - """ - if preset_name == "coding": - config["temperature"] = 0.7 - config["repeat_penalty"] = 1.05 - if "batch_size" in config: - config["batch_size"] = max(1, int(config["batch_size"] * 0.8)) - if "ubatch_size" in config: - config["ubatch_size"] = max(1, int(config["ubatch_size"] * 0.8)) - if "parallel" in config: - config["parallel"] = max(1, int(config["parallel"] * 1.2)) - logger.debug("Applied preset 'coding' tuning") - # conversational preset has no changes (factors = 1.0), so skip diff --git a/backend/smart_auto/constants.py b/backend/smart_auto/constants.py deleted file mode 100644 index c94dee6..0000000 --- a/backend/smart_auto/constants.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Constants used across the smart_auto module. -Centralizes magic numbers and configuration limits. -""" - -from typing import Dict, Any - -# ============================================================================ -# Memory optimization factors -# ============================================================================ - -KV_CACHE_OPTIMIZATION_FACTOR = 1.0 # Use actual memory (no optimization factor) - memory mapping doesn't reduce peak usage - -# Usage mode factors for KV cache estimation -# Based on theoretical model: single_user accumulates context (peak), multi_user clears context (typical usage) -KV_CACHE_SINGLE_USER_FACTOR = 1.0 # Peak estimate (full context window) -KV_CACHE_MULTI_USER_FACTOR = ( - 0.4 # Typical usage (context cleared between requests, ~40% of peak) -) - -MOE_OFFLOAD_ALL_RATIO = 0.3 # 30% of model for all MoE offloaded -MOE_OFFLOAD_UP_DOWN_RATIO = 0.2 # 20% of model for up/down MoE offloaded -MOE_OFFLOAD_UP_RATIO = 0.1 # 10% of model for up MoE offloaded - -LLAMA_CPP_OVERHEAD_MB = 256 # 256MB overhead for llama.cpp - -# Compute buffer constants (M_compute) -COMPUTE_FIXED_OVERHEAD_MB = ( - 550 # Fixed CUDA overhead (~550MB for CUDA context, cuBLAS workspace, etc.) -) -COMPUTE_SCRATCH_PER_UBATCH_MB = ( - 0.5 # Variable scratch buffer per ubatch size (rough estimate) -) - -# VRAM pressure thresholds for MoE offloading -VRAM_RATIO_VERY_TIGHT = 1.2 # Very tight VRAM - offload all MoE -VRAM_RATIO_TIGHT = 1.5 # Tight VRAM - offload up/down projections -VRAM_RATIO_MODERATE = 2.0 # Moderate VRAM - offload only up projection - -# ============================================================================ -# KV Cache quantization factors -# ============================================================================ - -KV_CACHE_QUANT_FACTORS: Dict[str, float] = { - "f32": 1.0, # Full precision (no reduction) - "f16": 0.5, # Half precision - "bf16": 0.5, # Bfloat16 - "q8_0": 0.25, # 8-bit quant - "q5_1": 0.156, # 5-bit high quality - "q5_0": 0.156, # 5-bit - "q4_1": 0.125, # 4-bit high quality - "q4_0": 0.125, # 4-bit - "iq4_nl": 0.125, # 4-bit non-linear -} - -QUANTIZATION_AVERAGE_FACTOR = 0.5 # Average of K and V cache quantization factors - -# ============================================================================ -# Architecture context defaults -# ============================================================================ - -ARCHITECTURE_CONTEXT_DEFAULTS: Dict[str, int] = { - "llama2": 4096, - "llama3": 8192, - "llama": 4096, - "codellama": 16384, - "mistral": 32768, - "phi": 2048, - "glm": 8192, - "glm4": 204800, # 200K context for GLM-4.6 - "deepseek": 32768, - "deepseek-v3": 32768, - "qwen": 32768, # 32K context - "qwen2": 32768, # 32K context - "qwen3": 131072, # 128K context for Qwen3 - "gemma": 8192, - "gemma3": 8192, - "generic": 4096, -} - -DEFAULT_CONTEXT_LENGTH = 4096 - -# ============================================================================ -# Memory calculation defaults -# ============================================================================ - -DEFAULT_BYTES_PER_ELEMENT = 2 # Assume fp16 for activations -BATCH_INTERMEDIATE_FACTOR = 0.08 # 8% factor for intermediate activations -BATCH_QKV_FACTOR = 0.04 # 4% factor for QKV projections -BATCH_COMPUTATION_OVERHEAD_KB = 400 # ~400KB per batch item -BATCH_FALLBACK_MB = 1.5 # 1.5MB per batch item fallback - -BATCH_VRAM_OVERHEAD_RATIO = 0.1 # 10% of KV cache VRAM for batch overhead -BATCH_RAM_OVERHEAD_RATIO = 0.1 # 10% of KV cache RAM for batch overhead - -# Layer estimation defaults -FALLBACK_LAYER_COUNT = 32 -FALLBACK_EMBEDDING_LENGTH = 4096 -FALLBACK_KV_CACHE_PER_TOKEN_BYTES = 60 * (4096 * 2 + 4096 * 2) # ~960 KB per token - -# ============================================================================ -# Context size limits -# ============================================================================ - -MIN_CONTEXT_SIZE = 512 -MAX_CONTEXT_SIZE = 262144 # 256K -MAX_CPU_CONTEXT_SIZE = 8192 # Conservative limit for CPU mode - -# ============================================================================ -# Batch size limits -# ============================================================================ - -MIN_BATCH_SIZE = 1 -MAX_BATCH_SIZE = 4096 - -# ============================================================================ -# GPU/VRAM calculation constants -# ============================================================================ - -# VRAM safety margins -VRAM_SAFETY_MARGIN = 0.9 # 90% of available VRAM -VRAM_FRAGMENTATION_MARGIN = 0.7 # 70% for batch size calculations -CONTEXT_SAFETY_MARGIN = 0.8 # 80% for context size calculations - -# GPU layer estimation -LAYERS_PER_GB_SMALL_MODEL = 8 # Models < 1GB -LAYERS_PER_GB_LARGE_MODEL = 4 # Models >= 1GB -GPU_LAYER_BUFFER = 0.8 # Leave 20% buffer - -# ============================================================================ -# CPU calculation constants -# ============================================================================ - -# RAM reservation overhead -MODEL_RAM_OVERHEAD_GB = 2.0 # Overhead for model loading -CONTEXT_RAM_OVERHEAD_GB = 1.0 # Additional overhead for context - -# ============================================================================ -# Architecture-specific configuration profiles -# ============================================================================ - -# CPU architecture optimization profiles -# Maps architecture to dict of optimization settings -ARCHITECTURE_CPU_PROFILES: Dict[str, Dict[str, Any]] = { - "mistral": { - "use_mmap": True, - }, - "llama3": { - "use_mmap": "dynamic", # Special flag for conditional mmap - }, - "llama2": { - "use_mmap": "dynamic", # Special flag for conditional mmap - }, - "codellama": { - "use_mmap": True, - "logits_all": False, - }, - "phi": { - "use_mmap": True, - }, -} - -# CPU batch size limits per architecture -ARCHITECTURE_CPU_BATCH_LIMITS: Dict[str, Dict[str, int]] = { - "mistral": { - "max_batch": 2048, - "max_ubatch": 1024, - "min_batch": 64, - "min_ubatch": 32, - }, - "llama3": {"max_batch": 1536, "max_ubatch": 768, "min_batch": 64, "min_ubatch": 32}, - "codellama": { - "max_batch": 1536, - "max_ubatch": 768, - "min_batch": 64, - "min_ubatch": 32, - }, - "default": { - "max_batch": 1024, - "max_ubatch": 512, - "min_batch": 32, - "min_ubatch": 16, - }, -} diff --git a/backend/smart_auto/cpu_config.py b/backend/smart_auto/cpu_config.py deleted file mode 100644 index 9a5ce86..0000000 --- a/backend/smart_auto/cpu_config.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -CPU configuration module. -Handles all CPU-specific configuration logic for model inference. -""" - -from typing import Dict, Any, Optional, Tuple -import psutil - -from .architecture_config import get_architecture_default_context -from .memory_estimator import ( - get_cpu_memory_gb, - tokens_per_gb_by_model_size, - ctx_tokens_budget_greedy, -) -from .calculators import calculate_optimal_batch_size_cpu, calculate_ubatch_size -from .constants import ( - MODEL_RAM_OVERHEAD_GB, - CONTEXT_RAM_OVERHEAD_GB, - MAX_CPU_CONTEXT_SIZE, - MIN_CONTEXT_SIZE, - ARCHITECTURE_CPU_PROFILES, - ARCHITECTURE_CPU_BATCH_LIMITS, -) - - -def get_optimal_cpu_context_size( - architecture: str, available_ram_gb: float, model_size_mb: float -) -> int: - """Calculate optimal context size for CPU-only mode based on available RAM.""" - base_context = get_architecture_default_context(architecture) - - # Calculate how much RAM we can allocate for context - # Reserve space for model + overhead - model_ram_gb = model_size_mb / 1024 - reserved_ram_gb = model_ram_gb + MODEL_RAM_OVERHEAD_GB - available_for_context = max(0, available_ram_gb - reserved_ram_gb) - - # Estimate context memory usage (rough: 1MB per 1000 tokens) - max_context_tokens = int(available_for_context * 1000) # 1GB = ~1000 tokens - - # Apply architecture-specific limits - if architecture == "mistral": - # Mistral can handle very large contexts - optimal_context = min(base_context, max_context_tokens) - elif architecture in ["llama3", "codellama"]: - # Llama3 and CodeLlama have good context handling - optimal_context = min(base_context, max_context_tokens) - else: - # Conservative for other architectures - optimal_context = min(base_context, max_context_tokens, MAX_CPU_CONTEXT_SIZE) - - # Ensure minimum context size - return max(MIN_CONTEXT_SIZE, optimal_context) - - -def calculate_optimal_batch_sizes( - available_ram_gb: float, model_size_mb: float, ctx_size: int, architecture: str -) -> Tuple[int, int]: - """Calculate optimal batch sizes for CPU mode.""" - return calculate_optimal_batch_size_cpu( - available_ram_gb, model_size_mb, ctx_size, architecture - ) - - -def get_optimal_parallel_cpu(available_ram_gb: float, model_size_mb: float) -> int: - """Calculate optimal parallel sequences for CPU mode.""" - model_ram_gb = model_size_mb / 1024 - - # Calculate how many parallel sequences we can run - # Each parallel sequence needs roughly 1GB of RAM - max_parallel = int(available_ram_gb / (model_ram_gb + 1.0)) - - # Apply reasonable limits - if available_ram_gb >= 32: # High RAM system - return min(8, max(1, max_parallel)) - elif available_ram_gb >= 16: # Mid RAM system - return min(4, max(1, max_parallel)) - else: # Low RAM system - return min(2, max(1, max_parallel)) - - -def get_cpu_architecture_optimizations( - architecture: str, available_ram_gb: float -) -> Dict[str, Any]: - """Get architecture-specific optimizations for CPU mode using dict-based profiles.""" - # Get architecture-specific profile, or empty dict if not found - profile = ARCHITECTURE_CPU_PROFILES.get(architecture, {}) - optimizations = dict(profile) # Copy to avoid mutating the original - - # Handle dynamic mmap setting for llama architectures - if optimizations.get("use_mmap") == "dynamic": - optimizations["use_mmap"] = available_ram_gb < 16 - - # Common CPU optimizations applied to all architectures - optimizations.update( - { - "embedding": False, # Disable embedding mode for inference - "cont_batching": True, # Enable continuous batching for efficiency - "no_kv_offload": True, # Don't offload KV cache (CPU mode) - } - ) - - return optimizations - - -def generate_cpu_config( - model_size_mb: float, - architecture: str, - layer_count: int = 32, - context_length: int = 4096, - vocab_size: int = 0, - embedding_length: int = 0, - attention_head_count: int = 0, - debug: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Generate CPU-only configuration optimized for available RAM.""" - # Get system memory info (from centralized helper) - total_ram_gb, used_ram_gb, available_ram_gb = get_cpu_memory_gb() - if debug is not None: - debug.update( - { - "cpu_total_ram_gb": total_ram_gb, - "cpu_available_ram_gb": available_ram_gb, - } - ) - - # Estimate CPU threads (leave some cores free for system) - cpu_count_phys = psutil.cpu_count(logical=False) or 1 - logical_cpu_count = psutil.cpu_count(logical=True) or cpu_count_phys - threads = max(1, cpu_count_phys - 1) # Leave 1 core for system - threads_batch = max( - 1, min(threads, max(1, logical_cpu_count - 2)) - ) # Guard negatives - - # Calculate optimal context size based on model's max and available RAM (no hard cap) - base_ctx = max(512, context_length or 4096) - model_gb = max(0.001, model_size_mb / 1024.0) - # Tokens per GB heuristic (centralized) - tokens_per_gb = tokens_per_gb_by_model_size(model_gb) - # Reserve RAM for model + overhead using actual available RAM - reserved_ram_gb = model_gb + 2.0 - available_for_ctx_gb = max(0.0, available_ram_gb - reserved_ram_gb) - # Provide a small minimum window so we don't quantize to zero - if available_for_ctx_gb <= 0: - available_for_ctx_gb = max(0.25, available_ram_gb * 0.1) - if debug is not None: - debug.update( - { - "model_gb": model_gb, - "tokens_per_gb": tokens_per_gb, - "reserved_ram_gb": reserved_ram_gb, - "available_for_ctx_gb": available_for_ctx_gb, - } - ) - # Initial cap ignoring batch/parallel - max_tokens_by_ram = ctx_tokens_budget_greedy( - model_gb, available_ram_gb, reserve_overhead_gb=2.0 - ) - optimal_ctx_size = max(512, min(base_ctx, max_tokens_by_ram)) - - # Calculate optimal batch sizes using centralized function - batch_size, ubatch_size = calculate_optimal_batch_size_cpu( - available_ram_gb, model_size_mb, optimal_ctx_size, architecture - ) - - # Adjust ctx_size to account for batch and parallel (ctx * batch * parallel <= tokens_budget) - parallel = 1 - tokens_budget = int(tokens_per_gb * available_for_ctx_gb) - if tokens_budget > 0: - # Budget ctx tokens directly from available RAM; batch is handled separately - safe_ctx = int(tokens_budget) - optimal_ctx_size = max(512, min(optimal_ctx_size, safe_ctx)) - if debug is not None: - debug.update( - { - "tokens_budget": tokens_budget, - "batch_size": batch_size, - "ubatch_size": ubatch_size, - "parallel": parallel, - "optimal_ctx_size": optimal_ctx_size, - } - ) - - config = { - "threads": threads, - "threads_batch": threads_batch, - "ctx_size": optimal_ctx_size, - "batch_size": batch_size, - "ubatch_size": ubatch_size, - "parallel": parallel, - "no_mmap": False, - "mlock": False, - "low_vram": False, - "logits_all": False, # Don't compute all logits to save memory - } - - # Add architecture-specific optimizations - config.update(get_cpu_architecture_optimizations(architecture, available_ram_gb)) - - return config diff --git a/backend/smart_auto/generation_params.py b/backend/smart_auto/generation_params.py deleted file mode 100644 index 6cdc1b4..0000000 --- a/backend/smart_auto/generation_params.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Dict, Any, Optional - - -def safe_float(val: Any, default: float = 0.0) -> float: - """Safely convert value to float, returning default on failure.""" - try: - return float(val) - except (ValueError, TypeError, OverflowError): - return default - - -def safe_int(val: Any, default: int = 0) -> int: - """Safely convert value to int, returning default on failure.""" - try: - return int(val) - except (ValueError, TypeError, OverflowError): - return default - - -def clamp_float(val: Any, lo: float, hi: float, default: float) -> float: - try: - fv = float(val) - except Exception: - return default - return max(lo, min(hi, fv)) - - -def clamp_int(val: Any, lo: int, hi: int, default: int) -> int: - try: - iv = int(val) - except Exception: - return default - return max(lo, min(hi, iv)) - - -def build_generation_params( - architecture: str, context_length: int, preset_overrides: Dict[str, Any] | None -) -> Dict[str, Any]: - params: Dict[str, Any] = {} - - params.update( - { - "temperature": 0.8, - "top_p": 0.9, - "typical_p": 1.0, - "min_p": 0.0, - "tfs_z": 1.0, - "top_k": 40, - "repeat_penalty": 1.1, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "mirostat": 0, - "mirostat_tau": 5.0, - "mirostat_eta": 0.1, - "ctx_size": max(512, int(context_length or 0)), - "stop": [], - } - ) - - if preset_overrides: - params.update(preset_overrides) - - params["temp"] = params.get("temperature", params.get("temp", 0.8)) - - params["temperature"] = clamp_float(params.get("temperature", 0.8), 0.0, 2.0, 0.8) - params["top_p"] = clamp_float(params.get("top_p", 0.9), 0.0, 1.0, 0.9) - params["min_p"] = clamp_float(params.get("min_p", 0.0), 0.0, 1.0, 0.0) - params["typical_p"] = clamp_float(params.get("typical_p", 1.0), 0.0, 1.0, 1.0) - params["tfs_z"] = clamp_float(params.get("tfs_z", 1.0), 0.0, 1.0, 1.0) - params["top_k"] = max(0, int(params.get("top_k", 40) or 0)) - params["repeat_penalty"] = max(0.0, float(params.get("repeat_penalty", 1.1) or 1.1)) - params["presence_penalty"] = float(params.get("presence_penalty", 0.0) or 0.0) - params["frequency_penalty"] = float(params.get("frequency_penalty", 0.0) or 0.0) - params["mirostat"] = max(0, min(2, int(params.get("mirostat", 0) or 0))) - params["mirostat_tau"] = clamp_float( - params.get("mirostat_tau", 5.0), 0.1, 20.0, 5.0 - ) - params["mirostat_eta"] = clamp_float( - params.get("mirostat_eta", 0.1), 0.01, 2.0, 0.1 - ) - params["ctx_size"] = max( - 512, int(params.get("ctx_size", context_length) or context_length) - ) - if not isinstance(params.get("stop", []), list): - params["stop"] = [] - - return params diff --git a/backend/smart_auto/gpu_config.py b/backend/smart_auto/gpu_config.py deleted file mode 100644 index af6b423..0000000 --- a/backend/smart_auto/gpu_config.py +++ /dev/null @@ -1,469 +0,0 @@ -""" -GPU configuration module. -Handles all GPU-specific configuration logic including single GPU, multi-GPU, and NVLink topologies. -""" - -from typing import Dict, Any, Optional, List -import psutil - -from .architecture_config import get_architecture_default_context -from .calculators import ( - calculate_optimal_batch_size_gpu, - calculate_max_context_size_gpu, - calculate_optimal_context_size_gpu, - calculate_optimal_gpu_layers, - calculate_ubatch_size, -) -from .constants import ( - VRAM_FRAGMENTATION_MARGIN, - VRAM_SAFETY_MARGIN, - MIN_CONTEXT_SIZE, - MAX_CONTEXT_SIZE, - MIN_BATCH_SIZE, - MAX_BATCH_SIZE, -) - - -def parse_compute_capability(value: str) -> float: - """Parse compute capability like '8.0', '7.5' to a float safely.""" - try: - parts = str(value).split(".") - major = int(parts[0]) if parts and parts[0].isdigit() else 0 - minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0 - return major + minor / 10.0 - except Exception: - return 0.0 - - -def calculate_optimal_batch_size( - available_vram_gb: float, - model_size_mb: float, - context_size: int, - embedding_length: int, - layer_count: int, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, -) -> int: - """Calculate optimal batch size based on memory and throughput analysis.""" - return calculate_optimal_batch_size_gpu( - available_vram_gb, - model_size_mb, - context_size, - embedding_length, - layer_count, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - ) - - -def calculate_max_context_size( - available_vram_gb: float, - model_size_mb: float, - layer_count: int, - embedding_length: int, - attention_head_count: int, - attention_head_count_kv: int, -) -> int: - """Calculate maximum context size based on actual memory requirements.""" - return calculate_max_context_size_gpu( - available_vram_gb, - model_size_mb, - layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv, - ) - - -def get_optimal_context_size( - architecture: str, - available_vram: int, - model_size_mb: float = 0, - layer_count: int = 32, - embedding_length: int = 0, - attention_head_count: int = 0, - attention_head_count_kv: int = 0, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", -) -> int: - """Calculate optimal context size based on actual memory requirements and architecture.""" - base_context = get_architecture_default_context(architecture) - return calculate_optimal_context_size_gpu( - architecture, - available_vram, - model_size_mb, - layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv, - base_context, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - usage_mode=usage_mode, - ) - - -def single_gpu_config( - model_size_mb: float, - architecture: str, - gpu: Dict, - layer_count: int = 32, - embedding_length: int = 0, - attention_head_count: int = 0, - attention_head_count_kv: int = 0, - compute_capability: float = 0.0, - context_length: int = 4096, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", -) -> Dict[str, Any]: - """Configuration for single GPU. - - Args: - compute_capability: Pre-parsed compute capability (e.g., 8.0, 7.5). - Use 0.0 if not available. - """ - # Extract frequently accessed values to avoid repeated dict lookups - gpu_memory = gpu.get("memory", {}) - vram_gb = gpu_memory.get("total", 0) / (1024**3) - free_vram_gb = gpu_memory.get("free", 0) / (1024**3) - gpu_index = gpu.get("index", 0) - - # Use calculator for GPU layer estimation - # Use exact M_kv/M_compute calculation with estimated context/batch values - # These will be refined later, but using estimates here gives better initial calculation - n_gpu_layers = calculate_optimal_gpu_layers( - free_vram_gb, - model_size_mb, - layer_count, - context_size=context_length, # Use architecture default context - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - ubatch_size=512, # Reasonable estimate for initial calculation - attention_head_count=attention_head_count, - attention_head_count_kv=attention_head_count_kv, - embedding_length=embedding_length, - usage_mode=usage_mode, - ) - - config = { - "n_gpu_layers": n_gpu_layers, - "main_gpu": gpu_index, - "threads": max(1, (psutil.cpu_count(logical=False) or 2) - 2), - "threads_batch": max(1, (psutil.cpu_count(logical=False) or 2) - 2), - } - - # Calculate optimal batch sizes based on actual memory requirements - # Note: Use architecture default context_length (will be refined later in generate_gpu_config) - # Use selected KV cache quantization if provided - if embedding_length > 0 and layer_count > 0: - # Use data-driven calculation with architecture default context length - optimal_batch_size = calculate_optimal_batch_size( - free_vram_gb, - model_size_mb, - context_length, - embedding_length, - layer_count, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - ) - config["batch_size"] = max( - MIN_BATCH_SIZE, min(MAX_BATCH_SIZE, optimal_batch_size) - ) - config["ubatch_size"] = calculate_ubatch_size(config["batch_size"]) - else: - # Fallback to VRAM-based estimation - if vram_gb >= 24: # High-end GPU - config["batch_size"] = min(1024, max(256, int(vram_gb * 30))) - config["ubatch_size"] = min(512, max(128, int(vram_gb * 15))) - elif vram_gb >= 12: # Mid-range GPU - config["batch_size"] = min(512, max(128, int(vram_gb * 25))) - config["ubatch_size"] = min(256, max(64, int(vram_gb * 12))) - elif vram_gb >= 8: # Lower-end GPU - config["batch_size"] = min(256, max(64, int(vram_gb * 20))) - config["ubatch_size"] = min(128, max(32, int(vram_gb * 10))) - else: # Very limited VRAM - config["batch_size"] = min(128, max(32, int(vram_gb * 15))) - config["ubatch_size"] = min(64, max(16, int(vram_gb * 7))) - - # Enable flash attention for supported GPUs (Ampere and newer: >= 8.0) - if compute_capability >= 8.0: - config["flash_attn"] = True - - return config - - -def multi_gpu_config( - model_size_mb: float, - architecture: str, - gpus: list, - nvlink_topology: Dict, - layer_count: int = 32, - compute_capabilities: Optional[List[float]] = None, -) -> Dict[str, Any]: - """Configuration for multiple GPUs with NVLink awareness. - - Args: - compute_capabilities: Pre-parsed compute capabilities list. If None, will parse from gpus. - """ - config = { - "main_gpu": 0, - "n_gpu_layers": -1, # Use all layers - "threads": max(1, psutil.cpu_count(logical=False) - 2), - "threads_batch": max(1, psutil.cpu_count(logical=False) - 2), - } - - # Enable flash attention if all GPUs support it (Ampere and newer: >= 8.0) - if compute_capabilities: - # Use pre-parsed compute capabilities - if all(cc >= 8.0 for cc in compute_capabilities): - config["flash_attn"] = True - else: - # Fallback: parse from gpus if not provided - if all( - parse_compute_capability(gpu.get("compute_capability", "0.0")) >= 8.0 - for gpu in gpus - ): - config["flash_attn"] = True - - # Configure based on NVLink topology - strategy = nvlink_topology.get("recommended_strategy", "pcie_only") - - if strategy == "nvlink_unified": - # All GPUs connected via NVLink - use unified memory approach - config.update(nvlink_unified_config(gpus, nvlink_topology)) - elif strategy == "nvlink_clustered": - # Multiple NVLink clusters - optimize per cluster - config.update(nvlink_clustered_config(gpus, nvlink_topology)) - elif strategy == "nvlink_partial": - # Partial NVLink connectivity - hybrid approach - config.update(nvlink_partial_config(gpus, nvlink_topology)) - else: - # PCIe only - traditional tensor splitting - config.update(pcie_only_config(gpus)) - - return config - - -def nvlink_unified_config(gpus: list, nvlink_topology: Dict) -> Dict[str, Any]: - """Configuration for unified NVLink cluster.""" - # With NVLink, we can use more aggressive tensor splitting - # Extract memory values once to avoid repeated dict lookups - vram_sizes = [gpu.get("memory", {}).get("total", 0) for gpu in gpus] - total_vram = sum(vram_sizes) - total_vram_gb = total_vram / (1024**3) - - # Pre-calculate ratios as floats, format only at the end - tensor_split = [ - f"{vram / total_vram:.3f}" if total_vram > 0 else "0.000" for vram in vram_sizes - ] - - return { - "tensor_split": ",".join(tensor_split), - "parallel": min(8, len(gpus) * 2), # Higher parallelism with NVLink - "batch_size": min( - 4096, max(512, int(total_vram_gb * 150)) - ), # Larger batches for high VRAM - "ubatch_size": min(2048, max(256, int(total_vram_gb * 75))), - } - - -def nvlink_clustered_config(gpus: list, nvlink_topology: Dict) -> Dict[str, Any]: - """Configuration for multiple NVLink clusters.""" - # Extract clusters once to avoid repeated dict lookup - clusters = nvlink_topology.get("clusters", []) - - if not clusters: - return pcie_only_config(gpus) - - # Use the largest cluster for primary processing - largest_cluster = max(clusters, key=lambda c: len(c["gpus"])) - cluster_gpu_indices = set(largest_cluster["gpus"]) - - # Configure tensor split for the largest cluster - # Pre-extract all GPU memory values once to avoid repeated dict lookups - gpu_memories = [gpu.get("memory", {}) for gpu in gpus] - cluster_vram_sizes = [gpu_memories[i].get("total", 0) for i in cluster_gpu_indices] - total_vram = sum(cluster_vram_sizes) - total_vram_gb = total_vram / (1024**3) - - # Pre-calculate ratios as floats, format only at the end - tensor_split_ratios = [] - for i, gpu_memory in enumerate(gpu_memories): - if i in cluster_gpu_indices: - ratio = gpu_memory.get("total", 0) / total_vram if total_vram > 0 else 0.0 - tensor_split_ratios.append(ratio) - else: - tensor_split_ratios.append(0.0) - - # Format all ratios in a single pass - tensor_split = [f"{ratio:.3f}" for ratio in tensor_split_ratios] - - return { - "tensor_split": ",".join(tensor_split), - "parallel": min(6, len(largest_cluster["gpus"]) * 2), - "batch_size": min(3072, max(384, int(total_vram_gb * 120))), - "ubatch_size": min(1536, max(192, int(total_vram_gb * 60))), - } - - -def nvlink_partial_config(gpus: list, nvlink_topology: Dict) -> Dict[str, Any]: - """Configuration for partial NVLink connectivity.""" - # Use conservative approach for partial NVLink - vram_sizes = [gpu.get("memory", {}).get("total", 0) for gpu in gpus] - total_vram = sum(vram_sizes) - total_vram_gb = total_vram / (1024**3) - - # Pre-calculate ratios as floats, format only at the end - tensor_split = [ - f"{vram / total_vram:.2f}" if total_vram > 0 else "0.00" for vram in vram_sizes - ] - - return { - "tensor_split": ",".join(tensor_split), - "parallel": min(4, len(gpus)), - "batch_size": min(2048, max(256, int(total_vram_gb * 100))), - "ubatch_size": min(1024, max(128, int(total_vram_gb * 50))), - } - - -def pcie_only_config(gpus: list) -> Dict[str, Any]: - """Configuration for PCIe-only multi-GPU setup.""" - # Calculate tensor split based on VRAM - vram_sizes = [gpu.get("memory", {}).get("total", 0) for gpu in gpus] - total_vram = sum(vram_sizes) - total_vram_gb = total_vram / (1024**3) - - # Pre-calculate ratios as floats, format only at the end - tensor_split = [ - f"{vram / total_vram:.2f}" if total_vram > 0 else "0.00" for vram in vram_sizes - ] - - return { - "tensor_split": ",".join(tensor_split), - "parallel": min(2, len(gpus)), # Conservative parallelism for PCIe - "batch_size": min(1024, max(128, int(total_vram_gb * 80))), - "ubatch_size": min(512, max(64, int(total_vram_gb * 40))), - } - - -def generate_gpu_config( - model_size_mb: float, - architecture: str, - gpus: list, - total_vram: int, - gpu_count: int, - nvlink_topology: Dict, - layer_count: int = 32, - context_length: int = 4096, - vocab_size: int = 0, - embedding_length: int = 0, - attention_head_count: int = 0, - attention_head_count_kv: int = 0, - compute_capabilities: Optional[List[float]] = None, - cache_type_k: Optional[str] = None, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", - debug: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Generate GPU-optimized configuration. - - Args: - compute_capabilities: Pre-parsed compute capabilities list from SystemResources. - """ - config = {} - - # Calculate optimal GPU layers - available_vram = sum(gpu.get("memory", {}).get("free", 0) for gpu in gpus) - available_vram_gb = available_vram / (1024**3) - - if gpu_count == 1: - # Use pre-parsed compute capability for single GPU - gpu_cc = ( - compute_capabilities[0] - if compute_capabilities and len(compute_capabilities) > 0 - else 0.0 - ) - config.update( - single_gpu_config( - model_size_mb, - architecture, - gpus[0], - layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv, - gpu_cc, - context_length, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - usage_mode=usage_mode, - ) - ) - else: - config.update( - multi_gpu_config( - model_size_mb, - architecture, - gpus, - nvlink_topology, - layer_count, - compute_capabilities, - ) - ) - - # Context size based on available VRAM and model parameters - # Use selected KV cache quantization if provided - ctx_size = get_optimal_context_size( - architecture, - available_vram, - model_size_mb, - layer_count, - embedding_length, - attention_head_count, - attention_head_count_kv, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - usage_mode=usage_mode, - ) - # Clamp GPU ctx size to sane bounds - config["ctx_size"] = max(MIN_CONTEXT_SIZE, min(ctx_size, MAX_CONTEXT_SIZE)) - if debug is not None: - debug.update( - { - "gpu_available_vram_bytes": int(available_vram), - "gpu_ctx_size": config["ctx_size"], - } - ) - - # Batch sizes based on actual memory requirements - # Use selected KV cache quantization if provided - if embedding_length > 0 and layer_count > 0: - optimal_batch_size = calculate_optimal_batch_size( - available_vram_gb, - model_size_mb, - config["ctx_size"], - embedding_length, - layer_count, - cache_type_k=cache_type_k, - cache_type_v=cache_type_v, - ) - config["batch_size"] = max( - MIN_BATCH_SIZE, min(MAX_BATCH_SIZE, optimal_batch_size) - ) - config["ubatch_size"] = calculate_ubatch_size(config["batch_size"]) - else: - # Fallback to size-based estimation - config["batch_size"] = min(1024, max(64, int(model_size_mb / 50))) - config["ubatch_size"] = min( - config["batch_size"], max(16, int(model_size_mb / 100)) - ) - - # Parallel sequences (conservative for multi-GPU) - if gpu_count > 1: - config["parallel"] = max(1, min(4, gpu_count)) - else: - config["parallel"] = 1 - - return config diff --git a/backend/smart_auto/kv_cache.py b/backend/smart_auto/kv_cache.py deleted file mode 100644 index ed45efb..0000000 --- a/backend/smart_auto/kv_cache.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Dict, Any - - -def get_optimal_kv_cache_quant( - available_vram_gb: float, - context_length: int, - architecture: str, - flash_attn_available: bool = False, -) -> Dict[str, Any]: - """Determine optimal KV cache quantization to balance memory usage and quality.""" - if context_length > 32768: - cache_type_k = "q5_1" if available_vram_gb > 40 else "q4_1" - cache_type_v = cache_type_k if flash_attn_available else None - return {"cache_type_k": cache_type_k, "cache_type_v": cache_type_v} - - if context_length > 8192: - cache_type_k = "q8_0" if available_vram_gb > 24 else "q4_1" - cache_type_v = cache_type_k if flash_attn_available else None - return {"cache_type_k": cache_type_k, "cache_type_v": cache_type_v} - - if available_vram_gb > 16: - return { - "cache_type_k": "f16", - "cache_type_v": "f16" if flash_attn_available else None, - } - - return { - "cache_type_k": "q8_0", - "cache_type_v": "q8_0" if flash_attn_available else None, - } diff --git a/backend/smart_auto/memory_estimator.py b/backend/smart_auto/memory_estimator.py deleted file mode 100644 index 1d10b12..0000000 --- a/backend/smart_auto/memory_estimator.py +++ /dev/null @@ -1,478 +0,0 @@ -""" -Memory estimation module. -Consolidates RAM and VRAM estimation with shared KV cache calculation logic. -Also provides CPU memory utilities. -""" - -from typing import Dict, Any, Tuple, Optional -from functools import lru_cache -import psutil - -from backend.database import Model -from backend.logging_config import get_logger -from .model_metadata import get_model_metadata -from .models import ModelMetadata -from .constants import ( - KV_CACHE_QUANT_FACTORS, - KV_CACHE_OPTIMIZATION_FACTOR, - KV_CACHE_SINGLE_USER_FACTOR, - KV_CACHE_MULTI_USER_FACTOR, - MODEL_RAM_OVERHEAD_GB, - FALLBACK_LAYER_COUNT, - FALLBACK_EMBEDDING_LENGTH, - FALLBACK_KV_CACHE_PER_TOKEN_BYTES, - BATCH_VRAM_OVERHEAD_RATIO, - BATCH_RAM_OVERHEAD_RATIO, - VRAM_SAFETY_MARGIN, - MOE_OFFLOAD_ALL_RATIO, - MOE_OFFLOAD_UP_DOWN_RATIO, - MOE_OFFLOAD_UP_RATIO, - LLAMA_CPP_OVERHEAD_MB, - DEFAULT_BYTES_PER_ELEMENT, - BATCH_INTERMEDIATE_FACTOR, - BATCH_QKV_FACTOR, - BATCH_COMPUTATION_OVERHEAD_KB, - BATCH_FALLBACK_MB, - QUANTIZATION_AVERAGE_FACTOR, - COMPUTE_FIXED_OVERHEAD_MB, - COMPUTE_SCRATCH_PER_UBATCH_MB, -) - -logger = get_logger(__name__) - - -# CPU memory utilities -def get_cpu_memory_gb() -> Tuple[float, float, float]: - """Return (total_gb, used_gb, available_gb) where available = total - used. - Uses actual values, no 60% approximations. - """ - mem = psutil.virtual_memory() - total = mem.total / (1024**3) - used = mem.used / (1024**3) - available = max(0.0, total - used) - return total, used, available - - -@lru_cache(maxsize=64) -def tokens_per_gb_by_model_size(model_size_gb: float) -> int: - """Heuristic tokens per GB for KV budget by model size.""" - if model_size_gb < 2: - return 3000 - if model_size_gb < 6: - return 2000 - if model_size_gb < 12: - return 1300 - return 400 - - -def ctx_tokens_budget_greedy( - model_size_gb: float, available_cpu_ram_gb: float, reserve_overhead_gb: float = None -) -> int: - """Compute context token budget from CPU RAM after reserving model + overhead. - Returns total tokens budget (not divided by batch/parallel). - """ - if reserve_overhead_gb is None: - reserve_overhead_gb = MODEL_RAM_OVERHEAD_GB - reserved = model_size_gb + max(0.0, reserve_overhead_gb) - for_ctx = max(0.0, available_cpu_ram_gb - reserved) - tpg = tokens_per_gb_by_model_size(model_size_gb) - return max(0, int(for_ctx * tpg)) - - -def get_kv_cache_quant_factor(cache_type: str) -> float: - """Get memory reduction factor for KV cache quantization.""" - return KV_CACHE_QUANT_FACTORS.get(cache_type, 1.0) - - -@lru_cache(maxsize=512) -def calculate_kv_cache_size( - ctx_size: int, - parallel: int, - layer_count: int, - embedding_length: int, - attention_head_count: int, - attention_head_count_kv: int, - cache_type_k: str, - cache_type_v: Optional[str] = None, - usage_mode: str = "single_user", -) -> int: - """ - Calculate KV cache size in bytes for memory estimation. - - Cached with LRU to avoid redundant calculations for same parameters. - Cache size increased to 512 for production workloads. - - Returns: - Total KV cache bytes - """ - # Get quantization factors - quant_factor_k = get_kv_cache_quant_factor(cache_type_k) - quant_factor_v = ( - get_kv_cache_quant_factor(cache_type_v) if cache_type_v else quant_factor_k - ) - - if embedding_length > 0 and layer_count > 0: - # Calculate bytes per element for K and V cache - bytes_per_k = ( - quant_factor_k * 4 - ) # Convert factor to actual bytes (f32=4, f16=2, etc.) - bytes_per_v = quant_factor_v * 4 if cache_type_v else bytes_per_k - - # GQA-aware KV cache calculation (correct formula from theoretical model) - # M_kv = n_ctx × N_layers × N_head_kv × d_head × (p_a_k + p_a_v) - # where d_head = N_embd / N_head - if attention_head_count_kv > 0 and attention_head_count > 0: - # Dimension per head - d_head = embedding_length / attention_head_count - # KV cache stores N_head_kv heads per layer, each of size d_head - kv_cache_per_layer_k = attention_head_count_kv * d_head * bytes_per_k - kv_cache_per_layer_v = attention_head_count_kv * d_head * bytes_per_v - else: - # Fallback for non-GQA models (MHA: N_head_kv = N_head) - # In this case, use full embedding dimension - kv_cache_per_layer_k = embedding_length * bytes_per_k - kv_cache_per_layer_v = embedding_length * bytes_per_v - - # Total per token: (Key + Value) * layers - kv_cache_per_token = (kv_cache_per_layer_k + kv_cache_per_layer_v) * layer_count - else: - # Fallback using constants - kv_cache_per_token = FALLBACK_KV_CACHE_PER_TOKEN_BYTES - - # KV cache: ctx_size tokens, each with kv_cache_per_token bytes - # Parallel might create multiple context copies, so multiply by parallel - # Use actual memory (optimization factor is 1.0 - memory mapping doesn't reduce peak usage) - base_kv_cache_bytes = int( - ctx_size * kv_cache_per_token * parallel * KV_CACHE_OPTIMIZATION_FACTOR - ) - - # Apply usage mode factor based on theoretical model: - # - single_user: Peak estimate (full context accumulates, full KV cache) - # - multi_user: Typical usage (context cleared between requests, lower estimate) - if usage_mode == "multi_user": - usage_factor = KV_CACHE_MULTI_USER_FACTOR - else: # single_user or default - usage_factor = KV_CACHE_SINGLE_USER_FACTOR - - kv_cache_bytes = int(base_kv_cache_bytes * usage_factor) - - return kv_cache_bytes - - -def estimate_vram_usage( - model: Model, - config: Dict[str, Any], - gpu_info: Dict[str, Any], - metadata: Optional[ModelMetadata] = None, - usage_mode: str = "single_user", -) -> Dict[str, Any]: - """Estimate VRAM usage for given configuration using comprehensive model metadata - - Args: - model: The model to estimate for - config: Configuration dictionary - gpu_info: GPU information dictionary - metadata: Optional pre-computed ModelMetadata to avoid redundant calls - """ - try: - model_size = model.file_size if model.file_size else 0 - - # Extract frequently accessed config values early to avoid repeated dict lookups - n_gpu_layers = int(config.get("n_gpu_layers", 0) or 0) - ctx_size = int(config.get("ctx_size", 4096) or 4096) - parallel = max(1, int(config.get("parallel", 1) or 1)) - cache_type_k = config.get("cache_type_k", "f16") - cache_type_v = config.get("cache_type_v") - - # Use provided metadata or fetch it (cached internally) - layer_info = metadata if metadata is not None else get_model_metadata(model) - total_layers = max(1, layer_info.layer_count or FALLBACK_LAYER_COUNT) - embedding_length = layer_info.embedding_length or 0 - attention_head_count = layer_info.attention_head_count or 0 - attention_head_count_kv = layer_info.attention_head_count_kv or 0 - - # Layer split between GPU and CPU - layer_ratio = min( - 1.0, max(0.0, (n_gpu_layers / total_layers) if total_layers > 0 else 0.0) - ) - model_vram = int(model_size * layer_ratio) - model_ram = max(0, int(model_size - model_vram)) - - # Use shared KV cache calculation - kv_cache_bytes = calculate_kv_cache_size( - ctx_size, - parallel, - total_layers, - embedding_length, - attention_head_count, - attention_head_count_kv, - cache_type_k, - cache_type_v, - usage_mode=usage_mode, - ) - - # Determine if KV cache goes to VRAM or RAM - # According to theoretical model: when n_gpu_layers > 0, M_kv goes to VRAM by default - # The "VRAM Trap": in hybrid mode, M_kv and M_compute both go to VRAM - if n_gpu_layers > 0: - # In GPU mode (including hybrid), KV cache goes to VRAM - kv_cache_vram = kv_cache_bytes - kv_cache_ram = 0 - else: - # CPU-only mode: KV cache goes to RAM - kv_cache_vram = 0 - kv_cache_ram = kv_cache_bytes - - # M_compute: Fixed overhead + variable scratch buffer - # According to theoretical model: M_compute = M_overhead_fixed + M_scratch_variable(n_ubatch) - ubatch_size = config.get("ubatch_size", 512) - compute_overhead_mb = COMPUTE_FIXED_OVERHEAD_MB + ( - ubatch_size * COMPUTE_SCRATCH_PER_UBATCH_MB - ) - compute_overhead_bytes = int(compute_overhead_mb * 1024 * 1024) - - # Allocate M_compute to VRAM if GPU layers > 0 (VRAM Trap) - if n_gpu_layers > 0: - batch_vram = compute_overhead_bytes - batch_ram = 0 - else: - batch_vram = 0 - batch_ram = compute_overhead_bytes - - estimated_vram = model_vram + kv_cache_vram + batch_vram - estimated_ram = model_ram + kv_cache_ram + batch_ram - - # System RAM usage snapshot - try: - vm = psutil.virtual_memory() - system_ram_used = int(vm.used) - system_ram_total = int(vm.total) - except Exception: - system_ram_used = 0 - system_ram_total = 0 - - # VRAM headroom check - # Extract gpus list once to avoid repeated dict lookup - gpus = gpu_info.get("gpus", []) - total_free_vram = sum(g.get("memory", {}).get("free", 0) for g in gpus) - fits_in_gpu = (n_gpu_layers == 0) or ( - estimated_vram <= max(0, total_free_vram * VRAM_SAFETY_MARGIN) - ) - - memory_mode = "ram_only" - if n_gpu_layers > 0: - if estimated_ram > 0: - memory_mode = "mixed" - else: - memory_mode = "vram_only" - - return { - "memory_mode": memory_mode, - # VRAM - "estimated_vram": estimated_vram, - "model_vram": model_vram, - "kv_cache_vram": kv_cache_vram, - "batch_vram": batch_vram, - # RAM - "estimated_ram": estimated_ram, - "model_ram": model_ram, - "kv_cache_ram": kv_cache_ram, - "batch_ram": batch_ram, - # System RAM snapshot - "system_ram_used": system_ram_used, - "system_ram_total": system_ram_total, - # Fit flag - "fits_in_gpu": fits_in_gpu, - } - except Exception: - try: - vm = psutil.virtual_memory() - system_ram_used = int(vm.used) - system_ram_total = int(vm.total) - except Exception: - system_ram_used = 0 - system_ram_total = 0 - return { - "memory_mode": "unknown", - "estimated_vram": 0, - "model_vram": 0, - "kv_cache_vram": 0, - "batch_vram": 0, - "estimated_ram": 0, - "model_ram": 0, - "kv_cache_ram": 0, - "batch_ram": 0, - "system_ram_used": system_ram_used, - "system_ram_total": system_ram_total, - "fits_in_gpu": True, - } - - -def estimate_ram_usage( - model: Model, - config: Dict[str, Any], - metadata: Optional[ModelMetadata] = None, - usage_mode: str = "single_user", -) -> Dict[str, Any]: - """Estimate RAM usage for given configuration - - Args: - model: The model to estimate for - config: Configuration dictionary - metadata: Optional pre-computed ModelMetadata to avoid redundant calls - """ - try: - model_size = model.file_size if model.file_size else 0 - - # Extract frequently accessed config values early to avoid repeated dict lookups - n_gpu_layers = config.get("n_gpu_layers", 0) - ctx_size = config.get("ctx_size", 4096) - batch_size = config.get("batch_size", 512) - parallel = config.get("parallel", 1) - cache_type_k = config.get("cache_type_k", "f16") - cache_type_v = config.get("cache_type_v") - - # Get system RAM info (extract once) - vm = psutil.virtual_memory() - total_memory = vm.total - available_memory = vm.available - - # Use provided metadata or fetch it (cached internally) - # Use ModelMetadata dataclass attributes directly - layer_info = metadata if metadata is not None else get_model_metadata(model) - total_layers = layer_info.layer_count or FALLBACK_LAYER_COUNT - embedding_length = layer_info.embedding_length or 0 - attention_head_count = layer_info.attention_head_count or 0 - attention_head_count_kv = layer_info.attention_head_count_kv or 0 - is_moe = layer_info.is_moe - - cpu_layers = total_layers - n_gpu_layers if n_gpu_layers > 0 else total_layers - - if n_gpu_layers > 0: - # GPU layers: full model loaded in RAM for GPU transfer - model_ram = model_size - else: - # CPU-only: only CPU layers in RAM - layer_ratio = cpu_layers / total_layers if cpu_layers > 0 else 1 - model_ram = int(model_size * layer_ratio) - - # Enhanced KV cache estimation using model architecture - # Use shared KV cache calculation - kv_cache_ram = calculate_kv_cache_size( - ctx_size, - parallel, - total_layers, - embedding_length, - attention_head_count, - attention_head_count_kv, - cache_type_k, - cache_type_v, - usage_mode=usage_mode, - ) - - # MoE models with CPU offloading use RAM for offloaded layers - moe_cpu_ram = 0 - if is_moe and n_gpu_layers > 0: - moe_pattern = config.get("moe_offload_custom", "") - if moe_pattern: - # Estimate RAM usage for offloaded MoE layers - if ".*_exps" in moe_pattern: - # All MoE offloaded - moe_cpu_ram = int(model_size * MOE_OFFLOAD_ALL_RATIO) - elif "up|down" in moe_pattern: - # Up/Down offloaded - moe_cpu_ram = int(model_size * MOE_OFFLOAD_UP_DOWN_RATIO) - elif "_up_" in moe_pattern: - # Only Up offloaded - moe_cpu_ram = int(model_size * MOE_OFFLOAD_UP_RATIO) - - # Batch processing overhead - if embedding_length > 0: - bytes_per_element = DEFAULT_BYTES_PER_ELEMENT - - # Intermediate activations: batch_size tokens * embedding_length - intermediate_ram = int( - batch_size - * embedding_length - * bytes_per_element - * BATCH_INTERMEDIATE_FACTOR - ) - - # QKV projections are also temporary and reused - qkv_ram = int( - batch_size * 3 * embedding_length * bytes_per_element * BATCH_QKV_FACTOR - ) - - # Additional buffers are minimal and reused - computation_overhead = batch_size * BATCH_COMPUTATION_OVERHEAD_KB * 1024 - - batch_ram = intermediate_ram + qkv_ram + computation_overhead - else: - # Fallback: reduced estimate based on actual usage - batch_ram = batch_size * int(BATCH_FALLBACK_MB * 1024 * 1024) - - # Additional overhead for llama.cpp - llama_overhead = LLAMA_CPP_OVERHEAD_MB * 1024 * 1024 - - total_ram = model_ram + kv_cache_ram + batch_ram + llama_overhead + moe_cpu_ram - - # Check if fits in available RAM - fits_in_ram = total_ram <= available_memory - - # Calculate quantization savings - quant_factor_k = get_kv_cache_quant_factor(cache_type_k) - quant_factor_v = ( - get_kv_cache_quant_factor(cache_type_v) if cache_type_v else quant_factor_k - ) - - # Calculate raw KV cache size (for savings calculation) using correct GQA-aware formula - if embedding_length > 0 and total_layers > 0: - bytes_per_k = quant_factor_k * 4 - bytes_per_v = quant_factor_v * 4 if cache_type_v else bytes_per_k - if attention_head_count_kv > 0 and attention_head_count > 0: - # GQA-aware calculation - d_head = embedding_length / attention_head_count - kv_cache_per_layer_k = attention_head_count_kv * d_head * bytes_per_k - kv_cache_per_layer_v = attention_head_count_kv * d_head * bytes_per_v - else: - # Fallback for non-GQA models - kv_cache_per_layer_k = embedding_length * bytes_per_k - kv_cache_per_layer_v = embedding_length * bytes_per_v - kv_cache_per_token = ( - kv_cache_per_layer_k + kv_cache_per_layer_v - ) * total_layers - else: - kv_cache_per_token = FALLBACK_KV_CACHE_PER_TOKEN_BYTES - - # Calculate savings (difference between f32 and current quantization) - kv_cache_savings = int( - ctx_size - * parallel - * kv_cache_per_token - * ( - 1 - - ( - QUANTIZATION_AVERAGE_FACTOR * quant_factor_k - + QUANTIZATION_AVERAGE_FACTOR * quant_factor_v - ) - ) - ) - - return { - "estimated_ram": total_ram, - "model_ram": model_ram, - "kv_cache_ram": kv_cache_ram, - "batch_ram": batch_ram, - "moe_cpu_ram": moe_cpu_ram, - "llama_overhead": llama_overhead, - "fits_in_ram": fits_in_ram, - "available_ram": available_memory, - "total_ram": total_memory, - "utilization_percent": ( - (total_ram / total_memory * 100) if total_memory > 0 else 0 - ), - "kv_cache_savings": kv_cache_savings, - } - - except Exception as e: - return {"error": str(e), "estimated_ram": 0, "fits_in_ram": False} diff --git a/backend/smart_auto/model_metadata.py b/backend/smart_auto/model_metadata.py deleted file mode 100644 index 925dcc3..0000000 --- a/backend/smart_auto/model_metadata.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Dict, Any -import os -from functools import lru_cache - -from backend.logging_config import get_logger -from backend.gguf_reader import get_model_layer_info -from .architecture_config import resolve_architecture -from .models import ModelMetadata - -logger = get_logger(__name__) - - -@lru_cache(maxsize=256) -def _get_layer_info_from_file(file_path: str, mtime: float) -> Dict[str, Any]: - """ - Get layer info from GGUF file with LRU caching. - Uses mtime as part of cache key for invalidation. - """ - try: - return get_model_layer_info(file_path) or {} - except Exception as e: - logger.warning(f"Failed to read layer info from {file_path}: {e}") - return {} - - -@lru_cache(maxsize=64) -def _estimate_layer_count_cached(model_name: str) -> int: - """Cached version of layer count estimation from model name.""" - if "7b" in model_name or "7B" in model_name: - return 32 - elif "3b" in model_name or "3B" in model_name: - return 28 - elif "1b" in model_name or "1B" in model_name: - return 22 - elif "13b" in model_name or "13B" in model_name: - return 40 - elif "30b" in model_name or "30B" in model_name: - return 60 - elif "65b" in model_name or "65B" in model_name: - return 80 - else: - return 32 # Default fallback - - -def get_model_metadata(model) -> ModelMetadata: - """ - Get comprehensive model metadata with caching. - - Uses LRU cache with mtime-based invalidation to prevent redundant file I/O. - This is the single source of truth for model layer information. - """ - # Default metadata structure - meta: Dict[str, Any] = { - "layer_count": 32, - "architecture": "unknown", - "context_length": 0, - "vocab_size": 0, - "embedding_length": 0, - "attention_head_count": 0, - "attention_head_count_kv": 0, - "block_count": 0, - "is_moe": False, - "expert_count": 0, - "experts_used_count": 0, - } - - try: - if model.file_path and os.path.exists(model.file_path): - # Use LRU cache with mtime-based invalidation - mtime = os.path.getmtime(model.file_path) - layer_info = _get_layer_info_from_file(model.file_path, mtime) - if layer_info: - meta.update(layer_info) - - # Resolve architecture from GGUF metadata - raw_architecture = meta.get("architecture", "") - normalized = resolve_architecture(raw_architecture) - meta["architecture"] = normalized - - if ( - normalized not in ("unknown", "generic") - and raw_architecture != normalized - ): - logger.debug( - f"Resolved architecture: '{raw_architecture}' -> '{normalized}'" - ) - except Exception as e: - logger.warning( - f"Failed to read GGUF metadata for model {getattr(model, 'id', 'unknown')}: {e}" - ) - - # Fallback to name-based detection if architecture is still unknown - current_arch = meta.get("architecture", "").strip() - if not current_arch or current_arch == "unknown": - detected = resolve_architecture(getattr(model, "name", "")) - meta["architecture"] = detected - if detected not in ("unknown", "generic"): - logger.debug(f"Detected architecture from model name: '{detected}'") - - # Fallback to name-based layer count estimation if needed - if meta.get("layer_count", 0) == 32 and current_arch == "unknown": - model_name = getattr(model, "name", "").lower() - meta["layer_count"] = _estimate_layer_count_cached(model_name) - - # Return as ModelMetadata dataclass - return ModelMetadata.from_dict(meta) diff --git a/backend/smart_auto/models.py b/backend/smart_auto/models.py deleted file mode 100644 index d3dec12..0000000 --- a/backend/smart_auto/models.py +++ /dev/null @@ -1,216 +0,0 @@ -""" -Data models for smart_auto module. -Provides type-safe data classes to replace dictionary passing throughout the module. -""" - -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Any, Tuple - - -@dataclass -class ModelMetadata: - """Comprehensive model metadata extracted from GGUF file or name.""" - - layer_count: int - architecture: str - context_length: int - vocab_size: int - embedding_length: int - attention_head_count: int - attention_head_count_kv: int - block_count: int = 0 - is_moe: bool = False - expert_count: int = 0 - experts_used_count: int = 0 - parameter_count: Optional[str] = None # Formatted as "32B", "36B", etc. - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ModelMetadata": - """Create ModelMetadata from a dictionary (e.g., from get_model_metadata result).""" - return cls( - layer_count=data.get("layer_count", 32), - architecture=data.get("architecture", "unknown"), - context_length=data.get("context_length", 0), - vocab_size=data.get("vocab_size", 0), - embedding_length=data.get("embedding_length", 0), - attention_head_count=data.get("attention_head_count", 0), - attention_head_count_kv=data.get("attention_head_count_kv", 0), - block_count=data.get("block_count", 0), - is_moe=data.get("is_moe", False), - expert_count=data.get("expert_count", 0), - experts_used_count=data.get("experts_used_count", 0), - parameter_count=data.get("parameter_count"), - ) - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for backward compatibility.""" - return { - "layer_count": self.layer_count, - "architecture": self.architecture, - "context_length": self.context_length, - "vocab_size": self.vocab_size, - "embedding_length": self.embedding_length, - "attention_head_count": self.attention_head_count, - "attention_head_count_kv": self.attention_head_count_kv, - "block_count": self.block_count, - "is_moe": self.is_moe, - "expert_count": self.expert_count, - "experts_used_count": self.experts_used_count, - "parameter_count": self.parameter_count, - } - - -@dataclass -class SystemResources: - """System resources information.""" - - gpus: List[Dict[str, Any]] - total_vram: int - available_vram_gb: float - gpu_count: int - nvlink_topology: Dict[str, Any] - cpu_cores: int - cpu_memory_gb: Tuple[float, float, float] # total, used, available - flash_attn_available: bool = False - compute_capabilities: List[float] = field( - default_factory=list - ) # Pre-parsed compute capabilities - - @classmethod - def from_gpu_info( - cls, - gpu_info: Dict[str, Any], - cpu_memory: Tuple[float, float, float], - cpu_cores: int, - flash_attn_available: bool = False, - ) -> "SystemResources": - """Create SystemResources from gpu_info and system data.""" - gpus = gpu_info.get("gpus", []) - - # Pre-parse compute capabilities to avoid repeated string parsing - compute_capabilities = [] - for gpu in gpus: - cc_str = gpu.get("compute_capability", "0.0") - try: - parts = str(cc_str).split(".") - major = int(parts[0]) if parts and parts[0].isdigit() else 0 - minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0 - compute_capabilities.append(major + minor / 10.0) - except Exception: - compute_capabilities.append(0.0) - - return cls( - gpus=gpus, - total_vram=gpu_info.get("total_vram", 0), - available_vram_gb=( - sum(gpu.get("memory", {}).get("free", 0) for gpu in gpus) / (1024**3) - if gpus - else 0.0 - ), - gpu_count=gpu_info.get("device_count", 0), - nvlink_topology=gpu_info.get("nvlink_topology", {}), - cpu_cores=cpu_cores, - cpu_memory_gb=cpu_memory, - flash_attn_available=flash_attn_available, - compute_capabilities=compute_capabilities, - ) - - -@dataclass -class GenerationConfig: - """Complete generation configuration with type-safe fields.""" - - # GPU configuration - n_gpu_layers: int = 0 - main_gpu: int = 0 - tensor_split: str = "" - flash_attn: bool = False - - # Memory and context - ctx_size: int = 4096 - batch_size: int = 512 - ubatch_size: int = 256 - parallel: int = 1 - - # CPU configuration - threads: int = 4 - threads_batch: int = 4 - - # Memory optimization - no_mmap: bool = False - mlock: bool = False - low_vram: bool = False - logits_all: bool = False - cont_batching: bool = True - no_kv_offload: bool = False - - # Generation parameters - temperature: float = 0.8 - temp: float = 0.8 - top_p: float = 0.9 - top_k: int = 40 - typical_p: float = 1.0 - min_p: float = 0.0 - tfs_z: float = 1.0 - repeat_penalty: float = 1.1 - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - mirostat: int = 0 - mirostat_tau: float = 5.0 - mirostat_eta: float = 0.1 - n_predict: int = -1 - stop: List[str] = field(default_factory=list) - seed: int = -1 - - # KV cache optimization - cache_type_k: str = "f16" - cache_type_v: Optional[str] = None - - # Architecture-specific - rope_freq_base: Optional[float] = None - rope_freq_scale: Optional[float] = None - rope_scaling: str = "" - yarn_ext_factor: float = 1.0 - yarn_attn_factor: float = 1.0 - - # MoE configuration - moe_offload_pattern: str = "none" - moe_offload_custom: str = "" - - # Special flags - embedding: bool = False - jinja: bool = False - - # Server parameters - host: str = "0.0.0.0" - port: int = 0 - timeout: int = 300 - - # Additional fields for backward compatibility - yaml: str = "" - customArgs: str = "" - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for llama-swap integration.""" - result = {} - for key, value in self.__dict__.items(): - # Skip None values to keep config clean - if value is not None and value != []: - result[key] = value - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GenerationConfig": - """Create GenerationConfig from dictionary.""" - # Filter out only fields that exist in the dataclass - field_names = {f.name for f in cls.__dataclass_fields__.values()} - filtered_data = {k: v for k, v in data.items() if k in field_names} - return cls(**filtered_data) - - def update(self, updates: Dict[str, Any]) -> "GenerationConfig": - """Create a new config with updates applied.""" - new_dict = self.to_dict() - new_dict.update(updates) - # Filter out None values and empty lists - new_dict = {k: v for k, v in new_dict.items() if v is not None and v != []} - return self.from_dict(new_dict) diff --git a/backend/smart_auto/moe_handler.py b/backend/smart_auto/moe_handler.py deleted file mode 100644 index fa44757..0000000 --- a/backend/smart_auto/moe_handler.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -MoE (Mixture of Experts) handling module. -Handles MoE model offloading patterns and architecture-specific flags. -""" - -from typing import Dict, Any, Tuple -from backend.logging_config import get_logger -from .constants import VRAM_RATIO_VERY_TIGHT, VRAM_RATIO_TIGHT, VRAM_RATIO_MODERATE - -logger = get_logger(__name__) - - -# MoE offload strategies: (vram_ratio_threshold, pattern) -MOE_OFFLOAD_STRATEGIES: list[Tuple[float, str]] = [ - (VRAM_RATIO_VERY_TIGHT, ".ffn_.*_exps.=CPU"), # Very tight: all MoE offloaded - (VRAM_RATIO_TIGHT, ".ffn_(up|down)_exps.=CPU"), # Tight: up/down offloaded - (VRAM_RATIO_MODERATE, ".ffn_(up)_exps.=CPU"), # Moderate: only up offloaded - (float("inf"), ""), # Ample: no offloading -] - - -def generate_moe_offload_pattern( - architecture: str, - available_vram_gb: float, - model_size_mb: float, - is_moe: bool = False, - expert_count: int = 0, -) -> str: - """Generate optimal MoE offloading pattern based on VRAM availability - - Returns regex pattern for the -ot (offload type) parameter to control MoE layer placement - """ - if not is_moe or expert_count == 0: - return "" # No MoE offloading for non-MoE models - - model_size_gb = model_size_mb / 1024 - - # Calculate VRAM pressure - vram_ratio = available_vram_gb / model_size_gb if model_size_gb > 0 else 1.0 - - # Find the appropriate strategy based on VRAM ratio - for threshold, pattern in MOE_OFFLOAD_STRATEGIES: - if vram_ratio < threshold: - return pattern - - return "" # Fallback: no offloading needed - - -def needs_jinja_template(architecture: str, layer_info: Dict[str, Any]) -> bool: - """Determine if architecture requires jinja template.""" - # GLM architectures always need jinja - if architecture in ["glm", "glm4"]: - return True - # Qwen3 coder variants need jinja - if architecture == "qwen3": - arch_str = layer_info.get("architecture", "").lower() - if "coder" in arch_str: - return True - return False - - -def get_architecture_specific_flags( - architecture: str, layer_info: Dict[str, Any] -) -> Dict[str, Any]: - """Get architecture-specific flags and settings. - - Returns dict with flags like jinja, moe_offload_custom, etc. - """ - flags = {"jinja": False, "moe_offload_custom": ""} - - # Check jinja requirement - if needs_jinja_template(architecture, layer_info): - flags["jinja"] = True - logger.info(f"{architecture} architecture detected - enabling jinja template") - - # Generate MoE offloading pattern if applicable - is_moe = layer_info.get("is_moe", False) - available_vram_gb = layer_info.get("available_vram_gb", 0) - - if not is_moe or available_vram_gb == 0: - if is_moe and available_vram_gb == 0: - logger.debug( - "MoE model detected but no GPU available - MoE layers will run on CPU" - ) - return flags - - # Generate MoE offload pattern for GPU mode - expert_count = layer_info.get("expert_count", 0) - model_size_mb = layer_info.get("model_size_mb", 0) - moe_pattern = generate_moe_offload_pattern( - architecture, available_vram_gb, model_size_mb, is_moe, expert_count - ) - - if moe_pattern: - flags["moe_offload_custom"] = moe_pattern - logger.debug(f"Generated MoE offload pattern: {moe_pattern}") - - return flags diff --git a/backend/smart_auto/optimizer.py b/backend/smart_auto/optimizer.py deleted file mode 100644 index 8622abf..0000000 --- a/backend/smart_auto/optimizer.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Joint optimization algorithm for llama.cpp configuration. - -Implements the theoretical model algorithm from Section III-B that jointly -optimizes (n_ngl, n_ctx, ubatch_size) given VRAM and RAM constraints. - -Prioritizes "Full Offload" (Max_Speed) regime before falling back to "Hybrid" mode. -""" - -from typing import Dict, Any, Optional, Tuple -from .memory_estimator import calculate_kv_cache_size -from .constants import ( - COMPUTE_FIXED_OVERHEAD_MB, - COMPUTE_SCRATCH_PER_UBATCH_MB, - MIN_CONTEXT_SIZE, - MIN_BATCH_SIZE, -) - - -def find_optimal_config( - model_size_bytes: int, - total_layers: int, - embedding_length: int, - attention_head_count: int, - attention_head_count_kv: int, - available_vram_bytes: int, - available_ram_bytes: int, - cache_type_k: str = "f16", - cache_type_v: Optional[str] = None, - ubatch_size: int = 512, - desired_performance: str = "Max_Speed", - min_context_size: int = MIN_CONTEXT_SIZE, -) -> Dict[str, Any]: - """ - Find optimal configuration using joint optimization algorithm. - - Implements the theoretical model algorithm that: - 1. Prioritizes "Full Offload" (Max_Speed) regime first - 2. Falls back to "Hybrid" mode if full offload fails - 3. Maximizes context length (n_ctx) given VRAM constraint - - Args: - model_size_bytes: Total model size in bytes (GGUF file size) - total_layers: Total number of layers (N_layers) - embedding_length: Hidden embedding dimension (N_embd) - attention_head_count: Number of attention heads (N_head) - attention_head_count_kv: Number of KV attention heads (N_head_kv) - available_vram_bytes: Available VRAM in bytes - available_ram_bytes: Available RAM in bytes - cache_type_k: K cache quantization type - cache_type_v: V cache quantization type (default: same as cache_type_k) - ubatch_size: Micro-batch size for M_compute calculation - desired_performance: 'Max_Speed' or 'Max_Context' - min_context_size: Minimum acceptable context size - - Returns: - Dictionary with: - - mode: "Full_Offload", "Hybrid_Mode", "Full_Offload_Failed", or "Insufficient_Memory" - - n_ngl_best: Optimal number of GPU layers - - n_ctx_best: Optimal context size - - cache_type_k: Selected K cache quantization - - cache_type_v: Selected V cache quantization - - ubatch_size: Optimal micro-batch size - """ - # Step 1: Calculate model constants - N_layers = total_layers - M_weights_total = model_size_bytes - - # Calculate KV cache cost per token (C_kv_per_token) - # Using GQA-aware formula: M_kv = n_ctx × N_layers × N_head_kv × d_head × (p_a_k + p_a_v) - if embedding_length > 0 and attention_head_count > 0: - from .constants import KV_CACHE_QUANT_FACTORS - - quant_factor_k = KV_CACHE_QUANT_FACTORS.get(cache_type_k, 0.5) - quant_factor_v = KV_CACHE_QUANT_FACTORS.get( - cache_type_v or cache_type_k, quant_factor_k - ) - bytes_per_k = quant_factor_k * 4 - bytes_per_v = quant_factor_v * 4 - - if attention_head_count_kv > 0: - d_head = embedding_length / attention_head_count - kv_cache_per_layer_k = attention_head_count_kv * d_head * bytes_per_k - kv_cache_per_layer_v = attention_head_count_kv * d_head * bytes_per_v - else: - # Fallback for non-GQA models - kv_cache_per_layer_k = embedding_length * bytes_per_k - kv_cache_per_layer_v = embedding_length * bytes_per_v - - C_kv_per_token = (kv_cache_per_layer_k + kv_cache_per_layer_v) * N_layers - else: - # Fallback: conservative estimate - C_kv_per_token = 1024 # ~1KB per token fallback - - # Calculate M_compute - M_compute_bytes = int( - (COMPUTE_FIXED_OVERHEAD_MB + (ubatch_size * COMPUTE_SCRATCH_PER_UBATCH_MB)) - * (1024**2) - ) - - # Step 2: Try "Full Offload" (Max_Speed) regime first - M_weights_vram_full = M_weights_total # n_ngl = N_layers - M_weights_ram_full = 0 - - VRAM_fixed_cost_full = M_weights_vram_full + M_compute_bytes - - if ( - VRAM_fixed_cost_full < available_vram_bytes - and M_weights_ram_full < available_ram_bytes - ): - # Model fits in VRAM. Calculate max context size. - VRAM_remaining_for_kv = available_vram_bytes - VRAM_fixed_cost_full - - if C_kv_per_token > 0: - n_ctx_candidate = VRAM_remaining_for_kv // C_kv_per_token - else: - n_ctx_candidate = 0 - - if n_ctx_candidate >= min_context_size: - return { - "mode": "Full_Offload (Max_Speed)", - "n_ngl_best": N_layers, - "n_ctx_best": n_ctx_candidate, - "cache_type_k": cache_type_k, - "cache_type_v": cache_type_v, - "ubatch_size": ubatch_size, - } - - # If full offload failed and user wants Max_Speed only, return failure - if desired_performance == "Max_Speed": - return { - "mode": "Full_Offload_Failed", - "n_ngl_best": 0, - "n_ctx_best": 0, - "cache_type_k": cache_type_k, - "cache_type_v": cache_type_v, - "ubatch_size": ubatch_size, - } - - # Step 3: Try "Hybrid" (Max_Context) regime - # Find minimum n_ngl required to fit remaining weights in RAM - if M_weights_total > available_ram_bytes: - n_ngl_min = max( - 1, int((M_weights_total - available_ram_bytes) * N_layers / M_weights_total) - ) - else: - n_ngl_min = 0 - - n_ngl_best = 0 - n_ctx_best = 0 - - # Iterate from full offload down to minimum - for n_ngl_candidate in range(N_layers, n_ngl_min - 1, -1): - if n_ngl_candidate <= 0: - continue - - M_weights_vram_hybrid = (n_ngl_candidate / N_layers) * M_weights_total - M_weights_ram_hybrid = M_weights_total - M_weights_vram_hybrid - VRAM_fixed_cost_hybrid = M_weights_vram_hybrid + M_compute_bytes - - # Check if this n_ngl is possible (weights + compute must fit in VRAM) - if VRAM_fixed_cost_hybrid >= available_vram_bytes: - continue # This n_ngl is too high - - # Check if remaining weights fit in RAM - if M_weights_ram_hybrid > available_ram_bytes: - continue # This n_ngl requires too much RAM - - # Calculate max context for this n_ngl - VRAM_remaining_for_kv = available_vram_bytes - VRAM_fixed_cost_hybrid - - if C_kv_per_token > 0: - n_ctx_candidate = VRAM_remaining_for_kv // C_kv_per_token - else: - n_ctx_candidate = 0 - - # We're looking for the combination that yields the highest n_ctx - if n_ctx_candidate > n_ctx_best: - n_ctx_best = n_ctx_candidate - n_ngl_best = n_ngl_candidate - - if n_ctx_best >= min_context_size: - return { - "mode": "Hybrid_Mode (Max_Context)", - "n_ngl_best": n_ngl_best, - "n_ctx_best": n_ctx_best, - "cache_type_k": cache_type_k, - "cache_type_v": cache_type_v, - "ubatch_size": ubatch_size, - } - else: - return { - "mode": "Insufficient_Memory", - "n_ngl_best": 0, - "n_ctx_best": 0, - "cache_type_k": cache_type_k, - "cache_type_v": cache_type_v, - "ubatch_size": ubatch_size, - } diff --git a/backend/smart_auto/recommendations.py b/backend/smart_auto/recommendations.py deleted file mode 100644 index 3e010fd..0000000 --- a/backend/smart_auto/recommendations.py +++ /dev/null @@ -1,372 +0,0 @@ -""" -Recommendation engine for model configuration parameters. -Uses smart_auto logic with balanced preset (speed_quality=50, conversational). -""" - -from typing import Dict, Any, Optional -import asyncio -from backend.database import Model -from .architecture_config import resolve_architecture - - -def _create_minimal_model( - layer_info: Dict[str, Any], model_name: str = "", file_path: Optional[str] = None -) -> Model: - """Create a minimal Model object from layer info for smart_auto.""" - model = Model() - model.name = model_name or "Unknown" - model.file_path = ( - file_path # Use provided file_path if available for metadata reading - ) - model.file_size = 0 - model.huggingface_id = model_name - return model - - -def _create_minimal_gpu_info() -> Dict[str, Any]: - """Create minimal GPU info for smart_auto (assumes GPU available but will work without).""" - return { - "gpus": [], - "total_vram": 0, - "available_vram": 0, - "compute_capabilities": [], - "nvlink_topology": None, - } - - -def _extract_recommendation_from_config( - config: Dict[str, Any], - key: str, - layer_info: Dict[str, Any], - recommendation_type: str, -) -> Dict[str, Any]: - """Extract recommendation structure from generated config value.""" - - if recommendation_type == "gpu_layers": - layer_count = layer_info.get("layer_count", 32) - value = config.get("n_gpu_layers", layer_count) - # Clamp value to max - value = min(value, layer_count) - return { - "recommended_value": value, - "description": f"Recommended {value} layers" - + (" (full offload)" if value == layer_count else ""), - "balanced_value": layer_count // 2 if layer_count > 0 else 0, - "balanced_description": f"{layer_count // 2 if layer_count > 0 else 0} layers (balanced)", - "min": 0, - "max": layer_count, - "ranges": [ - {"value": 0, "description": "CPU-only mode (slowest, lowest VRAM)"}, - { - "value": layer_count // 2 if layer_count > 0 else 0, - "description": "Balanced (good performance, moderate VRAM)", - }, - { - "value": layer_count, - "description": "Full offload (fastest, highest VRAM)", - }, - ], - } - - elif recommendation_type == "context_size": - context_length = layer_info.get("context_length", 131072) - value = config.get("ctx_size", context_length) - # Clamp value to max - value = min(value, context_length) - return { - "recommended_value": value, - "description": f"Recommended {value:,} tokens", - "min": 512, - "max": context_length, - "ranges": [ - {"min": 512, "max": 2048, "description": "Short conversations"}, - {"min": 4096, "max": 8192, "description": "Standard conversations"}, - { - "min": 16384, - "max": context_length, - "description": f"Long documents (max {context_length:,})", - }, - ], - } - - elif recommendation_type == "batch_size": - value = config.get("batch_size", 512) - # Calculate max based on attention heads, clamp to reasonable range - attention_heads = layer_info.get("attention_head_count", 32) - max_val = min(1024, max(512, attention_heads * 16)) - # Clamp value to max - value = min(value, max_val) - return { - "recommended_value": value, - "description": f"Recommended {value}", - "min": 1, - "max": max_val, - "ranges": [ - {"min": 1, "max": 128, "description": "Low memory usage"}, - {"min": 256, "max": 512, "description": "Balanced (recommended)"}, - {"min": max_val, "max": max_val, "description": "Maximum throughput"}, - ], - } - - elif recommendation_type == "temperature": - value = config.get("temperature", config.get("temp", 0.8)) - # Clamp value to max - value = min(value, 2.0) - arch = layer_info.get("architecture", "").lower() - recommended_str = f"{value:.1f} for balanced conversation" - - if "glm" in arch or "deepseek" in arch: - recommended_str = f"{value:.1f} for GLM/DeepSeek models" - elif "qwen" in arch: - recommended_str = f"{value:.1f} for Qwen models" - elif "codellama" in arch: - recommended_str = f"{value:.1f} for code generation" - - return { - "recommended_value": value, - "description": recommended_str, - "min": 0.0, - "max": 2.0, - "ranges": [ - { - "min": 0.1, - "max": 0.3, - "description": "Code generation, technical tasks", - }, - { - "min": 0.7, - "max": 1.0, - "description": "General conversation (recommended)", - }, - { - "min": 1.5, - "max": 2.0, - "description": "Creative writing, brainstorming", - }, - ], - } - - elif recommendation_type == "top_k": - value = config.get("top_k", 40) - # Clamp value to max - value = min(value, 200) - arch = layer_info.get("architecture", "").lower() - recommended_str = f"{value} for most models" - - if "glm" in arch or "deepseek" in arch: - recommended_str = f"{value} for GLM/DeepSeek models" - - return { - "recommended_value": value, - "description": recommended_str, - "min": 0, - "max": 200, - "ranges": [ - {"min": 10, "max": 30, "description": "Focused, code-like outputs"}, - {"min": 40, "max": 50, "description": "Balanced (recommended)"}, - { - "min": 100, - "max": 200, - "description": "High diversity, creative writing", - }, - ], - } - - elif recommendation_type == "top_p": - value = config.get("top_p", 0.9) - # Clamp value to max - value = min(value, 1.0) - arch = layer_info.get("architecture", "").lower() - recommended_str = f"{value:.2f} for most models" - - if "glm" in arch or "deepseek" in arch: - recommended_str = f"{value:.2f} for GLM/DeepSeek models" - elif "qwen" in arch: - recommended_str = f"{value:.2f} for Qwen models" - - return { - "recommended_value": value, - "description": recommended_str, - "min": 0.0, - "max": 1.0, - "ranges": [ - {"min": 0.7, "max": 0.8, "description": "More conservative"}, - {"min": 0.9, "max": 0.95, "description": "Balanced (recommended)"}, - {"min": 0.95, "max": 1.0, "description": "Higher diversity"}, - ], - } - - elif recommendation_type == "parallel": - value = config.get("parallel", 1) - # Clamp value to max - value = min(value, 8) - attention_heads = layer_info.get("attention_head_count", 32) - return { - "recommended_value": value, - "description": f"Recommended {value} based on {attention_heads} attention heads", - "min": 1, - "max": 8, - } - - # Fallback - return { - "recommended_value": config.get(key, 0), - "description": f"Recommended {config.get(key, 0)}", - "min": 0, - "max": 100, - } - - -async def _generate_balanced_config( - model_layer_info: Dict[str, Any], - model_name: str = "", - file_path: Optional[str] = None, -) -> Dict[str, Any]: - """Generate balanced configuration using smart_auto with speed_quality=50 and conversational preset.""" - from backend.smart_auto import SmartAutoConfig - - # Create minimal model object with file_path if available - model = _create_minimal_model(model_layer_info, model_name, file_path) - - # Create minimal GPU info (will work for CPU-only too) - gpu_info = _create_minimal_gpu_info() - - # Create smart_auto config generator - smart_config = SmartAutoConfig() - - # Generate config with balanced settings: - # - speed_quality=50 (balanced between speed and quality) - # - preset="conversational" (balanced preset) - # - usage_mode="single_user" (standard usage) - config = await smart_config.generate_config( - model=model, - gpu_info=gpu_info, - preset="conversational", # Balanced preset - usage_mode="single_user", - speed_quality=50, # Balanced (50 = equal speed/quality) - use_case=None, - debug=None, - ) - - return config - - -async def get_model_recommendations( - model_layer_info: Dict[str, Any], - model_name: str = "", - file_path: Optional[str] = None, -) -> Dict[str, Any]: - """ - Get all recommendations using smart_auto logic with balanced preset. - - Uses smart_auto's generate_config with: - - speed_quality=50 (balanced) - - preset="conversational" (balanced preset) - - usage_mode="single_user" - - Args: - model_layer_info: Layer information dict from GGUF metadata - model_name: Optional model name for fallback - file_path: Optional file path for metadata reading - - Returns: - Dict with all recommendations extracted from smart_auto generated config - """ - try: - # Generate balanced config using smart_auto - config = await _generate_balanced_config( - model_layer_info, model_name, file_path - ) - - # Extract recommendations from generated config - return { - "gpu_layers": _extract_recommendation_from_config( - config, "n_gpu_layers", model_layer_info, "gpu_layers" - ), - "context_size": _extract_recommendation_from_config( - config, "ctx_size", model_layer_info, "context_size" - ), - "batch_size": _extract_recommendation_from_config( - config, "batch_size", model_layer_info, "batch_size" - ), - "temperature": _extract_recommendation_from_config( - config, "temperature", model_layer_info, "temperature" - ), - "top_k": _extract_recommendation_from_config( - config, "top_k", model_layer_info, "top_k" - ), - "top_p": _extract_recommendation_from_config( - config, "top_p", model_layer_info, "top_p" - ), - "parallel": _extract_recommendation_from_config( - config, "parallel", model_layer_info, "parallel" - ), - } - except Exception as e: - # Fallback to basic recommendations if smart_auto fails - from backend.logging_config import get_logger - - logger = get_logger(__name__) - logger.warning( - f"Failed to generate recommendations with smart_auto: {e}. Using fallback." - ) - - # Return basic fallback recommendations - layer_count = model_layer_info.get("layer_count", 32) - context_length = model_layer_info.get("context_length", 131072) - attention_heads = model_layer_info.get("attention_head_count", 32) - - return { - "gpu_layers": { - "recommended_value": layer_count, - "description": f"Recommended {layer_count} layers (full offload)", - "min": 0, - "max": layer_count, - "ranges": [ - {"value": 0, "description": "CPU-only mode"}, - {"value": layer_count // 2, "description": "Balanced"}, - {"value": layer_count, "description": "Full offload"}, - ], - }, - "context_size": { - "recommended_value": context_length, - "description": f"Max {context_length:,} tokens", - "min": 512, - "max": context_length, - "ranges": [], - }, - "batch_size": { - "recommended_value": 512, - "description": "Recommended 512", - "min": 1, - "max": 1024, - "ranges": [], - }, - "temperature": { - "recommended_value": 0.8, - "description": "0.8 for balanced conversation", - "min": 0.0, - "max": 2.0, - "ranges": [], - }, - "top_k": { - "recommended_value": 40, - "description": "40 for most models", - "min": 0, - "max": 200, - "ranges": [], - }, - "top_p": { - "recommended_value": 0.9, - "description": "0.9 for most models", - "min": 0.0, - "max": 1.0, - "ranges": [], - }, - "parallel": { - "recommended_value": min(8, max(1, attention_heads // 4)), - "description": f"Recommended based on {attention_heads} attention heads", - "min": 1, - "max": 8, - }, - } diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..bd24e8a --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,8 @@ +"""Pytest configuration and fixtures.""" +import sys +from pathlib import Path + +# Ensure backend is importable when running from repo root +root = Path(__file__).resolve().parent.parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) diff --git a/backend/tests/test_app_smoke.py b/backend/tests/test_app_smoke.py new file mode 100644 index 0000000..c77c4eb --- /dev/null +++ b/backend/tests/test_app_smoke.py @@ -0,0 +1,58 @@ +""" +Smoke tests to verify the app and key routes work after refactoring. +Run with: PYTHONPATH=. pytest backend/tests/test_app_smoke.py -v +(Requires: pip install -r requirements.txt) +""" +import pytest +from fastapi.testclient import TestClient + +from backend.main import app + + +@pytest.fixture +def client(): + return TestClient(app) + + +def test_app_starts(client): + """App should start and respond.""" + # Root or health-style endpoint may redirect; use API + response = client.get("/api/status") + assert response.status_code == 200 + data = response.json() + assert "system" in data + + +def test_param_registry_route(client): + """Param registry should return basic/advanced params.""" + response = client.get("/api/models/param-registry") + assert response.status_code == 200 + data = response.json() + assert "basic" in data + assert "advanced" in data + assert isinstance(data["basic"], list) + assert isinstance(data["advanced"], list) + + +def test_models_list_route(client): + """Models list should return 200 and a list (possibly empty).""" + response = client.get("/api/models/") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +def test_models_list_route_no_trailing_slash(client): + """GET /api/models (no trailing slash) should return model list, not param-registry.""" + response = client.get("/api/models") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + +@pytest.mark.skip(reason="SSE stream never ends; TestClient blocks on full response") +def test_sse_events_route(client): + """SSE events endpoint returns 200 and event-stream content-type.""" + response = client.get("/api/events") + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") diff --git a/backend/tests/test_architecture_profiles.py b/backend/tests/test_architecture_profiles.py deleted file mode 100644 index a955ce9..0000000 --- a/backend/tests/test_architecture_profiles.py +++ /dev/null @@ -1,70 +0,0 @@ -from backend.architecture_profiles import compute_layers_for_architecture - - -def test_glm4moe_profile_uses_block_and_nextn(): - metadata = { - "general.architecture": "glm4moe", - "glm4moe.block_count": 47, - "glm4moe.nextn_predict_layers": 1, - } - result = compute_layers_for_architecture( - architecture="glm4moe", - metadata=metadata, - base_block_count=47, - ) - assert result["block_count"] == 47 - assert result["effective_layer_count"] == 48 - - -def test_llama_like_profile_adds_output_head(): - metadata = { - "general.architecture": "llama", - "llama.block_count": 32, - } - result = compute_layers_for_architecture( - architecture="llama", - metadata=metadata, - base_block_count=32, - ) - assert result["block_count"] == 32 - assert result["effective_layer_count"] == 33 - - -def test_qwen_family_profile_adds_output_head(): - metadata = { - "general.architecture": "qwen2", - "qwen2.block_count": 28, - } - result = compute_layers_for_architecture( - architecture="qwen2", - metadata=metadata, - base_block_count=28, - ) - assert result["block_count"] == 28 - assert result["effective_layer_count"] == 29 - - -def test_generic_profile_uses_base_block_count_plus_one(): - metadata = { - "general.architecture": "some-new-arch", - } - result = compute_layers_for_architecture( - architecture="some-new-arch", - metadata=metadata, - base_block_count=40, - ) - assert result["block_count"] == 40 - assert result["effective_layer_count"] == 41 - - -def test_generic_profile_falls_back_to_32_when_no_block_count(): - metadata = { - "general.architecture": "unknown-arch", - } - result = compute_layers_for_architecture( - architecture="unknown-arch", - metadata=metadata, - base_block_count=0, - ) - assert result["block_count"] == 0 - assert result["effective_layer_count"] == 32 diff --git a/backend/unified_monitor.py b/backend/unified_monitor.py deleted file mode 100644 index 633dfa4..0000000 --- a/backend/unified_monitor.py +++ /dev/null @@ -1,716 +0,0 @@ -import asyncio -import psutil -import time -import yaml -import os -from collections import deque -from datetime import datetime -from typing import Dict, Any, Optional, List -from sqlalchemy.orm import Session - -from backend.websocket_manager import websocket_manager -from backend.gpu_detector import get_gpu_info -from backend.llama_swap_client import LlamaSwapClient -from backend.database import SessionLocal, RunningInstance, Model -from backend.logging_config import get_logger - -try: - import pynvml # type: ignore -except ImportError: - pynvml = None # type: ignore[assignment] - -DEFAULT_PROXY_PORT = 2000 -LMDEPLOY_PORT = 2001 - -logger = get_logger(__name__) - - -class UnifiedMonitor: - """Unified monitoring service with smart llama-swap integration. - - Key insight: The /running endpoint is ALWAYS safe to poll (never returns 503). - Only /v1/chat/completions returns 503 during model loading. - - We poll /running frequently to: - 1. Detect model state changes (loading → running) - 2. Detect external model starts (via llama-swap UI) - 3. Sync database with actual llama-swap state - """ - - def __init__(self): - self.is_running = False - self.monitor_task: Optional[asyncio.Task] = None - self.update_interval = 2.0 # Poll /running every 2 seconds (safe, no 503s) - - # Data storage - self.recent_logs = deque(maxlen=100) # Keep last 100 log entries - - # llama-swap client - self.llama_swap_client = LlamaSwapClient() - - # Model mapping cache - self.model_mapping = {} - self._load_model_mapping() - - # Optional direct WS subscribers (used by routes/unified_monitoring.py) - self.subscribers: List[Any] = [] - - # Loading state tracking - self._models_loading: Dict[str, datetime] = {} # model_name -> start_time - self._loading_timeout = 300 # 5 minutes max loading time - - # Previous model states for change detection - self._previous_model_states: Dict[str, str] = {} # model_name -> state - - def _load_model_mapping(self): - """Load model mapping from llama-swap configuration""" - try: - config_path = "/app/data/llama-swap-config.yaml" - if os.path.exists(config_path): - with open(config_path, "r") as f: - config = yaml.safe_load(f) - - # Extract model mappings from the config - models_config = config.get("models", {}) - for llama_swap_name, model_config in models_config.items(): - cmd = model_config.get("cmd", "") - # Extract the actual model file path from the command - if "--model" in cmd: - parts = cmd.split("--model") - if len(parts) > 1: - model_path = parts[1].strip().split()[0] - # Extract just the filename without path and extension - filename = os.path.basename(model_path).replace(".gguf", "") - self.model_mapping[llama_swap_name] = { - "filename": filename, - "full_path": model_path, - } - - logger.info( - f"Loaded {len(self.model_mapping)} model mappings from llama-swap config" - ) - logger.debug(f"Model mappings: {self.model_mapping}") - else: - logger.warning(f"llama-swap config not found at {config_path}") - except Exception as e: - logger.error(f"Failed to load model mapping: {e}") - - def add_log(self, log_event: Dict[str, Any]): - """Add a log event to the buffer""" - self.recent_logs.append(log_event) - - def mark_model_loading(self, model_name: str): - """Mark a model as currently loading""" - self._models_loading[model_name] = datetime.utcnow() - self.llama_swap_client.mark_model_loading(model_name) - logger.info(f"Model '{model_name}' is now loading") - - def mark_model_ready(self, model_name: str): - """Mark a model as ready (finished loading)""" - if model_name in self._models_loading: - load_time = ( - datetime.utcnow() - self._models_loading[model_name] - ).total_seconds() - del self._models_loading[model_name] - logger.info( - f"Model '{model_name}' is now ready (loaded in {load_time:.1f}s)" - ) - self.llama_swap_client.mark_model_ready(model_name) - - def mark_model_stopped(self, model_name: str): - """Mark a model as stopped (clear loading state)""" - self._models_loading.pop(model_name, None) - self.llama_swap_client.clear_loading_state(model_name) - logger.debug(f"Model '{model_name}' loading state cleared") - - async def broadcast_model_event( - self, event_type: str, model_name: str, details: Dict[str, Any] = None - ): - """Broadcast a model event immediately (no polling needed). - - This is called when model state changes (start/stop/ready) to push - updates to frontend instantly without waiting for the next poll cycle. - """ - event_data = { - "type": "model_event", - "event": event_type, # "loading", "ready", "stopped", "error" - "model": model_name, - "timestamp": datetime.utcnow().isoformat(), - "details": details or {}, - } - - try: - await websocket_manager.broadcast(event_data) - logger.debug(f"Broadcast model event: {event_type} for {model_name}") - except Exception as e: - logger.error(f"Failed to broadcast model event: {e}") - - async def trigger_status_update(self): - """Trigger an immediate status update broadcast. - - Called after model start/stop to push fresh data to frontend - without waiting for the next poll cycle. - """ - try: - await self._collect_and_send_unified_data() - except Exception as e: - logger.error(f"Failed to trigger status update: {e}") - - def get_loading_models(self) -> Dict[str, Any]: - """Get currently loading models with their loading times""" - now = datetime.utcnow() - loading = {} - expired = [] - - for model_name, start_time in self._models_loading.items(): - elapsed = (now - start_time).total_seconds() - if elapsed > self._loading_timeout: - # Model has been loading too long, consider it stuck - expired.append(model_name) - logger.warning( - f"Model '{model_name}' loading timeout ({elapsed:.0f}s > {self._loading_timeout}s)" - ) - else: - loading[model_name] = { - "started_at": start_time.isoformat(), - "elapsed_seconds": elapsed, - } - - # Clean up expired loading states - for model_name in expired: - del self._models_loading[model_name] - self.llama_swap_client.clear_loading_state(model_name) - - return loading - - def has_loading_models(self) -> bool: - """Check if any models are currently loading""" - # Clean up expired first - self.get_loading_models() - return len(self._models_loading) > 0 - - async def add_subscriber(self, websocket): - """Accept and register a WebSocket subscriber (minimal implementation).""" - try: - await websocket.accept() - except Exception: - return - self.subscribers.append(websocket) - - async def remove_subscriber(self, websocket): - """Remove a WebSocket subscriber and close if open.""" - try: - if websocket in self.subscribers: - self.subscribers.remove(websocket) - try: - await websocket.close() - except Exception: - pass - except Exception: - pass - - async def start_monitoring(self): - """Start the unified monitoring background task""" - if self.is_running: - return - - self.is_running = True - self.monitor_task = asyncio.create_task(self._monitor_loop()) - - logger.info("Unified monitoring started") - - async def stop_monitoring(self): - """Stop the unified monitoring background task""" - self.is_running = False - - if self.monitor_task: - self.monitor_task.cancel() - - logger.info("Unified monitoring stopped") - - async def _monitor_loop(self): - """Main monitoring loop - polls /running endpoint every 2 seconds. - - The /running endpoint is SAFE to poll (never returns 503). - This allows us to: - - Detect when models finish loading - - Detect external model starts (via llama-swap UI) - - Keep database in sync with llama-swap - """ - while self.is_running: - try: - await self._collect_and_send_unified_data() - await asyncio.sleep(self.update_interval) - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Unified monitoring error: {e}") - await asyncio.sleep(self.update_interval) - - async def _collect_and_send_unified_data(self): - """Collect all monitoring data and send as single unified message""" - try: - # 1. System metrics - cpu_percent = psutil.cpu_percent(interval=0) - memory = psutil.virtual_memory() - # Use data directory at project root or /app/data for Docker - data_dir = "data" if os.path.exists("data") else "/app/data" - try: - disk = psutil.disk_usage(data_dir) - except FileNotFoundError: - disk = psutil.disk_usage("/") - - # 2. Running instances from database - db = SessionLocal() - try: - running_instances = db.query(RunningInstance).all() - active_instances = [] - for instance in running_instances: - port = ( - LMDEPLOY_PORT - if instance.runtime_type == "lmdeploy" - else DEFAULT_PROXY_PORT - ) - active_instances.append( - { - "id": instance.id, - "model_id": instance.model_id, - "port": port, - "runtime_type": instance.runtime_type, - "proxy_model_name": instance.proxy_model_name, - "started_at": ( - instance.started_at.isoformat() - if instance.started_at - else None - ), - } - ) - finally: - db.close() - - # 3. Running models from llama-swap /running endpoint - # This endpoint is SAFE - it never returns 503, even during model loading - # We poll it every 2 seconds to detect: - # - Model state changes (loading → running) - # - External model starts (via llama-swap UI) - enhanced_external_models = [] - try: - external_response = await self.llama_swap_client.get_running_models() - - # Extract the running models array from the response - if ( - isinstance(external_response, dict) - and "running" in external_response - ): - external_models = external_response["running"] - elif isinstance(external_response, list): - external_models = external_response - else: - external_models = [] - - # Process models and detect state changes - for model in external_models: - model_name = model.get("model", "") - state = model.get("state", "unknown") - previous_state = self._previous_model_states.get(model_name) - - # Detect state transitions - if previous_state != state: - logger.info( - f"Model '{model_name}' state changed: {previous_state} → {state}" - ) - - if state == "loading": - # Model started loading - if model_name not in self._models_loading: - self._models_loading[model_name] = datetime.utcnow() - await self.broadcast_model_event("loading", model_name) - - elif state in ("running", "ready"): - # Model finished loading - broadcast ready event! - if model_name in self._models_loading: - load_time = ( - datetime.utcnow() - self._models_loading[model_name] - ).total_seconds() - logger.info( - f"Model '{model_name}' ready after {load_time:.1f}s" - ) - del self._models_loading[model_name] - await self.broadcast_model_event("ready", model_name) - - self._previous_model_states[model_name] = state - - enhanced_model = { - "model": model_name, - "state": state, - "mapping": self.model_mapping.get(model_name, {}), - "is_loading": state == "loading", - } - enhanced_external_models.append(enhanced_model) - - # Detect models that were removed (stopped externally) - current_model_names = {m.get("model", "") for m in external_models} - for prev_model in list(self._previous_model_states.keys()): - if prev_model not in current_model_names: - logger.info( - f"Model '{prev_model}' stopped (removed from llama-swap)" - ) - await self.broadcast_model_event("stopped", prev_model) - del self._previous_model_states[prev_model] - self._models_loading.pop(prev_model, None) - - # Sync database with llama-swap state - await self._sync_database_with_external_models(enhanced_external_models) - - except Exception as e: - logger.debug( - f"Failed to poll /running (llama-swap may be starting): {e}" - ) - - # 4. GPU info - try: - gpu_info = await get_gpu_info() - vram_data = None - if not gpu_info.get("cpu_only_mode", True): - vram_data = await self._get_vram_data(gpu_info) - except Exception as e: - logger.error(f"Failed to get GPU info: {e}") - gpu_info = {"cpu_only_mode": True, "device_count": 0} - vram_data = None - - # 5. Get loading models info - loading_models = self.get_loading_models() - - # 6. Create unified monitoring data - unified_data = { - "type": "unified_monitoring", - "timestamp": datetime.utcnow().isoformat(), - "system": { - "cpu_percent": cpu_percent, - "memory": { - "total": memory.total, - "available": memory.available, - "percent": memory.percent, - "used": memory.used, - "free": memory.free, - "cached": getattr(memory, "cached", 0), - "buffers": getattr(memory, "buffers", 0), - "swap_total": psutil.swap_memory().total, - "swap_used": psutil.swap_memory().used, - }, - "disk": { - "total": disk.total, - "used": disk.used, - "free": disk.free, - "percent": (disk.used / disk.total) * 100, - }, - }, - "gpu": { - "cpu_only_mode": gpu_info.get("cpu_only_mode", True), - "device_count": gpu_info.get("device_count", 0), - "total_vram": gpu_info.get("total_vram", 0), - "available_vram": gpu_info.get("available_vram", 0), - "vram_data": vram_data, - }, - "models": { - "running_instances": active_instances, - "loading": loading_models, # Models currently loading - "has_loading": len(loading_models) > 0, - }, - "proxy_status": { - "enabled": True, - "port": 2000, - "endpoint": "http://localhost:2000/v1/chat/completions", - }, - "logs": list(self.recent_logs)[-20:], # Last 20 logs - } - - # 6. Send unified data to all WebSocket connections - logger.debug(f"Broadcasting unified monitoring data: {unified_data}") - await websocket_manager.broadcast(unified_data) - - except Exception as e: - logger.error(f"Error collecting unified monitoring data: {e}") - - async def _get_vram_data(self, gpu_info: Dict[str, Any]) -> Dict[str, Any]: - """Get current VRAM usage data""" - if pynvml is None: - logger.debug("NVML not available; skipping VRAM detail collection") - return { - "total": 0, - "used": 0, - "free": 0, - "percent": 0, - "gpus": [], - "cuda_version": gpu_info.get("cuda_version", "N/A"), - "device_count": gpu_info.get("device_count", 0), - "timestamp": time.time(), - } - try: - pynvml.nvmlInit() - - device_count = gpu_info.get("device_count", 0) - total_vram = 0 - used_vram = 0 - gpu_details = [] - - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) - - gpu_total = mem_info.total - gpu_used = mem_info.used - gpu_free = mem_info.free - - total_vram += gpu_total - used_vram += gpu_used - - gpu_details.append( - { - "device_id": i, - "total": gpu_total, - "used": gpu_used, - "free": gpu_free, - "utilization": utilization.gpu, - "memory_utilization": utilization.memory, - } - ) - - return { - "total": total_vram, - "used": used_vram, - "free": total_vram - used_vram, - "percent": (used_vram / total_vram * 100) if total_vram > 0 else 0, - "gpus": gpu_details, - "cuda_version": gpu_info.get("cuda_version", "N/A"), - "device_count": gpu_info.get("device_count", 0), - "timestamp": time.time(), - } - except Exception as e: - logger.error(f"Failed to get VRAM data: {e}") - return { - "total": 0, - "used": 0, - "free": 0, - "percent": 0, - "gpus": [], - "cuda_version": "N/A", - "device_count": 0, - "timestamp": time.time(), - } - finally: - try: - pynvml.nvmlShutdown() - except Exception: - pass - - async def _sync_database_with_external_models( - self, external_models: List[Dict[str, Any]] - ): - """Sync database RunningInstance records with external running models""" - try: - db: Session = SessionLocal() - try: - # Get all current running instances from database - current_instances = db.query(RunningInstance).all() - llama_cpp_instances = [ - instance - for instance in current_instances - if (instance.runtime_type or "llama_cpp") == "llama_cpp" - ] - current_proxy_names = { - instance.proxy_model_name - for instance in llama_cpp_instances - if instance.proxy_model_name - } - - # Get external model names - external_names = {model["model"] for model in external_models} - - # Find models that are running externally but not in database - missing_in_db = external_names - current_proxy_names - - # Find models that are in database but not running externally - missing_in_external = current_proxy_names - external_names - - # Add missing models to database - for proxy_name in missing_in_db: - # Find the corresponding model in the database by matching the proxy name - model = self._find_model_by_proxy_name(db, proxy_name) - if model: - # Create a new RunningInstance - new_instance = RunningInstance( - model_id=model.id, - proxy_model_name=proxy_name, - started_at=datetime.utcnow(), - runtime_type="llama_cpp", - ) - db.add(new_instance) - logger.info(f"Added missing model '{proxy_name}' to database") - - # Update model.is_active - model.is_active = True - - # Remove models that are no longer running externally - for proxy_name in missing_in_external: - instances_to_remove = ( - db.query(RunningInstance) - .filter( - RunningInstance.proxy_model_name == proxy_name, - RunningInstance.runtime_type == "llama_cpp", - ) - .all() - ) - - for instance in instances_to_remove: - # Update model.is_active - model = ( - db.query(Model) - .filter(Model.id == instance.model_id) - .first() - ) - if model: - model.is_active = False - - # Remove the RunningInstance - db.delete(instance) - logger.info( - f"Removed stopped model '{proxy_name}' from database" - ) - - db.commit() - logger.debug( - f"Database sync completed. Added: {len(missing_in_db)}, Removed: {len(missing_in_external)}" - ) - - finally: - db.close() - - except Exception as e: - logger.error(f"Error syncing database with external models: {e}") - import traceback - - logger.error(f"Traceback: {traceback.format_exc()}") - - def _find_model_by_proxy_name( - self, db: Session, proxy_name: str - ) -> Optional[Model]: - """Find a model in the database by matching the proxy name""" - try: - # Use the stored proxy_name field for direct lookup - model = db.query(Model).filter(Model.proxy_name == proxy_name).first() - - if model: - return model - - logger.warning(f"Could not find model for proxy name: {proxy_name}") - return None - - except Exception as e: - logger.error(f"Error finding model by proxy name '{proxy_name}': {e}") - return None - - # API methods for HTTP endpoints - async def get_system_status(self) -> Dict[str, Any]: - """Get current system status (for HTTP API)""" - cpu_percent = psutil.cpu_percent(interval=0) - memory = psutil.virtual_memory() - disk = psutil.disk_usage("/app/data") - - db = SessionLocal() - try: - running_instances = db.query(RunningInstance).all() - active_instances = [] - for instance in running_instances: - port = ( - LMDEPLOY_PORT - if instance.runtime_type == "lmdeploy" - else DEFAULT_PROXY_PORT - ) - active_instances.append( - { - "id": instance.id, - "model_id": instance.model_id, - "port": port, - "runtime_type": instance.runtime_type, - "proxy_model_name": instance.proxy_model_name, - "started_at": instance.started_at, - } - ) - finally: - db.close() - - return { - "system": { - "cpu_percent": cpu_percent, - "memory": { - "total": memory.total, - "available": memory.available, - "percent": memory.percent, - "used": memory.used, - "free": memory.free, - }, - "disk": { - "total": disk.total, - "used": disk.used, - "free": disk.free, - "percent": (disk.used / disk.total) * 100, - }, - }, - "running_instances": active_instances, - "proxy_status": { - "enabled": True, - "port": 2000, - "endpoint": "http://localhost:2000/v1/chat/completions", - }, - "timestamp": datetime.utcnow().isoformat(), - } - - async def get_running_models(self) -> List[Dict[str, Any]]: - """Get currently running models from llama-swap""" - try: - return await self.llama_swap_client.get_running_models() - except Exception as e: - logger.debug(f"Failed to get running models from llama-swap: {e}") - return [] - - async def unload_all_models(self) -> Dict[str, Any]: - """Unload all models via llama-swap""" - try: - return await self.llama_swap_client.unload_all_models() - except Exception as e: - logger.error(f"Failed to unload all models: {e}") - return {"error": str(e)} - - async def get_system_health(self) -> Dict[str, Any]: - """Get llama-swap and system health status""" - try: - health_result = await self.llama_swap_client.check_health() - llama_swap_healthy = health_result.get("healthy", False) - loading_models = health_result.get("loading_models", []) - except Exception as e: - logger.error(f"Failed to check llama-swap health: {e}") - llama_swap_healthy = False - loading_models = [] - - return { - "llama_swap_proxy": "healthy" if llama_swap_healthy else "unhealthy", - "monitoring_active": self.is_running, - "active_connections": len(websocket_manager.active_connections), - "loading_models": loading_models, - "has_loading_models": len(loading_models) > 0 or self.has_loading_models(), - } - - def get_recent_logs(self, limit: int = 100) -> List[Dict[str, Any]]: - """Get recent logs from monitor buffer""" - logs = list(self.recent_logs) - return logs[-limit:] - - def add_log(self, log_event: Dict[str, Any]): - """Add a log event to the buffer""" - self.recent_logs.append(log_event) - - -# Global unified monitor instance -unified_monitor = UnifiedMonitor() diff --git a/backend/websocket_manager.py b/backend/websocket_manager.py deleted file mode 100644 index 20960c0..0000000 --- a/backend/websocket_manager.py +++ /dev/null @@ -1,192 +0,0 @@ -from fastapi import WebSocket -from typing import List, Dict, Optional, Callable -import json -import asyncio -import time -from datetime import datetime -from backend.logging_config import get_logger - -logger = get_logger(__name__) - - -class WebSocketManager: - def __init__(self): - self.active_connections: List[WebSocket] = [] - self.subscribers: Dict[str, List[Callable]] = {} - - async def connect(self, websocket: WebSocket): - try: - await websocket.accept() - self.active_connections.append(websocket) - logger.info( - f"WebSocket connected. Total connections: {len(self.active_connections)}" - ) - - # Send a test message to verify the connection works - await self.send_personal_message( - json.dumps( - { - "type": "connection_test", - "message": "WebSocket connection established successfully", - "timestamp": datetime.utcnow().isoformat(), - } - ), - websocket, - ) - - except Exception as e: - logger.error(f"Error in WebSocketManager.connect: {e}") - raise - - def disconnect(self, websocket: WebSocket): - if websocket in self.active_connections: - self.active_connections.remove(websocket) - logger.info( - f"WebSocket disconnected. Total connections: {len(self.active_connections)}" - ) - - async def send_personal_message(self, message: str, websocket: WebSocket): - try: - await websocket.send_text(message) - except: - self.disconnect(websocket) - - async def broadcast(self, message: dict): - """Broadcast message to all active WebSocket connections""" - if not self.active_connections: - logger.debug( - f"No active WebSocket connections to broadcast message: {message.get('type', 'unknown')}" - ) - return - - message_str = json.dumps(message) - logger.debug( - f"Broadcasting to {len(self.active_connections)} connections: {message.get('type', 'unknown')}" - ) - - async def _send(conn): - try: - await conn.send_text(message_str) - return None - except Exception as e: - return conn - - # Send concurrently and collect failed connections - results = await asyncio.gather( - *[_send(c) for c in list(self.active_connections)], return_exceptions=False - ) - for failed in results: - if isinstance(failed, WebSocket): - self.disconnect(failed) - - # Legacy methods for backward compatibility - async def send_download_progress( - self, - task_id: str, - progress: int, - message: str = "", - bytes_downloaded: int = 0, - total_bytes: int = 0, - speed_mbps: float = 0, - eta_seconds: int = 0, - filename: str = "", - model_format: str = "gguf", - files_completed: int = None, - files_total: int = None, - current_filename: str = None, - huggingface_id: str = None, - ): - await self.broadcast( - { - "type": "download_progress", - "task_id": task_id, - "progress": progress, - "message": message, - "bytes_downloaded": bytes_downloaded, - "total_bytes": total_bytes, - "speed_mbps": speed_mbps, - "eta_seconds": eta_seconds, - "filename": filename, - "model_format": model_format, - "files_completed": files_completed, - "files_total": files_total, - "current_filename": current_filename or filename, - "huggingface_id": huggingface_id, - "timestamp": datetime.utcnow().isoformat(), - } - ) - - async def send_build_progress( - self, - task_id: str, - stage: str, - progress: int, - message: str = "", - log_lines: List[str] = None, - ): - message_data = { - "type": "build_progress", - "task_id": task_id, - "stage": stage, - "progress": progress, - "message": message, - "log_lines": log_lines or [], - "timestamp": datetime.utcnow().isoformat(), - } - - logger.debug( - f"Sending build progress: task_id={task_id}, stage={stage}, progress={progress}, message='{message}', connections={len(self.active_connections)}" - ) - logger.debug(f"Message data: {message_data}") - await self.broadcast(message_data) - - async def send_model_status_update( - self, model_id: int, status: str, details: dict = None - ): - await self.broadcast( - { - "type": "model_status", - "model_id": model_id, - "status": status, - "details": details or {}, - "timestamp": datetime.utcnow().isoformat(), - } - ) - - async def send_notification( - self, title: str, message: str, type: str = "info", actions: List[dict] = None - ): - await self.broadcast( - { - "type": "notification", - "title": title, - "message": message, - "notification_type": type, - "actions": actions or [], - "timestamp": datetime.utcnow().isoformat(), - } - ) - - async def send_lmdeploy_status(self, status: dict): - """Broadcast LMDeploy installer status update.""" - await self.broadcast( - { - "type": "lmdeploy_status", - **status, - "timestamp": datetime.utcnow().isoformat(), - } - ) - - async def send_lmdeploy_runtime_log(self, line: str): - """Broadcast LMDeploy runtime log line.""" - await self.broadcast( - { - "type": "lmdeploy_runtime_log", - "line": line, - "timestamp": datetime.utcnow().isoformat(), - } - ) - - -# Global WebSocket manager instance -websocket_manager = WebSocketManager() diff --git a/docker-compose.cuda.yml b/docker-compose.cuda.yml index 6e76ef1..57d5cc0 100644 --- a/docker-compose.cuda.yml +++ b/docker-compose.cuda.yml @@ -2,17 +2,20 @@ version: '3.8' services: llama-cpp-studio: - build: . + build: + context: . + pull: true ports: - "8080:8080" - "2000:2000" volumes: - ./data:/app/data - - ./backend:/app/backend environment: - CUDA_VISIBLE_DEVICES=all - HF_HUB_ENABLE_HF_TRANSFER=1 - - RELOAD=true + - HF_HOME=/app/data/temp/.cache/huggingface + - HUGGINGFACE_HUB_CACHE=/app/data/temp/.cache/huggingface/hub + - RELOAD=false # Uncomment and set your HuggingFace API key to enable model search and download # - HUGGINGFACE_API_KEY=your_huggingface_token_here # Alternative: Use .env file for environment variables diff --git a/docker-compose.rocm.yml b/docker-compose.rocm.yml deleted file mode 100644 index e9c00d2..0000000 --- a/docker-compose.rocm.yml +++ /dev/null @@ -1,40 +0,0 @@ -version: '3.8' - -services: - llama-cpp-studio: - build: . - image: llama-cpp-studio:rocm - ports: - - "8080:8080" - - "2000:2000" - volumes: - - ./data:/app/data - - ./backend:/app/backend - # Mount ROCm devices - - /dev/kfd:/dev/kfd - - /dev/dri:/dev/dri - environment: - # Disable CUDA to use ROCm instead - - CUDA_VISIBLE_DEVICES="" - - RELOAD=true - # ROCm environment variables - - HSA_OVERRIDE_GFX_VERSION=10.3.0 - - HIP_VISIBLE_DEVICES=all - - ROC_ENABLE_PRE_VEGA=1 - # Uncomment and set your HuggingFace API key - # - HUGGINGFACE_API_KEY=your_huggingface_token_here - devices: - # AMD GPU access - - /dev/kfd:/dev/kfd - - /dev/dri:/dev/dri - # Enable privileged mode for GPU access - privileged: true - restart: unless-stopped - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/api/status"] - interval: 30s - timeout: 10s - retries: 3 - cap_add: - - SYS_ADMIN - shm_size: '2gb' diff --git a/docker-compose.vulkan.yml b/docker-compose.vulkan.yml deleted file mode 100644 index 96324ca..0000000 --- a/docker-compose.vulkan.yml +++ /dev/null @@ -1,37 +0,0 @@ -version: '3.8' - -services: - llama-cpp-studio: - build: . - image: llama-cpp-studio:vulkan - ports: - - "8080:8080" - - "2000:2000" - volumes: - - ./data:/app/data - - ./backend:/app/backend - - /tmp/.X11-unix:/tmp/.X11-unix:rw - environment: - - DISPLAY=${DISPLAY} - - XDG_RUNTIME_DIR=${XDG_RUNTIME_DIR} - - HF_HUB_ENABLE_HF_TRANSFER=1 - # Disable CUDA to use Vulkan instead - - CUDA_VISIBLE_DEVICES="" - - RELOAD=true - # Vulkan device selection (optional) - - VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/radeon_icd.x86_64.json - # Uncomment and set your HuggingFace API key - # - HUGGINGFACE_API_KEY=your_huggingface_token_here - devices: - # Mount DRI devices for Vulkan access - - /dev/dri:/dev/dri - # Enable privileged mode for GPU access (required for Vulkan) - privileged: true - restart: unless-stopped - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8080/api/status"] - interval: 30s - timeout: 10s - retries: 3 - cap_add: - - SYS_ADMIN diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index f4158fe..ff2c434 100644 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -8,18 +8,9 @@ if [ -d "/app/data" ]; then # Check if we can write to the data directory if [ ! -w "/app/data" ]; then echo "WARNING: /app/data directory is not writable by current user ($(id -u))" - echo "This will cause database and file write errors." + echo "This will cause configuration and model write errors." echo "To fix, run on the host: sudo chown -R $(id -u):$(id -g) " fi - - # Check database file specifically - if [ -f "/app/data/db.sqlite" ] && [ ! -w "/app/data/db.sqlite" ]; then - echo "ERROR: Database file /app/data/db.sqlite exists but is not writable" - echo "Current user: $(id -u) ($(whoami))" - echo "File owner: $(stat -c '%U:%G (%u:%g)' /app/data/db.sqlite 2>/dev/null || echo 'unknown')" - echo "To fix, run on the host: sudo chown $(id -u):$(id -g) /db.sqlite" - echo "Or remove the database file to recreate it with correct permissions" - fi fi # Source the CUDA environment setup script if it exists diff --git a/frontend/src/App.vue b/frontend/src/App.vue index d749d0f..b4b0326 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -1,7 +1,7 @@