Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,20 @@ def _get_gpu_status() -> str:
return "MPS (Apple Silicon)"
elif backend_type == "mlx":
return "Metal (Apple Silicon via MLX)"

# Intel XPU (Arc / Data Center) via IPEX
try:
import intel_extension_for_pytorch # noqa: F401

if hasattr(torch, "xpu") and torch.xpu.is_available():
try:
xpu_name = torch.xpu.get_device_name(0)
except Exception:
xpu_name = "Intel GPU"
return f"XPU ({xpu_name})"
except ImportError:
pass

return "None (CPU only)"


Expand Down
31 changes: 31 additions & 0 deletions backend/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,37 @@ def get_torch_device(
return "cpu"


def empty_device_cache(device: str) -> None:
"""
Free cached memory on the given device (CUDA or XPU).

Backends should call this after unloading models so VRAM is returned
to the OS.
"""
import torch

if device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.empty_cache()


def manual_seed(seed: int, device: str) -> None:
"""
Set the random seed on both CPU and the active accelerator.

Covers CUDA and Intel XPU so that generation is reproducible
regardless of which GPU backend is in use.
"""
import torch

torch.manual_seed(seed)
if device == "cuda" and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.manual_seed(seed)


async def combine_voice_prompts(
audio_paths: List[str],
reference_texts: List[str],
Expand Down
16 changes: 6 additions & 10 deletions backend/backends/chatterbox_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
Expand Down Expand Up @@ -48,7 +50,7 @@ def __init__(self):
self._model_load_lock = asyncio.Lock()

def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True)
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)

def is_loaded(self) -> bool:
return self.model is not None
Expand Down Expand Up @@ -117,10 +119,7 @@ def unload_model(self) -> None:
del self.model
self.model = None
self._device = None
if device == "cuda":
import torch

torch.cuda.empty_cache()
empty_device_cache(device)
logger.info("Chatterbox unloaded")

async def create_voice_prompt(
Expand Down Expand Up @@ -200,7 +199,7 @@ def _generate_sync():
import torch

if seed is not None:
torch.manual_seed(seed)
manual_seed(seed, self._device)

logger.info(f"[Chatterbox] Generating: lang={language}")

Expand All @@ -220,10 +219,7 @@ def _generate_sync():
else:
audio = np.asarray(wav, dtype=np.float32)

sample_rate = (
getattr(self.model, "sr", None)
or getattr(self.model, "sample_rate", 24000)
)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)

return audio, sample_rate

Expand Down
16 changes: 6 additions & 10 deletions backend/backends/chatterbox_turbo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
Expand Down Expand Up @@ -48,7 +50,7 @@ def __init__(self):
self._model_load_lock = asyncio.Lock()

def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True)
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)

def is_loaded(self) -> bool:
return self.model is not None
Expand Down Expand Up @@ -116,10 +118,7 @@ def unload_model(self) -> None:
del self.model
self.model = None
self._device = None
if device == "cuda":
import torch

torch.cuda.empty_cache()
empty_device_cache(device)
logger.info("Chatterbox Turbo unloaded")

async def create_voice_prompt(
Expand Down Expand Up @@ -181,7 +180,7 @@ def _generate_sync():
import torch

if seed is not None:
torch.manual_seed(seed)
manual_seed(seed, self._device)

logger.info("[Chatterbox Turbo] Generating (English)")

Expand All @@ -200,10 +199,7 @@ def _generate_sync():
else:
audio = np.asarray(wav, dtype=np.float32)

sample_rate = (
getattr(self.model, "sr", None)
or getattr(self.model, "sample_rate", 24000)
)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)

return audio, sample_rate

Expand Down
39 changes: 19 additions & 20 deletions backend/backends/hume_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(self):
def _get_device(self) -> str:
# Force CPU on macOS — MPS has issues with flow matching
# and large vocab lm_head (>65536 output channels)
return get_torch_device(force_cpu_on_mac=True)
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)

def is_loaded(self) -> bool:
return self.model is not None
Expand Down Expand Up @@ -105,6 +107,7 @@ def _load_model_sync(self, model_size: str = "1B"):
# package. The real package pulls in onnx/tensorboard/matplotlib via
# descript-audiotools, so we use a lightweight shim instead.
from ..utils.dac_shim import install_dac_shim

install_dac_shim()

import torch
Expand Down Expand Up @@ -142,9 +145,12 @@ def _load_model_sync(self, model_size: str = "1B"):
allow_patterns=["tokenizer*", "special_tokens*"],
)

# Determine dtype — use bf16 on CUDA for ~50% memory savings
# Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings
if device == "cuda" and torch.cuda.is_bf16_supported():
model_dtype = torch.bfloat16
elif device == "xpu":
# Intel Arc (Alchemist+) supports bf16 natively
model_dtype = torch.bfloat16
else:
model_dtype = torch.float32

