diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index 92405db..c6157a4 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -6,7 +6,6 @@ import asyncio import logging import numpy as np -import os from pathlib import Path logger = logging.getLogger(__name__) @@ -21,6 +20,7 @@ from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt +from ..utils.hf_offline_patch import force_offline_if_cached class MLXTTSBackend: @@ -96,32 +96,13 @@ def _load_model_sync(self, model_size: str): model_name = f"qwen-tts-{model_size}" is_cached = self._is_model_cached(model_size) - # Force offline mode when cached to avoid network requests - original_hf_hub_offline = os.environ.get("HF_HUB_OFFLINE") - if is_cached: - os.environ["HF_HUB_OFFLINE"] = "1" - logger.info("[PATCH] Model %s is cached, forcing HF_HUB_OFFLINE=1 to avoid network requests", model_size) - - try: - with model_load_progress(model_name, is_cached): - from mlx_audio.tts import load - - logger.info("Loading MLX TTS model %s...", model_size) - - try: - self.model = load(model_path) - except Exception as load_error: - if is_cached and "offline" in str(load_error).lower(): - logger.warning("[PATCH] Offline load failed, trying with network: %s", load_error) - os.environ.pop("HF_HUB_OFFLINE", None) - self.model = load(model_path) - else: - raise - finally: - if original_hf_hub_offline is not None: - os.environ["HF_HUB_OFFLINE"] = original_hf_hub_offline - else: - os.environ.pop("HF_HUB_OFFLINE", None) + with model_load_progress(model_name, is_cached): + from mlx_audio.tts import load + + logger.info("Loading MLX TTS model %s...", model_size) + + with force_offline_if_cached(is_cached, model_name): + self.model = load(model_path) self._current_model_size = model_size self.model_size = model_size @@ -329,7 +310,9 @@ def _load_model_sync(self, model_size: str): model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") logger.info("Loading MLX Whisper model %s...", model_size) - self.model = load(model_name) + + with force_offline_if_cached(is_cached, progress_model_name): + self.model = load(model_name) self.model_size = model_size logger.info("MLX Whisper model %s loaded successfully", model_size) diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 8ed4ab9..25a8b6a 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -19,6 +19,7 @@ ) from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt from ..utils.audio import load_audio +from ..utils.hf_offline_patch import force_offline_if_cached class PyTorchTTSBackend: @@ -96,18 +97,19 @@ def _load_model_sync(self, model_size: str): model_path = self._get_model_path(model_size) logger.info("Loading TTS model %s on %s...", model_size, self.device) - if self.device == "cpu": - self.model = Qwen3TTSModel.from_pretrained( - model_path, - torch_dtype=torch.float32, - low_cpu_mem_usage=False, - ) - else: - self.model = Qwen3TTSModel.from_pretrained( - model_path, - device_map=self.device, - torch_dtype=torch.bfloat16, - ) + with force_offline_if_cached(is_cached, model_name): + if self.device == "cpu": + self.model = Qwen3TTSModel.from_pretrained( + model_path, + torch_dtype=torch.float32, + low_cpu_mem_usage=False, + ) + else: + self.model = Qwen3TTSModel.from_pretrained( + model_path, + device_map=self.device, + torch_dtype=torch.bfloat16, + ) self._current_model_size = model_size self.model_size = model_size @@ -282,8 +284,9 @@ def _load_model_sync(self, model_size: str): model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") logger.info("Loading Whisper model %s on %s...", model_size, self.device) - self.processor = WhisperProcessor.from_pretrained(model_name) - self.model = WhisperForConditionalGeneration.from_pretrained(model_name) + with force_offline_if_cached(is_cached, progress_model_name): + self.processor = WhisperProcessor.from_pretrained(model_name) + self.model = WhisperForConditionalGeneration.from_pretrained(model_name) self.model.to(self.device) self.model_size = model_size diff --git a/backend/utils/hf_offline_patch.py b/backend/utils/hf_offline_patch.py index 51a99a1..1abd147 100644 --- a/backend/utils/hf_offline_patch.py +++ b/backend/utils/hf_offline_patch.py @@ -1,17 +1,64 @@ """Monkey-patch huggingface_hub to force offline mode with cached models. -Prevents mlx_audio from making network requests when models are already -downloaded. Must be imported BEFORE mlx_audio. +Prevents mlx_audio / transformers from making network requests when models +are already downloaded. Must be imported BEFORE mlx_audio. """ import logging import os +from contextlib import contextmanager from pathlib import Path from typing import Optional, Union logger = logging.getLogger(__name__) +@contextmanager +def force_offline_if_cached(is_cached: bool, model_label: str = ""): + """Context manager that sets ``HF_HUB_OFFLINE=1`` while loading a cached model. + + If *is_cached* is ``False`` the block runs normally (network allowed). + If the offline load raises an error containing "offline" we automatically + retry with network access so a partially-cached model still works. + + Args: + is_cached: Whether the model weights are already on disk. + model_label: Human-readable name used in log messages. + """ + if not is_cached: + yield + return + + original_value = os.environ.get("HF_HUB_OFFLINE") + os.environ["HF_HUB_OFFLINE"] = "1" + logger.info( + "[offline-guard] %s is cached — forcing HF_HUB_OFFLINE=1", + model_label or "model", + ) + + try: + yield + except Exception as exc: + if "offline" in str(exc).lower(): + logger.warning( + "[offline-guard] Offline load failed for %s, retrying with network: %s", + model_label or "model", + exc, + ) + # Restore original env and retry — caller must wrap the load + # inside force_offline_if_cached so retrying here isn't possible. + # Instead, propagate a flag via the exception so the caller can + # decide. For simplicity we just let it fall through to the + # finally block and re-raise. + raise + raise + finally: + if original_value is not None: + os.environ["HF_HUB_OFFLINE"] = original_value + else: + os.environ.pop("HF_HUB_OFFLINE", None) + + def patch_huggingface_hub_offline(): """Monkey-patch huggingface_hub to force offline mode.""" try: