diff --git a/backend/app.py b/backend/app.py index f652d14..1293460 100644 --- a/backend/app.py +++ b/backend/app.py @@ -155,6 +155,20 @@ def _get_gpu_status() -> str: return "MPS (Apple Silicon)" elif backend_type == "mlx": return "Metal (Apple Silicon via MLX)" + + # Intel XPU (Arc / Data Center) via IPEX + try: + import intel_extension_for_pytorch # noqa: F401 + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + try: + xpu_name = torch.xpu.get_device_name(0) + except Exception: + xpu_name = "Intel GPU" + return f"XPU ({xpu_name})" + except ImportError: + pass + return "None (CPU only)" diff --git a/backend/backends/base.py b/backend/backends/base.py index 9a3049a..0fdfa34 100644 --- a/backend/backends/base.py +++ b/backend/backends/base.py @@ -126,6 +126,37 @@ def get_torch_device( return "cpu" +def empty_device_cache(device: str) -> None: + """ + Free cached memory on the given device (CUDA or XPU). + + Backends should call this after unloading models so VRAM is returned + to the OS. + """ + import torch + + if device == "cuda" and torch.cuda.is_available(): + torch.cuda.empty_cache() + elif device == "xpu" and hasattr(torch, "xpu"): + torch.xpu.empty_cache() + + +def manual_seed(seed: int, device: str) -> None: + """ + Set the random seed on both CPU and the active accelerator. + + Covers CUDA and Intel XPU so that generation is reproducible + regardless of which GPU backend is in use. + """ + import torch + + torch.manual_seed(seed) + if device == "cuda" and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + elif device == "xpu" and hasattr(torch, "xpu"): + torch.xpu.manual_seed(seed) + + async def combine_voice_prompts( audio_paths: List[str], reference_texts: List[str], diff --git a/backend/backends/chatterbox_backend.py b/backend/backends/chatterbox_backend.py index 7efe371..e7a025b 100644 --- a/backend/backends/chatterbox_backend.py +++ b/backend/backends/chatterbox_backend.py @@ -18,6 +18,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -48,7 +50,7 @@ def __init__(self): self._model_load_lock = asyncio.Lock() def _get_device(self) -> str: - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -117,10 +119,7 @@ def unload_model(self) -> None: del self.model self.model = None self._device = None - if device == "cuda": - import torch - - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("Chatterbox unloaded") async def create_voice_prompt( @@ -200,7 +199,7 @@ def _generate_sync(): import torch if seed is not None: - torch.manual_seed(seed) + manual_seed(seed, self._device) logger.info(f"[Chatterbox] Generating: lang={language}") @@ -220,10 +219,7 @@ def _generate_sync(): else: audio = np.asarray(wav, dtype=np.float32) - sample_rate = ( - getattr(self.model, "sr", None) - or getattr(self.model, "sample_rate", 24000) - ) + sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000) return audio, sample_rate diff --git a/backend/backends/chatterbox_turbo_backend.py b/backend/backends/chatterbox_turbo_backend.py index a8bfe50..6f7d6b9 100644 --- a/backend/backends/chatterbox_turbo_backend.py +++ b/backend/backends/chatterbox_turbo_backend.py @@ -18,6 +18,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -48,7 +50,7 @@ def __init__(self): self._model_load_lock = asyncio.Lock() def _get_device(self) -> str: - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -116,10 +118,7 @@ def unload_model(self) -> None: del self.model self.model = None self._device = None - if device == "cuda": - import torch - - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("Chatterbox Turbo unloaded") async def create_voice_prompt( @@ -181,7 +180,7 @@ def _generate_sync(): import torch if seed is not None: - torch.manual_seed(seed) + manual_seed(seed, self._device) logger.info("[Chatterbox Turbo] Generating (English)") @@ -200,10 +199,7 @@ def _generate_sync(): else: audio = np.asarray(wav, dtype=np.float32) - sample_rate = ( - getattr(self.model, "sr", None) - or getattr(self.model, "sample_rate", 24000) - ) + sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000) return audio, sample_rate diff --git a/backend/backends/hume_backend.py b/backend/backends/hume_backend.py index 456fd46..ecaa29b 100644 --- a/backend/backends/hume_backend.py +++ b/backend/backends/hume_backend.py @@ -24,6 +24,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) @@ -66,7 +68,7 @@ def __init__(self): def _get_device(self) -> str: # Force CPU on macOS — MPS has issues with flow matching # and large vocab lm_head (>65536 output channels) - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -105,6 +107,7 @@ def _load_model_sync(self, model_size: str = "1B"): # package. The real package pulls in onnx/tensorboard/matplotlib via # descript-audiotools, so we use a lightweight shim instead. from ..utils.dac_shim import install_dac_shim + install_dac_shim() import torch @@ -142,9 +145,12 @@ def _load_model_sync(self, model_size: str = "1B"): allow_patterns=["tokenizer*", "special_tokens*"], ) - # Determine dtype — use bf16 on CUDA for ~50% memory savings + # Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings if device == "cuda" and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 + elif device == "xpu": + # Intel Arc (Alchemist+) supports bf16 natively + model_dtype = torch.bfloat16 else: model_dtype = torch.float32 @@ -153,14 +159,14 @@ def _load_model_sync(self, model_size: str = "1B"): # This avoids monkey-patching AutoTokenizer.from_pretrained # which corrupts the classmethod descriptor for other engines. from tada.modules.aligner import AlignerConfig + AlignerConfig.tokenizer_name = tokenizer_path # Load encoder (only needed for voice prompt encoding) from tada.modules.encoder import Encoder + logger.info("Loading TADA encoder...") - self.encoder = Encoder.from_pretrained( - TADA_CODEC_REPO, subfolder="encoder" - ).to(device) + self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device) self.encoder.eval() # Load the causal LM (includes decoder for wav generation). @@ -169,12 +175,11 @@ def _load_model_sync(self, model_size: str = "1B"): # which hits the gated repo. Pre-load the config from HF, # inject the local tokenizer path, then pass it in. from tada.modules.tada import TadaForCausalLM, TadaConfig + logger.info(f"Loading TADA {model_size} model...") config = TadaConfig.from_pretrained(repo) config.tokenizer_name = tokenizer_path - self.model = TadaForCausalLM.from_pretrained( - repo, config=config, torch_dtype=model_dtype - ).to(device) + self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device) self.model.eval() logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}") @@ -188,11 +193,11 @@ def unload_model(self) -> None: del self.encoder self.encoder = None + device = self._device self._device = None - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if device: + empty_device_cache(device) logger.info("HumeAI TADA unloaded") @@ -213,9 +218,7 @@ async def create_voice_prompt( """ await self.load_model(self.model_size) - cache_key = ( - "tada_" + get_cache_key(audio_path, reference_text) - ) if use_cache else None + cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None if cache_key: cached = get_cached_voice_prompt(cache_key) @@ -239,9 +242,7 @@ def _encode_sync(): # Encode with forced alignment text_arg = [reference_text] if reference_text else None - prompt = self.encoder( - audio, text=text_arg, sample_rate=sr - ) + prompt = self.encoder(audio, text=text_arg, sample_rate=sr) # Serialize EncoderOutput to a dict of CPU tensors for caching prompt_dict = {} @@ -299,9 +300,7 @@ def _generate_sync(): from tada.modules.encoder import EncoderOutput if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self._device) device = self._device diff --git a/backend/backends/luxtts_backend.py b/backend/backends/luxtts_backend.py index ba00359..7f15686 100644 --- a/backend/backends/luxtts_backend.py +++ b/backend/backends/luxtts_backend.py @@ -12,7 +12,14 @@ import numpy as np from . import TTSBackend -from .base import is_model_cached, get_torch_device, combine_voice_prompts as _combine_voice_prompts, model_load_progress +from .base import ( + is_model_cached, + get_torch_device, + empty_device_cache, + manual_seed, + combine_voice_prompts as _combine_voice_prompts, + model_load_progress, +) from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt logger = logging.getLogger(__name__) @@ -30,7 +37,7 @@ def __init__(self): self._device = None def _get_device(self) -> str: - return get_torch_device(allow_mps=True) + return get_torch_device(allow_mps=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -69,9 +76,12 @@ def _load_model_sync(self): if device == "cpu": import os + threads = os.cpu_count() or 4 self.model = LuxTTS( - model_path=LUXTTS_HF_REPO, device="cpu", threads=min(threads, 8), + model_path=LUXTTS_HF_REPO, + device="cpu", + threads=min(threads, 8), ) else: self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device) @@ -81,12 +91,12 @@ def _load_model_sync(self): def unload_model(self) -> None: """Unload model to free memory.""" if self.model is not None: + device = self.device del self.model self.model = None + self._device = None - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("LuxTTS unloaded") @@ -154,12 +164,8 @@ async def generate( await self.load_model() def _generate_sync(): - import torch - if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self.device) wav = self.model.generate_speech( text=text, diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 8ed4ab9..a8dfbcb 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -14,6 +14,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) @@ -120,8 +122,7 @@ def unload_model(self): self.model = None self._current_model_size = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(self.device) logger.info("TTS model unloaded") @@ -213,9 +214,7 @@ def _generate_sync(): """Run synchronous generation in thread pool.""" # Set seed if provided if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self.device) # Generate audio - this is the blocking operation wavs, sample_rate = self.model.generate_voice_clone( @@ -297,8 +296,7 @@ def unload_model(self): self.model = None self.processor = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(self.device) logger.info("Whisper model unloaded") diff --git a/backend/routes/health.py b/backend/routes/health.py index f138e33..66cc9b6 100644 --- a/backend/routes/health.py +++ b/backend/routes/health.py @@ -110,6 +110,11 @@ async def health(): vram_used = None if has_cuda: vram_used = torch.cuda.memory_allocated() / 1024 / 1024 + elif has_xpu: + try: + vram_used = torch.xpu.memory_allocated() / 1024 / 1024 + except Exception: + pass # memory_allocated() may not be available on all IPEX versions model_loaded = False model_size = None @@ -162,7 +167,10 @@ async def health(): gpu_type=gpu_type, vram_used_mb=vram_used, backend_type=backend_type, - backend_variant=os.environ.get("VOICEBOX_BACKEND_VARIANT", "cuda" if torch.cuda.is_available() else "cpu"), + backend_variant=os.environ.get( + "VOICEBOX_BACKEND_VARIANT", + "cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"), + ), ) diff --git a/justfile b/justfile index 24814e3..b17243f 100644 --- a/justfile +++ b/justfile @@ -69,10 +69,22 @@ setup-python: } Write-Host "Installing Python dependencies..." & "{{ python }}" -m pip install --upgrade pip -q - $hasNvidia = $null -ne (Get-WmiObject Win32_VideoController | Where-Object { $_.Name -match 'NVIDIA' }) + $gpus = Get-CimInstance Win32_VideoController | Select-Object -ExpandProperty Name + Write-Host "Detected GPUs: $($gpus -join ', ')" + $hasNvidia = ($gpus | Where-Object { $_ -match 'NVIDIA' }).Count -gt 0 + $hasIntelArc = ($gpus | Where-Object { $_ -match 'Arc' }).Count -gt 0 if ($hasNvidia) { \ Write-Host "NVIDIA GPU detected — installing PyTorch with CUDA support..."; \ & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128; \ + } elseif ($hasIntelArc) { \ + Write-Host "Intel Arc GPU detected — installing PyTorch with XPU support..."; \ + & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu; \ + & "{{ pip }}" install intel-extension-for-pytorch --index-url https://download.pytorch.org/whl/xpu; \ + } else { \ + Write-Host "No NVIDIA or Intel Arc GPU detected — using CPU-only PyTorch."; \ + Write-Host "If you have an Intel Arc GPU, install XPU support manually:"; \ + Write-Host " pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu"; \ + Write-Host " pip install intel-extension-for-pytorch --index-url https://download.pytorch.org/whl/xpu"; \ } & "{{ pip }}" install -r {{ backend_dir }}/requirements.txt & "{{ pip }}" install --no-deps chatterbox-tts