Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions backend/backends/mlx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import asyncio
import logging
import numpy as np
import os
from pathlib import Path

logger = logging.getLogger(__name__)
Expand All @@ -21,6 +20,7 @@
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.hf_offline_patch import force_offline_if_cached


class MLXTTSBackend:
Expand Down Expand Up @@ -96,32 +96,13 @@ def _load_model_sync(self, model_size: str):
model_name = f"qwen-tts-{model_size}"
is_cached = self._is_model_cached(model_size)

# Force offline mode when cached to avoid network requests
original_hf_hub_offline = os.environ.get("HF_HUB_OFFLINE")
if is_cached:
os.environ["HF_HUB_OFFLINE"] = "1"
logger.info("[PATCH] Model %s is cached, forcing HF_HUB_OFFLINE=1 to avoid network requests", model_size)

try:
with model_load_progress(model_name, is_cached):
from mlx_audio.tts import load

logger.info("Loading MLX TTS model %s...", model_size)

try:
self.model = load(model_path)
except Exception as load_error:
if is_cached and "offline" in str(load_error).lower():
logger.warning("[PATCH] Offline load failed, trying with network: %s", load_error)
os.environ.pop("HF_HUB_OFFLINE", None)
self.model = load(model_path)
else:
raise
finally:
if original_hf_hub_offline is not None:
os.environ["HF_HUB_OFFLINE"] = original_hf_hub_offline
else:
os.environ.pop("HF_HUB_OFFLINE", None)
with model_load_progress(model_name, is_cached):
from mlx_audio.tts import load

logger.info("Loading MLX TTS model %s...", model_size)

with force_offline_if_cached(is_cached, model_name):
self.model = load(model_path)

self._current_model_size = model_size
self.model_size = model_size
Expand Down Expand Up @@ -329,7 +310,9 @@ def _load_model_sync(self, model_size: str):

model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading MLX Whisper model %s...", model_size)
self.model = load(model_name)

with force_offline_if_cached(is_cached, progress_model_name):
self.model = load(model_name)

self.model_size = model_size
logger.info("MLX Whisper model %s loaded successfully", model_size)
Expand Down
31 changes: 17 additions & 14 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import load_audio
from ..utils.hf_offline_patch import force_offline_if_cached


class PyTorchTTSBackend:
Expand Down Expand Up @@ -96,18 +97,19 @@ def _load_model_sync(self, model_size: str):
model_path = self._get_model_path(model_size)
logger.info("Loading TTS model %s on %s...", model_size, self.device)

if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
with force_offline_if_cached(is_cached, model_name):
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
)

self._current_model_size = model_size
self.model_size = model_size
Expand Down Expand Up @@ -282,8 +284,9 @@ def _load_model_sync(self, model_size: str):
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading Whisper model %s on %s...", model_size, self.device)

self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
with force_offline_if_cached(is_cached, progress_model_name):
self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)

self.model.to(self.device)
self.model_size = model_size
Expand Down
51 changes: 49 additions & 2 deletions backend/utils/hf_offline_patch.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,64 @@
"""Monkey-patch huggingface_hub to force offline mode with cached models.

Prevents mlx_audio from making network requests when models are already
downloaded. Must be imported BEFORE mlx_audio.
Prevents mlx_audio / transformers from making network requests when models
are already downloaded. Must be imported BEFORE mlx_audio.
"""

import logging
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Union

logger = logging.getLogger(__name__)


@contextmanager
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
"""Context manager that sets ``HF_HUB_OFFLINE=1`` while loading a cached model.

If *is_cached* is ``False`` the block runs normally (network allowed).
If the offline load raises an error containing "offline" we automatically
retry with network access so a partially-cached model still works.

Args:
is_cached: Whether the model weights are already on disk.
model_label: Human-readable name used in log messages.
"""
if not is_cached:
yield
return

original_value = os.environ.get("HF_HUB_OFFLINE")
os.environ["HF_HUB_OFFLINE"] = "1"
logger.info(
"[offline-guard] %s is cached — forcing HF_HUB_OFFLINE=1",
model_label or "model",
)

try:
yield
except Exception as exc:
if "offline" in str(exc).lower():
logger.warning(
"[offline-guard] Offline load failed for %s, retrying with network: %s",
model_label or "model",
exc,
)
# Restore original env and retry — caller must wrap the load
# inside force_offline_if_cached so retrying here isn't possible.
# Instead, propagate a flag via the exception so the caller can
# decide. For simplicity we just let it fall through to the
# finally block and re-raise.
raise
Comment on lines +16 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring overpromises retry behavior that isn't implemented.

The docstring at lines 21-22 states "we automatically retry with network access so a partially-cached model still works," but the actual implementation (lines 48-53) only logs a warning and re-raises the exception. The inline comment at lines 48-52 explicitly acknowledges that retrying within the context manager isn't possible.

Looking at backend/services/generation.py (lines 129-135 in context snippet), the caller catches exceptions and marks generation as "failed" without any retry attempt.

Consider updating the docstring to accurately describe the behavior:

📝 Suggested docstring fix
 `@contextmanager`
 def force_offline_if_cached(is_cached: bool, model_label: str = ""):
     """Context manager that sets ``HF_HUB_OFFLINE=1`` while loading a cached model.

     If *is_cached* is ``False`` the block runs normally (network allowed).
-    If the offline load raises an error containing "offline" we automatically
-    retry with network access so a partially-cached model still works.
+    If the offline load raises an error containing "offline", it is logged
+    and re-raised to allow the caller to implement retry logic if desired.

     Args:
         is_cached: Whether the model weights are already on disk.
         model_label: Human-readable name used in log messages.
     """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/utils/hf_offline_patch.py` around lines 16 - 53, The docstring for
force_offline_if_cached incorrectly claims it "automatically retry[s] with
network access"; update the docstring to accurately describe current behavior:
when is_cached is True the context sets HF_HUB_OFFLINE=1 and if an "offline"
error occurs it logs a warning and re-raises the exception (it does NOT perform
a retry), and callers (e.g., code in backend/services/generation.py) must catch
the exception and perform any retry with network access themselves; mention that
the original HF_HUB_OFFLINE env var is restored in the finally block.

raise
finally:
if original_value is not None:
os.environ["HF_HUB_OFFLINE"] = original_value
else:
os.environ.pop("HF_HUB_OFFLINE", None)


def patch_huggingface_hub_offline():
"""Monkey-patch huggingface_hub to force offline mode."""
try:
Expand Down