Expand All @@ -153,14 +159,14 @@ def _load_model_sync(self, model_size: str = "1B"):
# This avoids monkey-patching AutoTokenizer.from_pretrained
# which corrupts the classmethod descriptor for other engines.
from tada.modules.aligner import AlignerConfig

AlignerConfig.tokenizer_name = tokenizer_path

# Load encoder (only needed for voice prompt encoding)
from tada.modules.encoder import Encoder

logger.info("Loading TADA encoder...")
self.encoder = Encoder.from_pretrained(
TADA_CODEC_REPO, subfolder="encoder"
).to(device)
self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device)
self.encoder.eval()

# Load the causal LM (includes decoder for wav generation).
Expand All @@ -169,12 +175,11 @@ def _load_model_sync(self, model_size: str = "1B"):
# which hits the gated repo. Pre-load the config from HF,
# inject the local tokenizer path, then pass it in.
from tada.modules.tada import TadaForCausalLM, TadaConfig

logger.info(f"Loading TADA {model_size} model...")
config = TadaConfig.from_pretrained(repo)
config.tokenizer_name = tokenizer_path
self.model = TadaForCausalLM.from_pretrained(
repo, config=config, torch_dtype=model_dtype
).to(device)
self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device)
self.model.eval()

logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}")
Expand All @@ -188,11 +193,11 @@ def unload_model(self) -> None:
del self.encoder
self.encoder = None

device = self._device
self._device = None

import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if device:
empty_device_cache(device)

logger.info("HumeAI TADA unloaded")

Expand All @@ -213,9 +218,7 @@ async def create_voice_prompt(
"""
await self.load_model(self.model_size)

cache_key = (
"tada_" + get_cache_key(audio_path, reference_text)
) if use_cache else None
cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None

if cache_key:
cached = get_cached_voice_prompt(cache_key)
Expand All @@ -239,9 +242,7 @@ def _encode_sync():

# Encode with forced alignment
text_arg = [reference_text] if reference_text else None
prompt = self.encoder(
audio, text=text_arg, sample_rate=sr
)
prompt = self.encoder(audio, text=text_arg, sample_rate=sr)

# Serialize EncoderOutput to a dict of CPU tensors for caching
prompt_dict = {}
Expand Down Expand Up @@ -299,9 +300,7 @@ def _generate_sync():
from tada.modules.encoder import EncoderOutput

if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
manual_seed(seed, self._device)

device = self._device

Expand Down
28 changes: 17 additions & 11 deletions backend/backends/luxtts_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
import numpy as np

from . import TTSBackend
from .base import is_model_cached, get_torch_device, combine_voice_prompts as _combine_voice_prompts, model_load_progress
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt

logger = logging.getLogger(__name__)
Expand All @@ -30,7 +37,7 @@ def __init__(self):
self._device = None

def _get_device(self) -> str:
return get_torch_device(allow_mps=True)
return get_torch_device(allow_mps=True, allow_xpu=True)

def is_loaded(self) -> bool:
return self.model is not None
Expand Down Expand Up @@ -69,9 +76,12 @@ def _load_model_sync(self):

if device == "cpu":
import os

threads = os.cpu_count() or 4
self.model = LuxTTS(
model_path=LUXTTS_HF_REPO, device="cpu", threads=min(threads, 8),
model_path=LUXTTS_HF_REPO,
device="cpu",
threads=min(threads, 8),
)
else:
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
Expand All @@ -81,12 +91,12 @@ def _load_model_sync(self):
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self.device
del self.model
self.model = None
self._device = None

import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
empty_device_cache(device)

logger.info("LuxTTS unloaded")

Expand Down Expand Up @@ -154,12 +164,8 @@ async def generate(
await self.load_model()

def _generate_sync():
import torch

if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
manual_seed(seed, self.device)

wav = self.model.generate_speech(
text=text,
Expand Down
12 changes: 5 additions & 7 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
Expand Down Expand Up @@ -120,8 +122,7 @@ def unload_model(self):
self.model = None
self._current_model_size = None

if torch.cuda.is_available():
torch.cuda.empty_cache()
empty_device_cache(self.device)

logger.info("TTS model unloaded")

Expand Down Expand Up @@ -213,9 +214,7 @@ def _generate_sync():
"""Run synchronous generation in thread pool."""
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
manual_seed(seed, self.device)

# Generate audio - this is the blocking operation
wavs, sample_rate = self.model.generate_voice_clone(
Expand Down Expand Up @@ -297,8 +296,7 @@ def unload_model(self):
self.model = None
self.processor = None

if torch.cuda.is_available():
torch.cuda.empty_cache()
empty_device_cache(self.device)

logger.info("Whisper model unloaded")

Expand Down
Loading