From 427968123c46fc1220464cb998ece51d69d10ba0 Mon Sep 17 00:00:00 2001 From: Alex Baur Date: Fri, 20 Mar 2026 12:39:27 +0100 Subject: [PATCH] Add Blackwell Cuda GB10 Support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Create Dockerfile.blackwell - Set fixed requirement of hugginface_hub - Create new docker-compose.blackwell.yml - Check in opentranscibe.sh if it’s a Blackwell GPU and load new docker-compose.blackwell.yml - Update update_gpu_stats functio for Blackwell GPU --- backend/Dockerfile.blackwell | 199 +++++++++++++++++++++++++++++++++++ backend/app/tasks/utility.py | 163 +++++++++++++++++----------- backend/requirements.txt | 3 + docker-compose.blackwell.yml | 43 ++++++++ docker-compose.yml | 13 --- opentranscribe.sh | 13 ++- 6 files changed, 354 insertions(+), 80 deletions(-) create mode 100644 backend/Dockerfile.blackwell create mode 100644 docker-compose.blackwell.yml diff --git a/backend/Dockerfile.blackwell b/backend/Dockerfile.blackwell new file mode 100644 index 00000000..6933c3f5 --- /dev/null +++ b/backend/Dockerfile.blackwell @@ -0,0 +1,199 @@ +# ============================================================================= +# OpenTranscribe GPU Worker for NVIDIA DGX Spark / ARM64 / Blackwell +# - basiert auf NVIDIA PyTorch Container +# - ohne appuser, stattdessen USER 1000:1000 +# - mit Blackwell/NVRTC Patches fuer WhisperX / Torchaudio +# - mit pyannote SemVer Workaround fuer NVIDIA Torch Dev-Versionen +# ============================================================================= + +FROM nvcr.io/nvidia/pytorch:25.01-py3 + +WORKDIR /app + +# Blackwell / NVRTC Kompatibilitaet + Cache-Pfade +ENV TORCH_CUDA_ARCH_LIST="9.0" \ + CUDA_FORCE_PTX_JIT=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + HF_HOME=/home/user/.cache/huggingface \ + TRANSFORMERS_CACHE=/home/user/.cache/huggingface/transformers \ + TORCH_HOME=/home/user/.cache/torch \ + NLTK_DATA=/home/user/.cache/nltk_data \ + SENTENCE_TRANSFORMERS_HOME=/home/user/.cache/sentence-transformers \ + XDG_CACHE_HOME=/home/user/.cache \ + PATH=/usr/local/bin:$PATH + +# System-Abhaengigkeiten +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + ffmpeg \ + libsndfile1 \ + libimage-exiftool-perl \ + libgomp1 \ + git \ + cmake \ + build-essential \ + libopenblas-dev \ + sox \ + libsox-dev \ + && rm -rf /var/lib/apt/lists/* + +# Verzeichnisse fuer Runtime und Caches +RUN mkdir -p \ + /app \ + /app/models \ + /app/temp \ + /home/user/.cache/huggingface \ + /home/user/.cache/torch \ + /home/user/.cache/nltk_data \ + /home/user/.cache/sentence-transformers \ + /home/user/.cache/yt-dlp \ + && chown -R 1000:1000 /app /home/user + +# Requirements zuerst fuer Layer-Cache +COPY requirements.txt /app/requirements.txt + +# Basis Python-Werkzeuge +RUN python -m pip install --no-cache-dir --upgrade pip setuptools wheel && \ + python -m pip install --no-cache-dir pybind11 packaging ninja cmake + +# ----------------------------------------------------------------------------- +# Torchaudio gegen NVIDIA Torch bauen +# ----------------------------------------------------------------------------- +RUN git clone --depth 1 --branch v2.6.0 --recursive https://github.com/pytorch/audio.git /tmp/torchaudio && \ + cd /tmp/torchaudio && \ + pip install --no-cache-dir --no-deps --no-build-isolation . && \ + rm -rf /tmp/torchaudio + +# ----------------------------------------------------------------------------- +# CTranslate2 mit CUDA/cuDNN bauen +# ----------------------------------------------------------------------------- +RUN git clone --recursive --depth 1 --branch v4.4.0 https://github.com/OpenNMT/CTranslate2.git /tmp/ctranslate2 && \ + cd /tmp/ctranslate2 && \ + mkdir build && cd build && \ + cmake .. \ + -DCMAKE_BUILD_TYPE=Release \ + -DWITH_CUDA=ON \ + -DWITH_CUDNN=ON \ + -DCUDNN_ROOT=/usr \ + -DCUDA_DYNAMIC_LOADING=ON \ + -DWITH_MKL=OFF \ + -DWITH_OPENBLAS=ON \ + -DOPENMP_RUNTIME=COMP \ + -DCMAKE_INSTALL_PREFIX=/usr/local && \ + make -j"$(nproc)" && \ + make install && \ + ldconfig && \ + cd /tmp/ctranslate2/python && \ + pip install --no-cache-dir --no-build-isolation . && \ + rm -rf /tmp/ctranslate2 + +# ----------------------------------------------------------------------------- +# NVIDIA Torch-Stack sichern, bevor pip evtl. Dinge ueberschreibt +# ----------------------------------------------------------------------------- +RUN cp -r /usr/local/lib/python3.12/dist-packages/torch /tmp/torch_nvidia && \ + cp -r /usr/local/lib/python3.12/dist-packages/torchvision /tmp/torchvision_nvidia && \ + cp -r /usr/local/lib/python3.12/dist-packages/torchaudio /tmp/torchaudio_custom && \ + cp -r /usr/local/lib/python3.12/dist-packages/torio /tmp/torio_custom && \ + cp -r /usr/local/lib/python3.12/dist-packages/numpy /tmp/numpy_nvidia && \ + cp -r /usr/local/lib/python3.12/dist-packages/numpy.libs /tmp/numpy_libs_nvidia 2>/dev/null || true + +# ----------------------------------------------------------------------------- +# OpenTranscribe-Requirements ohne GPU-kritische Pakete installieren +# So bleibt der NVIDIA-Torch-Stack erhalten +# ----------------------------------------------------------------------------- +RUN grep -vE '^(torch==|torch>=|torchaudio==|torchaudio>=|ctranslate2|whisperx==|whisperx>=|pyannote\.audio)' /app/requirements.txt > /tmp/requirements.safe.txt && \ + pip install --no-cache-dir -r /tmp/requirements.safe.txt + +# GPU-/WhisperX-relevante Pakete explizit setzen +RUN pip install --no-cache-dir \ + huggingface_hub==0.23.5 \ + whisperx==3.3.1 \ + faster-whisper==1.1.0 \ + pyannote.audio==3.3.2 \ + python-multipart \ + nltk \ + matplotlib + +# ----------------------------------------------------------------------------- +# NVIDIA Torch-Stack wiederherstellen +# ----------------------------------------------------------------------------- +RUN rm -rf /usr/local/lib/python3.12/dist-packages/torch && \ + rm -rf /usr/local/lib/python3.12/dist-packages/torchvision && \ + rm -rf /usr/local/lib/python3.12/dist-packages/torchaudio && \ + rm -rf /usr/local/lib/python3.12/dist-packages/torio && \ + rm -rf /usr/local/lib/python3.12/dist-packages/numpy && \ + rm -rf /usr/local/lib/python3.12/dist-packages/numpy.libs && \ + mv /tmp/torch_nvidia /usr/local/lib/python3.12/dist-packages/torch && \ + mv /tmp/torchvision_nvidia /usr/local/lib/python3.12/dist-packages/torchvision && \ + mv /tmp/torchaudio_custom /usr/local/lib/python3.12/dist-packages/torchaudio && \ + mv /tmp/torio_custom /usr/local/lib/python3.12/dist-packages/torio && \ + mv /tmp/numpy_nvidia /usr/local/lib/python3.12/dist-packages/numpy && \ + mv /tmp/numpy_libs_nvidia /usr/local/lib/python3.12/dist-packages/numpy.libs 2>/dev/null || true + +# ========================= +# BLACKWELL PATCHES +# ========================= + +# Patch 1: get_device_capability fuer SM_121 auf SM_90 umbiegen +RUN sed -i 's/def get_device_capability/def _original_get_device_capability/g' \ + /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf '\n# BLACKWELL PATCH: Spoof SM_121 as SM_90 for nvrtc compatibility\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf 'def get_device_capability(device=None):\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf ' major, minor = _original_get_device_capability(device)\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf ' if major == 12 and minor == 1:\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf ' return (9, 0)\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py && \ + printf ' return (major, minor)\n' >> /usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py + +# Patch 2: harte compute_121 / sm_121 Referenzen ersetzen +RUN sed -i 's/compute_121/compute_90/g' /usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py 2>/dev/null || true && \ + sed -i 's/sm_121/sm_90/g' /usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py 2>/dev/null || true + +# Patch 3: optional Inductor-Codecache patchen +RUN sed -i 's/compute_121/compute_90/g' /usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py 2>/dev/null || true && \ + sed -i 's/sm_121/sm_90/g' /usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py 2>/dev/null || true + +# Patch 4: Torchaudio fbank / jiterator Crash umgehen +RUN sed -i 's/spectrum = torch.fft.rfft(strided_input).abs()/# BLACKWELL PATCH: Avoid jiterator by computing abs manually\n fft_result = torch.fft.rfft(strided_input)\n spectrum = torch.sqrt(fft_result.real**2 + fft_result.imag**2)/' \ + /usr/local/lib/python3.12/dist-packages/torchaudio/compliance/kaldi.py + +# Patch 5: pyannote SemVer Check deaktivieren fuer NVIDIA Torch Dev-Versionen +RUN python - <<'PY' +from pathlib import Path +import re + +p = Path("/usr/local/lib/python3.12/dist-packages/pyannote/audio/utils/version.py") +if not p.exists(): + raise SystemExit("pyannote version.py not found") + +text = p.read_text() + +patched = re.sub( + r"def check_version\(.*?\n(?=def |\Z)", + "def check_version(*args, **kwargs):\n return\n\n", + text, + flags=re.S, +) + +if patched == text: + print("No check_version patch applied; pattern not found exactly, leaving file unchanged") +else: + p.write_text(patched) + print("Disabled pyannote check_version") +PY + +# App-Code kopieren +COPY . /app + +# Besitzrechte fuer Runtime-User +RUN chown -R 1000:1000 /app /home/user + +USER 1000:1000 + +EXPOSE 8080 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -f http://localhost:8080/health || exit 1 + +# Default fuer Backend-Container +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"] \ No newline at end of file diff --git a/backend/app/tasks/utility.py b/backend/app/tasks/utility.py index 26546346..590bc1bd 100644 --- a/backend/app/tasks/utility.py +++ b/backend/app/tasks/utility.py @@ -79,18 +79,33 @@ def update_gpu_stats(self): """ Periodic task to update GPU statistics in Redis. - This task runs on the celery worker (which has GPU access) and stores - GPU memory stats in Redis so the backend API can retrieve them. + On DGX Spark / GB10, nvidia-smi framebuffer memory stats can be unavailable + because the platform uses unified memory. In that case we fall back to + torch.cuda.mem_get_info() and clearly label the source. + """ - Uses nvidia-smi to get accurate GPU memory usage including all processes, - not just PyTorch allocated memory. + def safe_float(value): + if value is None: + return None + value = str(value).strip() + if value in {"[N/A]", "N/A", "", "Unknown", "Not Supported"}: + return None + try: + return float(value) + except Exception: + return None + + def format_bytes(byte_count): + if byte_count is None: + return "N/A" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if byte_count < 1024 or unit == "TB": + return f"{byte_count:.2f} {unit}" + byte_count /= 1024 + return f"{byte_count:.2f} TB" - Returns: - Dictionary with GPU stats or error status - """ try: import subprocess - import torch if not torch.cuda.is_available(): @@ -101,49 +116,78 @@ def update_gpu_stats(self): "memory_used": "N/A", "memory_free": "N/A", "memory_percent": "N/A", + "memory_source": "none", } else: - # Get GPU device info from PyTorch - device_id = 0 # Primary GPU + device_id = 0 gpu_properties = torch.cuda.get_device_properties(device_id) - # Use nvidia-smi for accurate memory usage (includes all processes) - # Format: memory.used,memory.total,memory.free (in MiB) - # Security: Safe subprocess call with hardcoded system command (nvidia-smi). - # Only dynamic parameter is device_id (integer), preventing command injection. - result = subprocess.run( - [ # noqa: S603 S607 # nosec B603 B607 - hardcoded nvidia-smi, integer device_id - "nvidia-smi", - "--query-gpu=memory.used,memory.total,memory.free", - "--format=csv,noheader,nounits", - f"--id={device_id}", - ], - capture_output=True, - text=True, - check=True, - ) - - # Parse the output: "used, total, free" in MiB - memory_values = result.stdout.strip().split(", ") - memory_used_mib = float(memory_values[0]) - memory_total_mib = float(memory_values[1]) - memory_free_mib = float(memory_values[2]) - - # Convert MiB to bytes for formatting - memory_used = memory_used_mib * 1024 * 1024 - memory_total = memory_total_mib * 1024 * 1024 - memory_free = memory_free_mib * 1024 * 1024 - - # Calculate percentage used - memory_percent = (memory_used / memory_total * 100) if memory_total > 0 else 0 - - # Format bytes to human-readable - def format_bytes(byte_count): - for unit in ["B", "KB", "MB", "GB", "TB"]: - if byte_count < 1024 or unit == "TB": - return f"{byte_count:.2f} {unit}" - byte_count /= 1024 - return f"{byte_count:.2f} TB" + memory_total = None + memory_used = None + memory_free = None + memory_percent = None + memory_source = None + memory_note = None + + # 1) Preferred fallback on DGX Spark / GB10: + # CUDA-visible memory via PyTorch + try: + free_bytes, total_bytes = torch.cuda.mem_get_info(device_id) + memory_free = float(free_bytes) + memory_total = float(total_bytes) + memory_used = memory_total - memory_free + if memory_total > 0: + memory_percent = memory_used / memory_total * 100 + memory_source = "cudaMemGetInfo" + memory_note = "CUDA-visible memory on unified-memory system" + except Exception: + pass + + # 2) Try nvidia-smi only if CUDA mem info did not work + if memory_total is None: + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.total,memory.free", + "--format=csv,noheader,nounits", + f"--id={device_id}", + ], + capture_output=True, + text=True, + check=True, + ) + + raw_values = [v.strip() for v in result.stdout.strip().split(",")] + + memory_used_mib = safe_float(raw_values[0]) if len(raw_values) > 0 else None + memory_total_mib = safe_float(raw_values[1]) if len(raw_values) > 1 else None + memory_free_mib = safe_float(raw_values[2]) if len(raw_values) > 2 else None + + memory_used = ( + memory_used_mib * 1024 * 1024 if memory_used_mib is not None else None + ) + memory_total = ( + memory_total_mib * 1024 * 1024 if memory_total_mib is not None else None + ) + memory_free = ( + memory_free_mib * 1024 * 1024 if memory_free_mib is not None else None + ) + + if memory_used is not None and memory_total not in (None, 0): + memory_percent = memory_used / memory_total * 100 + + memory_source = "nvidia-smi" + except Exception: + pass + + # 3) Final fallback for DGX Spark / GB10 UMA + if memory_total is None: + memory_source = "unified-memory" + memory_note = ( + "DGX Spark / GB10 uses unified memory; nvidia-smi framebuffer " + "memory stats may be unavailable" + ) gpu_stats = { "available": True, @@ -151,21 +195,16 @@ def format_bytes(byte_count): "memory_total": format_bytes(memory_total), "memory_used": format_bytes(memory_used), "memory_free": format_bytes(memory_free), - "memory_percent": f"{memory_percent:.1f}%", + "memory_percent": f"{memory_percent:.1f}%" if memory_percent is not None else "N/A", + "memory_source": memory_source or "unknown", + "memory_note": memory_note, } - # Store in Redis with 60 second expiration redis_client = celery_app.backend.client - redis_client.setex( - "gpu_stats", - 60, # Expire after 60 seconds - json.dumps(gpu_stats), - ) + redis_client.setex("gpu_stats", 60, json.dumps(gpu_stats)) - # Broadcast to all connected WebSocket clients try: import redis as sync_redis - from app.core.config import settings broadcast_client = sync_redis.from_url(settings.REDIS_URL) @@ -179,12 +218,10 @@ def format_bytes(byte_count): } ), ) - logger.debug("Broadcast GPU stats update via WebSocket") except Exception as broadcast_err: logger.warning(f"Failed to broadcast GPU stats: {broadcast_err}") - # Clear debounce lock (best-effort, non-critical) - with contextlib.suppress(Exception): # noqa: S110 + with contextlib.suppress(Exception): redis_client.delete("gpu_stats_pending") logger.debug(f"Updated GPU stats in Redis: {gpu_stats}") @@ -192,15 +229,15 @@ def format_bytes(byte_count): except ImportError: logger.warning("PyTorch not available for GPU monitoring") - gpu_stats = { + return { "available": False, "name": "PyTorch Not Installed", "memory_total": "N/A", "memory_used": "N/A", "memory_free": "N/A", "memory_percent": "N/A", + "memory_source": "none", } - return gpu_stats except Exception as e: logger.error(f"Error updating GPU stats: {str(e)}") return { @@ -210,8 +247,8 @@ def format_bytes(byte_count): "memory_used": "Unknown", "memory_free": "Unknown", "memory_percent": "Unknown", + "memory_source": "error", "error": str(e), } - - + # All recovery tasks have been moved to app.tasks.recovery diff --git a/backend/requirements.txt b/backend/requirements.txt index 13e6c809..1bab2de6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -46,6 +46,9 @@ torchaudio==2.8.0 # WhisperX latest version with ctranslate2 4.5+ support whisperx==3.7.0 +# Hugging Face Hub pinned for compatibility with libraries still using use_auth_token +huggingface_hub==0.23.5 + # CTranslate2 with cuDNN 9 support (required for CUDA 12.8) ctranslate2>=4.6.0 diff --git a/docker-compose.blackwell.yml b/docker-compose.blackwell.yml new file mode 100644 index 00000000..82735cbe --- /dev/null +++ b/docker-compose.blackwell.yml @@ -0,0 +1,43 @@ +services: + celery-worker: + build: + context: ./backend + dockerfile: Dockerfile.blackwell + image: opentranscribe-backend:blackwell + pull_policy: never + user: "1000:1000" + command: celery -A app.core.celery worker --loglevel=info -Q gpu --pool=solo + env_file: .env + environment: + POSTGRES_HOST: postgres + POSTGRES_PORT: 5432 + MINIO_HOST: minio + MINIO_PORT: 9000 + REDIS_HOST: redis + REDIS_PORT: 6379 + OPENSEARCH_HOST: opensearch + OPENSEARCH_PORT: 9200 + MODELS_DIRECTORY: /app/models + MODEL_BASE_DIR: /app/models + TEMP_DIR: /app/temp + TORCH_CUDA_ARCH_LIST: "9.0" + CUDA_FORCE_PTX_JIT: "1" + HF_HOME: /home/user/.cache/huggingface + TRANSFORMERS_CACHE: /home/user/.cache/huggingface/transformers + TORCH_HOME: /home/user/.cache/torch + NLTK_DATA: /home/user/.cache/nltk_data + SENTENCE_TRANSFORMERS_HOME: /home/user/.cache/sentence-transformers + XDG_CACHE_HOME: /home/user/.cache + volumes: + - ${MODEL_CACHE_DIR:-./models}/huggingface:/home/user/.cache/huggingface + - ${MODEL_CACHE_DIR:-./models}/torch:/home/user/.cache/torch + - ${MODEL_CACHE_DIR:-./models}/nltk_data:/home/user/.cache/nltk_data + - ${MODEL_CACHE_DIR:-./models}/sentence-transformers:/home/user/.cache/sentence-transformers + - ${MODEL_CACHE_DIR:-./models}/yt-dlp:/home/user/.cache/yt-dlp + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] + device_ids: ["${GPU_DEVICE_ID:-0}"] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 3e416c2b..e11c4e7f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,6 @@ services: postgres: - container_name: opentranscribe-postgres image: postgres:17.5-alpine restart: always env_file: .env @@ -24,7 +23,6 @@ services: retries: 5 minio: - container_name: opentranscribe-minio image: minio/minio:RELEASE.2025-09-07T16-13-09Z restart: always env_file: .env @@ -41,7 +39,6 @@ services: retries: 5 redis: - container_name: opentranscribe-redis image: redis:8.2.2-alpine3.22 restart: always env_file: .env @@ -68,7 +65,6 @@ services: retries: 50 opensearch: - container_name: opentranscribe-opensearch image: opensearchproject/opensearch:3.3.1 restart: always environment: @@ -92,7 +88,6 @@ services: retries: 20 backend: - container_name: opentranscribe-backend restart: always env_file: .env # Backend handles API requests only - no AI models needed @@ -130,7 +125,6 @@ services: condition: service_healthy celery-worker: - container_name: opentranscribe-celery-worker restart: always env_file: .env command: celery -A app.core.celery worker --loglevel=info -Q gpu --concurrency=1 @@ -169,7 +163,6 @@ services: - opensearch celery-download-worker: - container_name: opentranscribe-celery-download-worker restart: always env_file: .env command: celery -A app.core.celery worker --loglevel=info -Q download --concurrency=3 --max-tasks-per-child=10 @@ -200,7 +193,6 @@ services: - opensearch celery-cpu-worker: - container_name: opentranscribe-celery-cpu-worker restart: always env_file: .env command: celery -A app.core.celery worker --loglevel=info -Q cpu,utility --concurrency=8 --max-tasks-per-child=20 @@ -228,7 +220,6 @@ services: - opensearch celery-nlp-worker: - container_name: opentranscribe-celery-nlp-worker restart: always env_file: .env command: celery -A app.core.celery worker --loglevel=info -Q nlp,celery --concurrency=4 --max-tasks-per-child=50 @@ -256,7 +247,6 @@ services: - opensearch celery-beat: - container_name: opentranscribe-celery-beat restart: always env_file: .env command: celery -A app.core.celery beat --loglevel=info @@ -286,7 +276,6 @@ services: # GPU configuration (deploy.resources.reservations) is provided by gpu-scale.yml # Image/build/volumes provided by override.yml (dev) or prod.yml/offline.yml celery-worker-gpu-scaled: - container_name: ${COMPOSE_PROJECT_NAME:-opentranscribe}-celery-worker-gpu-scaled profiles: - gpu-scale # Only activated with --gpu-scale flag scale: 0 # Disabled by default, gpu-scale.yml sets scale: 1 @@ -314,7 +303,6 @@ services: - opensearch frontend: - container_name: opentranscribe-frontend restart: always env_file: .env # Note: ports, healthcheck, and NODE_ENV defined in override files (dev vs prod) @@ -323,7 +311,6 @@ services: condition: service_healthy flower: - container_name: opentranscribe-flower restart: always env_file: .env command: > diff --git a/opentranscribe.sh b/opentranscribe.sh index 4a7b1345..8e6b13a6 100755 --- a/opentranscribe.sh +++ b/opentranscribe.sh @@ -115,12 +115,17 @@ get_compose_files() { compose_files="$compose_files -f docker-compose.prod.yml" fi - # Add GPU overlay if NVIDIA runtime is available and overlay exists + # Add GPU overlay if NVIDIA runtime is available local docker_runtime docker_runtime=$(detect_nvidia_runtime) - if [ "$docker_runtime" = "nvidia" ] && [ -f docker-compose.gpu.yml ]; then - compose_files="$compose_files -f docker-compose.gpu.yml" - echo -e "${BLUE}🎯 GPU acceleration enabled (NVIDIA Container Toolkit detected)${NC}" >&2 + if [ "$docker_runtime" = "nvidia" ]; then + if [ -f docker-compose.blackwell.yml ]; then + compose_files="$compose_files -f docker-compose.blackwell.yml" + echo -e "${BLUE}🎯 Blackwell GPU overlay enabled${NC}" >&2 + elif [ -f docker-compose.gpu.yml ]; then + compose_files="$compose_files -f docker-compose.gpu.yml" + echo -e "${BLUE}🎯 GPU acceleration enabled (standard NVIDIA overlay)${NC}" >&2 + fi fi # Add NGINX overlay if NGINX_SERVER_NAME is configured