From 83ebababe7c514af8c3c9972f2c57c3184c1548c Mon Sep 17 00:00:00 2001 From: James Pine Date: Wed, 18 Mar 2026 11:24:51 -0700 Subject: [PATCH 1/2] feat: add Intel Arc (XPU) GPU support across all backends Auto-detect Intel Arc GPUs during Windows setup and install PyTorch with XPU support + intel-extension-for-pytorch. Enable allow_xpu=True on all TTS backends (Chatterbox, Chatterbox Turbo, Hume TADA, LuxTTS) that previously only supported CUDA. Add shared empty_device_cache() and manual_seed() helpers in base.py to handle XPU memory management and reproducible seeding alongside CUDA. --- backend/backends/base.py | 31 ++++++++++++++++ backend/backends/chatterbox_backend.py | 13 ++----- backend/backends/chatterbox_turbo_backend.py | 13 ++----- backend/backends/hume_backend.py | 39 ++++++++++---------- backend/backends/luxtts_backend.py | 21 ++++++++--- backend/backends/pytorch_backend.py | 12 +++--- justfile | 5 +++ 7 files changed, 83 insertions(+), 51 deletions(-) diff --git a/backend/backends/base.py b/backend/backends/base.py index 9a3049a0..0fdfa344 100644 --- a/backend/backends/base.py +++ b/backend/backends/base.py @@ -126,6 +126,37 @@ def get_torch_device( return "cpu" +def empty_device_cache(device: str) -> None: + """ + Free cached memory on the given device (CUDA or XPU). + + Backends should call this after unloading models so VRAM is returned + to the OS. + """ + import torch + + if device == "cuda" and torch.cuda.is_available(): + torch.cuda.empty_cache() + elif device == "xpu" and hasattr(torch, "xpu"): + torch.xpu.empty_cache() + + +def manual_seed(seed: int, device: str) -> None: + """ + Set the random seed on both CPU and the active accelerator. + + Covers CUDA and Intel XPU so that generation is reproducible + regardless of which GPU backend is in use. + """ + import torch + + torch.manual_seed(seed) + if device == "cuda" and torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + elif device == "xpu" and hasattr(torch, "xpu"): + torch.xpu.manual_seed(seed) + + async def combine_voice_prompts( audio_paths: List[str], reference_texts: List[str], diff --git a/backend/backends/chatterbox_backend.py b/backend/backends/chatterbox_backend.py index 7efe371a..9b061bb1 100644 --- a/backend/backends/chatterbox_backend.py +++ b/backend/backends/chatterbox_backend.py @@ -18,6 +18,7 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -48,7 +49,7 @@ def __init__(self): self._model_load_lock = asyncio.Lock() def _get_device(self) -> str: - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -117,10 +118,7 @@ def unload_model(self) -> None: del self.model self.model = None self._device = None - if device == "cuda": - import torch - - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("Chatterbox unloaded") async def create_voice_prompt( @@ -220,10 +218,7 @@ def _generate_sync(): else: audio = np.asarray(wav, dtype=np.float32) - sample_rate = ( - getattr(self.model, "sr", None) - or getattr(self.model, "sample_rate", 24000) - ) + sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000) return audio, sample_rate diff --git a/backend/backends/chatterbox_turbo_backend.py b/backend/backends/chatterbox_turbo_backend.py index a8bfe503..fcdbf39e 100644 --- a/backend/backends/chatterbox_turbo_backend.py +++ b/backend/backends/chatterbox_turbo_backend.py @@ -18,6 +18,7 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -48,7 +49,7 @@ def __init__(self): self._model_load_lock = asyncio.Lock() def _get_device(self) -> str: - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -116,10 +117,7 @@ def unload_model(self) -> None: del self.model self.model = None self._device = None - if device == "cuda": - import torch - - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("Chatterbox Turbo unloaded") async def create_voice_prompt( @@ -200,10 +198,7 @@ def _generate_sync(): else: audio = np.asarray(wav, dtype=np.float32) - sample_rate = ( - getattr(self.model, "sr", None) - or getattr(self.model, "sample_rate", 24000) - ) + sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000) return audio, sample_rate diff --git a/backend/backends/hume_backend.py b/backend/backends/hume_backend.py index 456fd46e..ecaa29b7 100644 --- a/backend/backends/hume_backend.py +++ b/backend/backends/hume_backend.py @@ -24,6 +24,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) @@ -66,7 +68,7 @@ def __init__(self): def _get_device(self) -> str: # Force CPU on macOS — MPS has issues with flow matching # and large vocab lm_head (>65536 output channels) - return get_torch_device(force_cpu_on_mac=True) + return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -105,6 +107,7 @@ def _load_model_sync(self, model_size: str = "1B"): # package. The real package pulls in onnx/tensorboard/matplotlib via # descript-audiotools, so we use a lightweight shim instead. from ..utils.dac_shim import install_dac_shim + install_dac_shim() import torch @@ -142,9 +145,12 @@ def _load_model_sync(self, model_size: str = "1B"): allow_patterns=["tokenizer*", "special_tokens*"], ) - # Determine dtype — use bf16 on CUDA for ~50% memory savings + # Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings if device == "cuda" and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 + elif device == "xpu": + # Intel Arc (Alchemist+) supports bf16 natively + model_dtype = torch.bfloat16 else: model_dtype = torch.float32 @@ -153,14 +159,14 @@ def _load_model_sync(self, model_size: str = "1B"): # This avoids monkey-patching AutoTokenizer.from_pretrained # which corrupts the classmethod descriptor for other engines. from tada.modules.aligner import AlignerConfig + AlignerConfig.tokenizer_name = tokenizer_path # Load encoder (only needed for voice prompt encoding) from tada.modules.encoder import Encoder + logger.info("Loading TADA encoder...") - self.encoder = Encoder.from_pretrained( - TADA_CODEC_REPO, subfolder="encoder" - ).to(device) + self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device) self.encoder.eval() # Load the causal LM (includes decoder for wav generation). @@ -169,12 +175,11 @@ def _load_model_sync(self, model_size: str = "1B"): # which hits the gated repo. Pre-load the config from HF, # inject the local tokenizer path, then pass it in. from tada.modules.tada import TadaForCausalLM, TadaConfig + logger.info(f"Loading TADA {model_size} model...") config = TadaConfig.from_pretrained(repo) config.tokenizer_name = tokenizer_path - self.model = TadaForCausalLM.from_pretrained( - repo, config=config, torch_dtype=model_dtype - ).to(device) + self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device) self.model.eval() logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}") @@ -188,11 +193,11 @@ def unload_model(self) -> None: del self.encoder self.encoder = None + device = self._device self._device = None - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if device: + empty_device_cache(device) logger.info("HumeAI TADA unloaded") @@ -213,9 +218,7 @@ async def create_voice_prompt( """ await self.load_model(self.model_size) - cache_key = ( - "tada_" + get_cache_key(audio_path, reference_text) - ) if use_cache else None + cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None if cache_key: cached = get_cached_voice_prompt(cache_key) @@ -239,9 +242,7 @@ def _encode_sync(): # Encode with forced alignment text_arg = [reference_text] if reference_text else None - prompt = self.encoder( - audio, text=text_arg, sample_rate=sr - ) + prompt = self.encoder(audio, text=text_arg, sample_rate=sr) # Serialize EncoderOutput to a dict of CPU tensors for caching prompt_dict = {} @@ -299,9 +300,7 @@ def _generate_sync(): from tada.modules.encoder import EncoderOutput if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self._device) device = self._device diff --git a/backend/backends/luxtts_backend.py b/backend/backends/luxtts_backend.py index ba00359e..6f88fa3e 100644 --- a/backend/backends/luxtts_backend.py +++ b/backend/backends/luxtts_backend.py @@ -12,7 +12,13 @@ import numpy as np from . import TTSBackend -from .base import is_model_cached, get_torch_device, combine_voice_prompts as _combine_voice_prompts, model_load_progress +from .base import ( + is_model_cached, + get_torch_device, + empty_device_cache, + combine_voice_prompts as _combine_voice_prompts, + model_load_progress, +) from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt logger = logging.getLogger(__name__) @@ -30,7 +36,7 @@ def __init__(self): self._device = None def _get_device(self) -> str: - return get_torch_device(allow_mps=True) + return get_torch_device(allow_mps=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None @@ -69,9 +75,12 @@ def _load_model_sync(self): if device == "cpu": import os + threads = os.cpu_count() or 4 self.model = LuxTTS( - model_path=LUXTTS_HF_REPO, device="cpu", threads=min(threads, 8), + model_path=LUXTTS_HF_REPO, + device="cpu", + threads=min(threads, 8), ) else: self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device) @@ -81,12 +90,12 @@ def _load_model_sync(self): def unload_model(self) -> None: """Unload model to free memory.""" if self.model is not None: + device = self.device del self.model self.model = None + self._device = None - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(device) logger.info("LuxTTS unloaded") diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 8ed4ab9c..a8dfbcb9 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -14,6 +14,8 @@ from .base import ( is_model_cached, get_torch_device, + empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) @@ -120,8 +122,7 @@ def unload_model(self): self.model = None self._current_model_size = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(self.device) logger.info("TTS model unloaded") @@ -213,9 +214,7 @@ def _generate_sync(): """Run synchronous generation in thread pool.""" # Set seed if provided if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self.device) # Generate audio - this is the blocking operation wavs, sample_rate = self.model.generate_voice_clone( @@ -297,8 +296,7 @@ def unload_model(self): self.model = None self.processor = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache(self.device) logger.info("Whisper model unloaded") diff --git a/justfile b/justfile index 24814e30..1bddbf69 100644 --- a/justfile +++ b/justfile @@ -70,9 +70,14 @@ setup-python: Write-Host "Installing Python dependencies..." & "{{ python }}" -m pip install --upgrade pip -q $hasNvidia = $null -ne (Get-WmiObject Win32_VideoController | Where-Object { $_.Name -match 'NVIDIA' }) + $hasIntelArc = $null -ne (Get-WmiObject Win32_VideoController | Where-Object { $_.Name -match 'Intel.*Arc' }) if ($hasNvidia) { \ Write-Host "NVIDIA GPU detected — installing PyTorch with CUDA support..."; \ & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128; \ + } elseif ($hasIntelArc) { \ + Write-Host "Intel Arc GPU detected — installing PyTorch with XPU support..."; \ + & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu; \ + & "{{ pip }}" install intel-extension-for-pytorch --index-url https://download.pytorch.org/whl/xpu; \ } & "{{ pip }}" install -r {{ backend_dir }}/requirements.txt & "{{ pip }}" install --no-deps chatterbox-tts From 707046237c42fe6c3c6b10c03c261321d8229105 Mon Sep 17 00:00:00 2001 From: James Pine Date: Wed, 18 Mar 2026 17:01:12 -0700 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20complete=20Intel=20XPU=20support=20?= =?UTF-8?q?=E2=80=94=20device-aware=20seeding,=20GPU=20status=20reporting,?= =?UTF-8?q?=20and=20setup=20detection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address CodeRabbit review feedback and user-reported GPU acceleration failure: - Use shared manual_seed() in chatterbox, chatterbox_turbo, and luxtts backends so XPU (and future accelerators) get proper device seeding - Add XPU branch to _get_gpu_status() so startup log reports Intel Arc GPUs instead of 'None (CPU only)' - Add XPU VRAM reporting and correct backend_variant fallback in the /health endpoint - Switch justfile GPU detection from Get-WmiObject to Get-CimInstance, simplify the Arc regex to match 'Arc' (not 'Intel.*Arc'), log detected GPUs, and print manual install instructions on miss Resolves the root cause where IPEX was silently not installed due to WMI detection failure, causing CPU-only fallback on Intel Arc systems. --- backend/app.py | 14 ++++++++++++++ backend/backends/chatterbox_backend.py | 3 ++- backend/backends/chatterbox_turbo_backend.py | 3 ++- backend/backends/luxtts_backend.py | 7 ++----- backend/routes/health.py | 10 +++++++++- justfile | 11 +++++++++-- 6 files changed, 38 insertions(+), 10 deletions(-) diff --git a/backend/app.py b/backend/app.py index f652d149..1293460a 100644 --- a/backend/app.py +++ b/backend/app.py @@ -155,6 +155,20 @@ def _get_gpu_status() -> str: return "MPS (Apple Silicon)" elif backend_type == "mlx": return "Metal (Apple Silicon via MLX)" + + # Intel XPU (Arc / Data Center) via IPEX + try: + import intel_extension_for_pytorch # noqa: F401 + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + try: + xpu_name = torch.xpu.get_device_name(0) + except Exception: + xpu_name = "Intel GPU" + return f"XPU ({xpu_name})" + except ImportError: + pass + return "None (CPU only)" diff --git a/backend/backends/chatterbox_backend.py b/backend/backends/chatterbox_backend.py index 9b061bb1..e7a025b3 100644 --- a/backend/backends/chatterbox_backend.py +++ b/backend/backends/chatterbox_backend.py @@ -19,6 +19,7 @@ is_model_cached, get_torch_device, empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -198,7 +199,7 @@ def _generate_sync(): import torch if seed is not None: - torch.manual_seed(seed) + manual_seed(seed, self._device) logger.info(f"[Chatterbox] Generating: lang={language}") diff --git a/backend/backends/chatterbox_turbo_backend.py b/backend/backends/chatterbox_turbo_backend.py index fcdbf39e..6f7d6b94 100644 --- a/backend/backends/chatterbox_turbo_backend.py +++ b/backend/backends/chatterbox_turbo_backend.py @@ -19,6 +19,7 @@ is_model_cached, get_torch_device, empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, patch_chatterbox_f32, @@ -179,7 +180,7 @@ def _generate_sync(): import torch if seed is not None: - torch.manual_seed(seed) + manual_seed(seed, self._device) logger.info("[Chatterbox Turbo] Generating (English)") diff --git a/backend/backends/luxtts_backend.py b/backend/backends/luxtts_backend.py index 6f88fa3e..7f15686a 100644 --- a/backend/backends/luxtts_backend.py +++ b/backend/backends/luxtts_backend.py @@ -16,6 +16,7 @@ is_model_cached, get_torch_device, empty_device_cache, + manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) @@ -163,12 +164,8 @@ async def generate( await self.load_model() def _generate_sync(): - import torch - if seed is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + manual_seed(seed, self.device) wav = self.model.generate_speech( text=text, diff --git a/backend/routes/health.py b/backend/routes/health.py index f138e336..66cc9b62 100644 --- a/backend/routes/health.py +++ b/backend/routes/health.py @@ -110,6 +110,11 @@ async def health(): vram_used = None if has_cuda: vram_used = torch.cuda.memory_allocated() / 1024 / 1024 + elif has_xpu: + try: + vram_used = torch.xpu.memory_allocated() / 1024 / 1024 + except Exception: + pass # memory_allocated() may not be available on all IPEX versions model_loaded = False model_size = None @@ -162,7 +167,10 @@ async def health(): gpu_type=gpu_type, vram_used_mb=vram_used, backend_type=backend_type, - backend_variant=os.environ.get("VOICEBOX_BACKEND_VARIANT", "cuda" if torch.cuda.is_available() else "cpu"), + backend_variant=os.environ.get( + "VOICEBOX_BACKEND_VARIANT", + "cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"), + ), ) diff --git a/justfile b/justfile index 1bddbf69..b17243fb 100644 --- a/justfile +++ b/justfile @@ -69,8 +69,10 @@ setup-python: } Write-Host "Installing Python dependencies..." & "{{ python }}" -m pip install --upgrade pip -q - $hasNvidia = $null -ne (Get-WmiObject Win32_VideoController | Where-Object { $_.Name -match 'NVIDIA' }) - $hasIntelArc = $null -ne (Get-WmiObject Win32_VideoController | Where-Object { $_.Name -match 'Intel.*Arc' }) + $gpus = Get-CimInstance Win32_VideoController | Select-Object -ExpandProperty Name + Write-Host "Detected GPUs: $($gpus -join ', ')" + $hasNvidia = ($gpus | Where-Object { $_ -match 'NVIDIA' }).Count -gt 0 + $hasIntelArc = ($gpus | Where-Object { $_ -match 'Arc' }).Count -gt 0 if ($hasNvidia) { \ Write-Host "NVIDIA GPU detected — installing PyTorch with CUDA support..."; \ & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128; \ @@ -78,6 +80,11 @@ setup-python: Write-Host "Intel Arc GPU detected — installing PyTorch with XPU support..."; \ & "{{ pip }}" install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu; \ & "{{ pip }}" install intel-extension-for-pytorch --index-url https://download.pytorch.org/whl/xpu; \ + } else { \ + Write-Host "No NVIDIA or Intel Arc GPU detected — using CPU-only PyTorch."; \ + Write-Host "If you have an Intel Arc GPU, install XPU support manually:"; \ + Write-Host " pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu"; \ + Write-Host " pip install intel-extension-for-pytorch --index-url https://download.pytorch.org/whl/xpu"; \ } & "{{ pip }}" install -r {{ backend_dir }}/requirements.txt & "{{ pip }}" install --no-deps chatterbox-tts