From 4e0c731db85b1de355cef9fa45d386f0e69509bd Mon Sep 17 00:00:00 2001 From: James Pine Date: Thu, 19 Mar 2026 18:31:30 -0700 Subject: [PATCH 1/4] feat: add Qwen CustomVoice preset engine --- .../Generation/EngineModelSelector.tsx | 15 +- .../components/Generation/GenerationForm.tsx | 2 +- .../ServerSettings/ModelManagement.tsx | 5 + .../components/VoiceProfiles/ProfileCard.tsx | 8 +- .../components/VoiceProfiles/ProfileForm.tsx | 4 +- .../components/VoiceProfiles/ProfileList.tsx | 3 +- app/src/lib/api/types.ts | 9 +- app/src/lib/constants/languages.ts | 1 + app/src/lib/hooks/useGenerationForm.ts | 32 ++- backend/backends/__init__.py | 55 ++++- backend/backends/qwen_custom_voice_backend.py | 210 ++++++++++++++++++ backend/build_binary.py | 2 + backend/models.py | 2 +- backend/routes/profiles.py | 72 +++++- 14 files changed, 399 insertions(+), 21 deletions(-) create mode 100644 backend/backends/qwen_custom_voice_backend.py diff --git a/app/src/components/Generation/EngineModelSelector.tsx b/app/src/components/Generation/EngineModelSelector.tsx index 80fd82a..7195a28 100644 --- a/app/src/components/Generation/EngineModelSelector.tsx +++ b/app/src/components/Generation/EngineModelSelector.tsx @@ -19,6 +19,8 @@ import type { GenerationFormValues } from '@/lib/hooks/useGenerationForm'; const ENGINE_OPTIONS = [ { value: 'qwen:1.7B', label: 'Qwen3-TTS 1.7B', engine: 'qwen' }, { value: 'qwen:0.6B', label: 'Qwen3-TTS 0.6B', engine: 'qwen' }, + { value: 'qwen_custom_voice:1.7B', label: 'Qwen CustomVoice 1.7B', engine: 'qwen_custom_voice' }, + { value: 'qwen_custom_voice:0.6B', label: 'Qwen CustomVoice 0.6B', engine: 'qwen_custom_voice' }, { value: 'luxtts', label: 'LuxTTS', engine: 'luxtts' }, { value: 'chatterbox', label: 'Chatterbox', engine: 'chatterbox' }, { value: 'chatterbox_turbo', label: 'Chatterbox Turbo', engine: 'chatterbox_turbo' }, @@ -29,6 +31,7 @@ const ENGINE_OPTIONS = [ const ENGINE_DESCRIPTIONS: Record = { qwen: 'Multi-language, two sizes', + qwen_custom_voice: '9 preset voices, instruct control', luxtts: 'Fast, English-focused', chatterbox: '23 languages, incl. Hebrew', chatterbox_turbo: 'English, [laugh] [cough] tags', @@ -49,12 +52,22 @@ function getAvailableOptions(selectedProfile?: VoiceProfileResponse | null) { function getSelectValue(engine: string, modelSize?: string): string { if (engine === 'qwen') return `qwen:${modelSize || '1.7B'}`; + if (engine === 'qwen_custom_voice') return `qwen_custom_voice:${modelSize || '1.7B'}`; if (engine === 'tada') return `tada:${modelSize || '1B'}`; return engine; } function handleEngineChange(form: UseFormReturn, value: string) { - if (value.startsWith('qwen:')) { + if (value.startsWith('qwen_custom_voice:')) { + const [, modelSize] = value.split(':'); + form.setValue('engine', 'qwen_custom_voice'); + form.setValue('modelSize', modelSize as '1.7B' | '0.6B'); + const currentLang = form.getValues('language'); + const available = getLanguageOptionsForEngine('qwen_custom_voice'); + if (!available.some((l) => l.value === currentLang)) { + form.setValue('language', available[0]?.value ?? 'en'); + } + } else if (value.startsWith('qwen:')) { const [, modelSize] = value.split(':'); form.setValue('engine', 'qwen'); form.setValue('modelSize', modelSize as '1.7B' | '0.6B'); diff --git a/app/src/components/Generation/GenerationForm.tsx b/app/src/components/Generation/GenerationForm.tsx index 9f7a7cd..1195e8b 100644 --- a/app/src/components/Generation/GenerationForm.tsx +++ b/app/src/components/Generation/GenerationForm.tsx @@ -91,7 +91,7 @@ export function GenerationForm() { )} /> - {form.watch('engine') === 'qwen' && ( + {(form.watch('engine') === 'qwen' || form.watch('engine') === 'qwen_custom_voice') && ( = { 'HumeAI TADA 3B Multilingual — built on Llama 3.2 3B. Supports 10 languages with high-fidelity voice cloning via text-acoustic dual alignment.', kokoro: 'Kokoro 82M by hexgrad. Tiny 82M-parameter TTS that runs at CPU realtime. Supports 8 languages with pre-built voice styles. Apache 2.0 licensed.', + 'qwen-custom-voice-1.7B': + 'Qwen3-TTS CustomVoice 1.7B by Alibaba. 9 premium preset voices with instruct-based style control for tone, emotion, and prosody. Supports 10 languages.', + 'qwen-custom-voice-0.6B': + 'Qwen3-TTS CustomVoice 0.6B by Alibaba. Lightweight version with the same 9 preset voices and instruct control. Faster inference for lower-end hardware.', 'whisper-base': 'Smallest Whisper model (74M parameters). Fast transcription with moderate accuracy.', 'whisper-small': @@ -396,6 +400,7 @@ export function ModelManagement() { modelStatus?.models.filter( (m) => m.model_name.startsWith('qwen-tts') || + m.model_name.startsWith('qwen-custom-voice') || m.model_name.startsWith('luxtts') || m.model_name.startsWith('chatterbox') || m.model_name.startsWith('tada') || diff --git a/app/src/components/VoiceProfiles/ProfileCard.tsx b/app/src/components/VoiceProfiles/ProfileCard.tsx index e2a9d4d..e634c38 100644 --- a/app/src/components/VoiceProfiles/ProfileCard.tsx +++ b/app/src/components/VoiceProfiles/ProfileCard.tsx @@ -17,6 +17,12 @@ import { useDeleteProfile, useExportProfile } from '@/lib/hooks/useProfiles'; import { cn } from '@/lib/utils/cn'; import { useUIStore } from '@/stores/uiStore'; +/** Human-readable display names for preset engine badges. */ +const ENGINE_DISPLAY_NAMES: Record = { + kokoro: 'Kokoro', + qwen_custom_voice: 'CustomVoice', +}; + interface ProfileCardProps { profile: VoiceProfileResponse; } @@ -99,7 +105,7 @@ export function ProfileCard({ profile }: ProfileCardProps) { {profile.voice_type === 'preset' && ( - {profile.preset_engine} + {ENGINE_DISPLAY_NAMES[profile.preset_engine ?? ''] ?? profile.preset_engine} )} {profile.voice_type === 'designed' && ( diff --git a/app/src/components/VoiceProfiles/ProfileForm.tsx b/app/src/components/VoiceProfiles/ProfileForm.tsx index 0c4987b..a148bfb 100644 --- a/app/src/components/VoiceProfiles/ProfileForm.tsx +++ b/app/src/components/VoiceProfiles/ProfileForm.tsx @@ -60,9 +60,10 @@ import { AudioSampleUpload } from './AudioSampleUpload'; import { SampleList } from './SampleList'; const MAX_AUDIO_DURATION_SECONDS = 30; -const PRESET_ONLY_ENGINES = new Set(['kokoro']); +const PRESET_ONLY_ENGINES = new Set(['kokoro', 'qwen_custom_voice']); const DEFAULT_ENGINE_OPTIONS = [ { value: 'qwen', label: 'Qwen3-TTS' }, + { value: 'qwen_custom_voice', label: 'Qwen CustomVoice' }, { value: 'luxtts', label: 'LuxTTS' }, { value: 'chatterbox', label: 'Chatterbox' }, { value: 'chatterbox_turbo', label: 'Chatterbox Turbo' }, @@ -849,6 +850,7 @@ export function ProfileForm() { Kokoro 82M + Qwen CustomVoice diff --git a/app/src/components/VoiceProfiles/ProfileList.tsx b/app/src/components/VoiceProfiles/ProfileList.tsx index 606f4a2..be7332a 100644 --- a/app/src/components/VoiceProfiles/ProfileList.tsx +++ b/app/src/components/VoiceProfiles/ProfileList.tsx @@ -7,11 +7,12 @@ import { ProfileCard } from './ProfileCard'; import { ProfileForm } from './ProfileForm'; /** Engines that use preset (built-in) voices instead of cloned profiles. */ -const PRESET_ENGINES = new Set(['kokoro']); +const PRESET_ENGINES = new Set(['kokoro', 'qwen_custom_voice']); /** Human-readable engine names for empty state messages. */ const ENGINE_NAMES: Record = { kokoro: 'Kokoro', + qwen_custom_voice: 'Qwen CustomVoice', }; export function ProfileList() { diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index 34e8038..86e3012 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -62,7 +62,14 @@ export interface GenerationRequest { language: LanguageCode; seed?: number; model_size?: '1.7B' | '0.6B' | '1B' | '3B'; - engine?: 'qwen' | 'luxtts' | 'chatterbox' | 'chatterbox_turbo' | 'tada' | 'kokoro'; + engine?: + | 'qwen' + | 'qwen_custom_voice' + | 'luxtts' + | 'chatterbox' + | 'chatterbox_turbo' + | 'tada' + | 'kokoro'; instruct?: string; max_chunk_chars?: number; crossfade_ms?: number; diff --git a/app/src/lib/constants/languages.ts b/app/src/lib/constants/languages.ts index 1a5c5f2..e28c519 100644 --- a/app/src/lib/constants/languages.ts +++ b/app/src/lib/constants/languages.ts @@ -69,6 +69,7 @@ export const ENGINE_LANGUAGES: Record = { chatterbox_turbo: ['en'], tada: ['en', 'ar', 'zh', 'de', 'es', 'fr', 'it', 'ja', 'pl', 'pt'], kokoro: ['en', 'es', 'fr', 'hi', 'it', 'pt', 'ja', 'zh'], + qwen_custom_voice: ['zh', 'en', 'ja', 'ko', 'de', 'fr', 'ru', 'pt', 'es', 'it'], } as const; /** Helper: get language options for a given engine. */ diff --git a/app/src/lib/hooks/useGenerationForm.ts b/app/src/lib/hooks/useGenerationForm.ts index 894d933..0acdabb 100644 --- a/app/src/lib/hooks/useGenerationForm.ts +++ b/app/src/lib/hooks/useGenerationForm.ts @@ -17,7 +17,17 @@ const generationSchema = z.object({ seed: z.number().int().optional(), modelSize: z.enum(['1.7B', '0.6B', '1B', '3B']).optional(), instruct: z.string().max(500).optional(), - engine: z.enum(['qwen', 'luxtts', 'chatterbox', 'chatterbox_turbo', 'tada', 'kokoro']).optional(), + engine: z + .enum([ + 'qwen', + 'qwen_custom_voice', + 'luxtts', + 'chatterbox', + 'chatterbox_turbo', + 'tada', + 'kokoro', + ]) + .optional(), }); export type GenerationFormValues = z.infer; @@ -85,7 +95,9 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { : 'tada-1b' : engine === 'kokoro' ? 'kokoro' - : `qwen-tts-${data.modelSize}`; + : engine === 'qwen_custom_voice' + ? `qwen-custom-voice-${data.modelSize}` + : `qwen-tts-${data.modelSize}`; const displayName = engine === 'luxtts' ? 'LuxTTS' @@ -99,9 +111,13 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { : 'TADA 1B' : engine === 'kokoro' ? 'Kokoro 82M' - : data.modelSize === '1.7B' - ? 'Qwen TTS 1.7B' - : 'Qwen TTS 0.6B'; + : engine === 'qwen_custom_voice' + ? data.modelSize === '1.7B' + ? 'Qwen CustomVoice 1.7B' + : 'Qwen CustomVoice 0.6B' + : data.modelSize === '1.7B' + ? 'Qwen TTS 1.7B' + : 'Qwen TTS 0.6B'; // Check if model needs downloading try { @@ -116,7 +132,9 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { console.error('Failed to check model status:', error); } - const hasModelSizes = engine === 'qwen' || engine === 'tada'; + const hasModelSizes = + engine === 'qwen' || engine === 'qwen_custom_voice' || engine === 'tada'; + const supportsInstruct = engine === 'qwen' || engine === 'qwen_custom_voice'; const effectsChain = options.getEffectsChain?.(); // This now returns immediately with status="generating" const result = await generation.mutateAsync({ @@ -126,7 +144,7 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { seed: data.seed, model_size: hasModelSizes ? data.modelSize : undefined, engine, - instruct: engine === 'qwen' ? data.instruct || undefined : undefined, + instruct: supportsInstruct ? data.instruct || undefined : undefined, max_chunk_chars: maxChunkChars, crossfade_ms: crossfadeMs, normalize: normalizeAudio, diff --git a/backend/backends/__init__.py b/backend/backends/__init__.py index 33ae57d..db19b14 100644 --- a/backend/backends/__init__.py +++ b/backend/backends/__init__.py @@ -163,6 +163,7 @@ def is_loaded(self) -> bool: # The factory function uses this for the if/elif chain; the model configs live on the backend classes. TTS_ENGINES = { "qwen": "Qwen TTS", + "qwen_custom_voice": "Qwen CustomVoice", "luxtts": "LuxTTS", "chatterbox": "Chatterbox TTS", "chatterbox_turbo": "Chatterbox Turbo", @@ -205,6 +206,32 @@ def _get_qwen_model_configs() -> list[ModelConfig]: ] +def _get_qwen_custom_voice_configs() -> list[ModelConfig]: + """Return Qwen CustomVoice model configs.""" + return [ + ModelConfig( + model_name="qwen-custom-voice-1.7B", + display_name="Qwen CustomVoice 1.7B", + engine="qwen_custom_voice", + hf_repo_id="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + model_size="1.7B", + size_mb=3500, + supports_instruct=True, + languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"], + ), + ModelConfig( + model_name="qwen-custom-voice-0.6B", + display_name="Qwen CustomVoice 0.6B", + engine="qwen_custom_voice", + hf_repo_id="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", + model_size="0.6B", + size_mb=1200, + supports_instruct=True, + languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"], + ), + ] + + def _get_non_qwen_tts_configs() -> list[ModelConfig]: """Return model configs for non-Qwen TTS engines. @@ -333,12 +360,12 @@ def _get_whisper_configs() -> list[ModelConfig]: def get_all_model_configs() -> list[ModelConfig]: """Return the full list of model configs (TTS + STT).""" - return _get_qwen_model_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs() + return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs() def get_tts_model_configs() -> list[ModelConfig]: """Return only TTS model configs.""" - return _get_qwen_model_configs() + _get_non_qwen_tts_configs() + return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() # Lookup helpers — these replace the if/elif chains in main.py @@ -369,7 +396,7 @@ def engine_has_model_sizes(engine: str) -> bool: async def load_engine_model(engine: str, model_size: str = "default") -> None: """Load a model for the given engine, handling engines with multiple model sizes.""" backend = get_tts_backend_for_engine(engine) - if engine == "qwen": + if engine in ("qwen", "qwen_custom_voice"): await backend.load_model_async(model_size) elif engine == "tada": await backend.load_model(model_size) @@ -388,7 +415,7 @@ async def ensure_model_cached_or_raise(engine: str, model_size: str = "default") cfg = c break - if engine in ("qwen", "tada"): + if engine in ("qwen", "qwen_custom_voice", "tada"): if not backend._is_model_cached(model_size): raise HTTPException( status_code=400, @@ -423,6 +450,14 @@ def unload_model_by_config(config: ModelConfig) -> bool: return True return False + if config.engine == "qwen_custom_voice": + backend = get_tts_backend_for_engine(config.engine) + loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None) + if backend.is_loaded() and loaded_size == config.model_size: + backend.unload_model() + return True + return False + # All other TTS engines backend = get_tts_backend_for_engine(config.engine) if backend.is_loaded(): @@ -446,6 +481,11 @@ def check_model_loaded(config: ModelConfig) -> bool: loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None) return tts_model.is_loaded() and loaded_size == config.model_size + if config.engine == "qwen_custom_voice": + backend = get_tts_backend_for_engine(config.engine) + loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None) + return backend.is_loaded() and loaded_size == config.model_size + backend = get_tts_backend_for_engine(config.engine) return backend.is_loaded() except Exception: @@ -463,6 +503,9 @@ def get_model_load_func(config: ModelConfig): if config.engine == "qwen": return lambda: tts.get_tts_model().load_model(config.model_size) + if config.engine == "qwen_custom_voice": + return lambda: get_tts_backend_for_engine(config.engine).load_model(config.model_size) + return lambda: get_tts_backend_for_engine(config.engine).load_model() @@ -528,6 +571,10 @@ def get_tts_backend_for_engine(engine: str) -> TTSBackend: from .kokoro_backend import KokoroTTSBackend backend = KokoroTTSBackend() + elif engine == "qwen_custom_voice": + from .qwen_custom_voice_backend import QwenCustomVoiceBackend + + backend = QwenCustomVoiceBackend() else: raise ValueError(f"Unknown TTS engine: {engine}. Supported: {list(TTS_ENGINES.keys())}") diff --git a/backend/backends/qwen_custom_voice_backend.py b/backend/backends/qwen_custom_voice_backend.py new file mode 100644 index 0000000..fbbf9f3 --- /dev/null +++ b/backend/backends/qwen_custom_voice_backend.py @@ -0,0 +1,210 @@ +""" +Qwen3-TTS CustomVoice backend implementation. + +Wraps the Qwen3-TTS-12Hz CustomVoice model for preset-speaker TTS with +instruction-based style control. Uses the same qwen_tts library as the +Base model (pytorch_backend.py) but loads a different checkpoint and +calls generate_custom_voice() instead of generate_voice_clone(). + +Key differences from the Base engine: + - Uses preset speakers (9 built-in voices) instead of zero-shot cloning + - Supports instruct parameter for tone/emotion/prosody control + - Two model sizes: 1.7B and 0.6B + +Languages supported: zh, en, ja, ko, de, fr, ru, pt, es, it +""" + +import asyncio +import logging +from typing import Optional + +import numpy as np +import torch + +from . import TTSBackend, LANGUAGE_CODE_TO_NAME +from .base import ( + is_model_cached, + get_torch_device, + combine_voice_prompts as _combine_voice_prompts, + model_load_progress, +) + +logger = logging.getLogger(__name__) + +# ── Preset speakers ────────────────────────────────────────────────── + +# (speaker_id, display_name, gender, native_language_code, description) +QWEN_CUSTOM_VOICES = [ + ("Vivian", "Vivian", "female", "zh", "Bright, slightly edgy young female voice"), + ("Serena", "Serena", "female", "zh", "Warm, gentle young female voice"), + ("Uncle_Fu", "Uncle Fu", "male", "zh", "Seasoned male voice with a low, mellow timbre"), + ("Dylan", "Dylan", "male", "zh", "Youthful Beijing male voice with a clear, natural timbre"), + ("Eric", "Eric", "male", "zh", "Lively Chengdu male voice with a slightly husky brightness"), + ("Ryan", "Ryan", "male", "en", "Dynamic male voice with strong rhythmic drive"), + ("Aiden", "Aiden", "male", "en", "Sunny American male voice with a clear midrange"), + ("Ono_Anna", "Ono Anna", "female", "ja", "Playful Japanese female voice with a light, nimble timbre"), + ("Sohee", "Sohee", "female", "ko", "Warm Korean female voice with rich emotion"), +] + +QWEN_CV_DEFAULT_SPEAKER = "Ryan" + +# HuggingFace repo IDs per model size +QWEN_CV_HF_REPOS = { + "1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + "0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", +} + + +class QwenCustomVoiceBackend: + """Qwen3-TTS CustomVoice backend — preset speakers with instruct control.""" + + def __init__(self, model_size: str = "1.7B"): + self.model = None + self.model_size = model_size + self.device = self._get_device() + self._current_model_size: Optional[str] = None + + def _get_device(self) -> str: + return get_torch_device(allow_xpu=True, allow_directml=True) + + def is_loaded(self) -> bool: + return self.model is not None + + def _get_model_path(self, model_size: str) -> str: + if model_size not in QWEN_CV_HF_REPOS: + raise ValueError(f"Unknown model size: {model_size}") + return QWEN_CV_HF_REPOS[model_size] + + def _is_model_cached(self, model_size: Optional[str] = None) -> bool: + size = model_size or self.model_size + return is_model_cached(self._get_model_path(size)) + + async def load_model_async(self, model_size: Optional[str] = None) -> None: + if model_size is None: + model_size = self.model_size + + if self.model is not None and self._current_model_size == model_size: + return + + if self.model is not None and self._current_model_size != model_size: + self.unload_model() + + await asyncio.to_thread(self._load_model_sync, model_size) + + # Alias for compatibility with the TTSBackend protocol + load_model = load_model_async + + def _load_model_sync(self, model_size: str) -> None: + model_name = f"qwen-custom-voice-{model_size}" + is_cached = self._is_model_cached(model_size) + + with model_load_progress(model_name, is_cached): + from qwen_tts import Qwen3TTSModel + + model_path = self._get_model_path(model_size) + logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device) + + if self.device == "cpu": + self.model = Qwen3TTSModel.from_pretrained( + model_path, + torch_dtype=torch.float32, + low_cpu_mem_usage=False, + ) + else: + self.model = Qwen3TTSModel.from_pretrained( + model_path, + device_map=self.device, + torch_dtype=torch.bfloat16, + ) + + self._current_model_size = model_size + self.model_size = model_size + logger.info("Qwen CustomVoice %s loaded successfully", model_size) + + def unload_model(self) -> None: + if self.model is not None: + del self.model + self.model = None + self._current_model_size = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info("Qwen CustomVoice unloaded") + + async def create_voice_prompt( + self, + audio_path: str, + reference_text: str, + use_cache: bool = True, + ) -> tuple[dict, bool]: + """ + Create voice prompt for CustomVoice. + + CustomVoice doesn't use reference audio — it uses preset speakers. + When called for a cloned profile (fallback), uses the default speaker. + For preset profiles, the voice_prompt dict is built by the profile + service and bypasses this method entirely. + """ + return { + "voice_type": "preset", + "preset_engine": "qwen_custom_voice", + "preset_voice_id": QWEN_CV_DEFAULT_SPEAKER, + }, False + + async def combine_voice_prompts( + self, + audio_paths: list[str], + reference_texts: list[str], + ) -> tuple[np.ndarray, str]: + return await _combine_voice_prompts(audio_paths, reference_texts) + + async def generate( + self, + text: str, + voice_prompt: dict, + language: str = "en", + seed: Optional[int] = None, + instruct: Optional[str] = None, + ) -> tuple[np.ndarray, int]: + """ + Generate audio using Qwen CustomVoice. + + Args: + text: Text to synthesize + voice_prompt: Dict with preset_voice_id (speaker name) + language: Language code (zh, en, ja, ko, etc.) + seed: Random seed for reproducibility + instruct: Natural language instruction for style control + (e.g. "Speak in an angry tone", "Very happy") + + Returns: + Tuple of (audio_array, sample_rate) + """ + await self.load_model_async(None) + + speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER + + def _generate_sync(): + if seed is not None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + lang_name = LANGUAGE_CODE_TO_NAME.get(language, "auto") + + kwargs = { + "text": text, + "language": lang_name.capitalize() if lang_name != "auto" else "Auto", + "speaker": speaker, + } + + # Only pass instruct if non-empty + if instruct: + kwargs["instruct"] = instruct + + wavs, sample_rate = self.model.generate_custom_voice(**kwargs) + return wavs[0], sample_rate + + audio, sample_rate = await asyncio.to_thread(_generate_sync) + return audio, sample_rate diff --git a/backend/build_binary.py b/backend/build_binary.py index ca7b21a..43ad071 100644 --- a/backend/build_binary.py +++ b/backend/build_binary.py @@ -86,6 +86,8 @@ def build_server(cuda=False): "--hidden-import", "backend.backends.pytorch_backend", "--hidden-import", + "backend.backends.qwen_custom_voice_backend", + "--hidden-import", "backend.utils.audio", "--hidden-import", "backend.utils.cache", diff --git a/backend/models.py b/backend/models.py index f568e0f..f2f43d4 100644 --- a/backend/models.py +++ b/backend/models.py @@ -78,7 +78,7 @@ class GenerationRequest(BaseModel): seed: Optional[int] = Field(None, ge=0) model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B|1B|3B)$") instruct: Optional[str] = Field(None, max_length=500) - engine: Optional[str] = Field(default="qwen", pattern="^(qwen|luxtts|chatterbox|chatterbox_turbo|tada|kokoro)$") + engine: Optional[str] = Field(default="qwen", pattern="^(qwen|qwen_custom_voice|luxtts|chatterbox|chatterbox_turbo|tada|kokoro)$") max_chunk_chars: int = Field( default=800, ge=100, le=5000, description="Max characters per chunk for long text splitting" ) diff --git a/backend/routes/profiles.py b/backend/routes/profiles.py index a13b02f..665f6b2 100644 --- a/backend/routes/profiles.py +++ b/backend/routes/profiles.py @@ -90,6 +90,21 @@ async def list_preset_voices(engine: str): for vid, name, gender, lang in KOKORO_VOICES ], } + if engine == "qwen_custom_voice": + from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES + + return { + "engine": engine, + "voices": [ + { + "voice_id": speaker_id, + "name": display_name, + "gender": gender, + "language": lang, + } + for speaker_id, display_name, gender, lang, _desc in QWEN_CUSTOM_VOICES + ], + } return {"engine": engine, "voices": []} @@ -103,9 +118,15 @@ async def seed_preset_profiles_route( Creates profiles for all available preset voices that don't already exist. Returns the count of newly created profiles. """ - if engine != "kokoro": - raise HTTPException(status_code=400, detail=f"No presets available for engine: {engine}") + if engine == "kokoro": + return _seed_kokoro_presets(db) + if engine == "qwen_custom_voice": + return _seed_qwen_custom_voice_presets(db) + raise HTTPException(status_code=400, detail=f"No presets available for engine: {engine}") + +def _seed_kokoro_presets(db: Session): + """Seed Kokoro preset profiles.""" try: from ..backends.kokoro_backend import KOKORO_VOICES @@ -154,12 +175,57 @@ async def seed_preset_profiles_route( db.commit() logger.info(f"Seeded {created} Kokoro preset profiles") - return {"engine": engine, "created": created, "total_available": len(KOKORO_VOICES)} + return {"engine": "kokoro", "created": created, "total_available": len(KOKORO_VOICES)} except Exception as e: logger.exception(f"Failed to seed Kokoro profiles: {e}") raise HTTPException(status_code=500, detail=str(e)) +def _seed_qwen_custom_voice_presets(db: Session): + """Seed Qwen CustomVoice preset profiles.""" + try: + from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES + + created = 0 + for speaker_id, display_name, gender, lang, description in QWEN_CUSTOM_VOICES: + # Skip if preset already exists + existing = ( + db.query(DBVoiceProfile) + .filter_by(preset_engine="qwen_custom_voice", preset_voice_id=speaker_id) + .first() + ) + if existing: + continue + + # Skip name collisions + if db.query(DBVoiceProfile).filter_by(name=display_name).first(): + continue + + profile = DBVoiceProfile( + id=str(uuid.uuid4()), + name=display_name, + description=f"Qwen CustomVoice — {description}", + language=lang, + voice_type="preset", + preset_engine="qwen_custom_voice", + preset_voice_id=speaker_id, + default_engine="qwen_custom_voice", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + db.add(profile) + created += 1 + + if created > 0: + db.commit() + logger.info(f"Seeded {created} Qwen CustomVoice preset profiles") + + return {"engine": "qwen_custom_voice", "created": created, "total_available": len(QWEN_CUSTOM_VOICES)} + except Exception as e: + logger.exception(f"Failed to seed Qwen CustomVoice profiles: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse) async def get_profile( profile_id: str, From 72c13fd3fc533de246e12ccfc2a84ee413758b03 Mon Sep 17 00:00:00 2001 From: James Pine Date: Thu, 19 Mar 2026 19:13:28 -0700 Subject: [PATCH 2/4] fix: enforce preset profile engine compatibility --- .../Generation/EngineModelSelector.tsx | 6 +- .../components/Generation/GenerationForm.tsx | 27 +++- .../components/VoiceProfiles/ProfileForm.tsx | 9 ++ app/src/lib/api/client.ts | 6 - backend/database/migrations.py | 11 +- backend/requirements.txt | 2 +- backend/routes/generations.py | 25 ++-- backend/routes/profiles.py | 121 ------------------ backend/services/profiles.py | 43 +++++++ backend/voicebox-server.spec | 2 +- 10 files changed, 102 insertions(+), 150 deletions(-) diff --git a/app/src/components/Generation/EngineModelSelector.tsx b/app/src/components/Generation/EngineModelSelector.tsx index 7195a28..7f4f600 100644 --- a/app/src/components/Generation/EngineModelSelector.tsx +++ b/app/src/components/Generation/EngineModelSelector.tsx @@ -57,7 +57,7 @@ function getSelectValue(engine: string, modelSize?: string): string { return engine; } -function handleEngineChange(form: UseFormReturn, value: string) { +export function applyEngineSelection(form: UseFormReturn, value: string) { if (value.startsWith('qwen_custom_voice:')) { const [, modelSize] = value.split(':'); form.setValue('engine', 'qwen_custom_voice'); @@ -123,7 +123,7 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo useEffect(() => { if (!currentEngineAvailable && availableOptions.length > 0) { - handleEngineChange(form, availableOptions[0].value); + applyEngineSelection(form, availableOptions[0].value); } }, [availableOptions, currentEngineAvailable, form]); @@ -133,7 +133,7 @@ export function EngineModelSelector({ form, compact, selectedProfile }: EngineMo : undefined; return ( - applyEngineSelection(form, v)}> diff --git a/app/src/components/Generation/GenerationForm.tsx b/app/src/components/Generation/GenerationForm.tsx index 1195e8b..ef3ff2c 100644 --- a/app/src/components/Generation/GenerationForm.tsx +++ b/app/src/components/Generation/GenerationForm.tsx @@ -1,3 +1,4 @@ +import { useEffect } from 'react'; import { Loader2, Mic } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; @@ -19,19 +20,41 @@ import { SelectValue, } from '@/components/ui/select'; import { Textarea } from '@/components/ui/textarea'; -import { getLanguageOptionsForEngine } from '@/lib/constants/languages'; +import { getLanguageOptionsForEngine, type LanguageCode } from '@/lib/constants/languages'; import { useGenerationForm } from '@/lib/hooks/useGenerationForm'; import { useProfile } from '@/lib/hooks/useProfiles'; import { useUIStore } from '@/stores/uiStore'; -import { EngineModelSelector, getEngineDescription } from './EngineModelSelector'; +import { EngineModelSelector, applyEngineSelection, getEngineDescription } from './EngineModelSelector'; import { ParalinguisticInput } from './ParalinguisticInput'; +function getEngineSelectValue(engine: string): string { + if (engine === 'qwen') return 'qwen:1.7B'; + if (engine === 'qwen_custom_voice') return 'qwen_custom_voice:1.7B'; + if (engine === 'tada') return 'tada:1B'; + return engine; +} + export function GenerationForm() { const selectedProfileId = useUIStore((state) => state.selectedProfileId); const { data: selectedProfile } = useProfile(selectedProfileId || ''); const { form, handleSubmit, isPending } = useGenerationForm(); + useEffect(() => { + if (!selectedProfile) { + return; + } + + if (selectedProfile.language) { + form.setValue('language', selectedProfile.language as LanguageCode); + } + + const preferredEngine = selectedProfile.default_engine || selectedProfile.preset_engine; + if (preferredEngine) { + applyEngineSelection(form, getEngineSelectValue(preferredEngine)); + } + }, [form, selectedProfile]); + async function onSubmit(data: Parameters[0]) { await handleSubmit(data, selectedProfileId); } diff --git a/app/src/components/VoiceProfiles/ProfileForm.tsx b/app/src/components/VoiceProfiles/ProfileForm.tsx index a148bfb..50b8cb5 100644 --- a/app/src/components/VoiceProfiles/ProfileForm.tsx +++ b/app/src/components/VoiceProfiles/ProfileForm.tsx @@ -375,6 +375,15 @@ export function ProfileForm() { } }, [availableDefaultEngines, defaultEngine]); + useEffect(() => { + if (!selectedPresetVoiceId) { + return; + } + + if (!presetVoices.some((voice: PresetVoice) => voice.voice_id === selectedPresetVoiceId)) { + setSelectedPresetVoiceId(''); + } + }, [presetVoices, selectedPresetVoiceId]); async function handleTranscribe() { const file = form.getValues('sampleFile'); if (!file) { diff --git a/app/src/lib/api/client.ts b/app/src/lib/api/client.ts index 6849b8c..98a375e 100644 --- a/app/src/lib/api/client.ts +++ b/app/src/lib/api/client.ts @@ -102,12 +102,6 @@ class ApiClient { return this.request<{ engine: string; voices: PresetVoice[] }>(`/profiles/presets/${engine}`); } - async seedPresetProfiles( - engine: string, - ): Promise<{ engine: string; created: number; total_available: number }> { - return this.request(`/profiles/presets/${engine}/seed`, { method: 'POST' }); - } - async updateProfile(profileId: string, data: VoiceProfileCreate): Promise { return this.request(`/profiles/${profileId}`, { method: 'PUT', diff --git a/backend/database/migrations.py b/backend/database/migrations.py index f4cc5ad..6256d92 100644 --- a/backend/database/migrations.py +++ b/backend/database/migrations.py @@ -194,6 +194,8 @@ def _resolve_relative_paths(engine, tables: set[str]) -> None: configured data directory. If the path starts with "data/", strip that prefix and prepend get_data_dir(). Otherwise, join the relative path directly under get_data_dir(). + directly under get_data_dir(). If the rebased path still does not exist, + fall back to resolving relative to CWD. """ from pathlib import Path @@ -222,16 +224,13 @@ def _resolve_relative_paths(engine, tables: set[str]) -> None: p = Path(path_val) if p.is_absolute(): continue - - # Try rebasing: "data/generations/abc.wav" → data_dir / "generations/abc.wav" parts = p.parts if parts and parts[0] == "data": - rebased = data_dir / Path(*parts[1:]) + rebased = (data_dir / Path(*parts[1:])).resolve() else: - rebased = data_dir / p - - resolved = rebased.resolve() + rebased = (data_dir / p).resolve() + resolved = rebased if rebased.exists() else p.resolve() if resolved.exists(): conn.execute( text(f"UPDATE {table} SET {column} = :path WHERE id = :id"), diff --git a/backend/requirements.txt b/backend/requirements.txt index c9c65b0..e916b1d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -42,7 +42,7 @@ torchaudio # Kokoro TTS (lightweight 82M-param engine) kokoro>=0.9.4 -misaki[en]>=0.9.4 +misaki[en,ja,zh]>=0.9.4 # spacy model for misaki English G2P — must be pre-installed or misaki # tries spacy.cli.download() at runtime which crashes frozen builds en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl diff --git a/backend/routes/generations.py b/backend/routes/generations.py index 9f985fc..2af3832 100644 --- a/backend/routes/generations.py +++ b/backend/routes/generations.py @@ -20,6 +20,10 @@ router = APIRouter() +def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str: + return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen" + + @router.post("/generate", response_model=models.GenerationResponse) async def generate_speech( data: models.GenerationRequest, @@ -35,7 +39,12 @@ async def generate_speech( from ..backends import engine_has_model_sizes - engine = data.engine or "qwen" + engine = _resolve_generation_engine(data, profile) + try: + profiles.validate_profile_engine(profile, engine) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None generation = await history.create_generation( @@ -230,15 +239,11 @@ async def stream_speech( if not profile: raise HTTPException(status_code=404, detail="Profile not found") - # Mirror the regular /generate endpoint behavior more closely: - # if the caller doesn't specify an engine, prefer the profile's default - # engine (or preset engine) before falling back to qwen. - engine = ( - data.engine - or getattr(profile, "default_engine", None) - or getattr(profile, "preset_engine", None) - or "qwen" - ) + engine = _resolve_generation_engine(data, profile) + try: + profiles.validate_profile_engine(profile, engine) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) tts_model = get_tts_backend_for_engine(engine) model_size = data.model_size or "1.7B" diff --git a/backend/routes/profiles.py b/backend/routes/profiles.py index 665f6b2..d65a113 100644 --- a/backend/routes/profiles.py +++ b/backend/routes/profiles.py @@ -4,8 +4,6 @@ import json as _json import logging import tempfile -import uuid -from datetime import datetime from pathlib import Path from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile @@ -107,125 +105,6 @@ async def list_preset_voices(engine: str): } return {"engine": engine, "voices": []} - -@router.post("/profiles/presets/{engine}/seed") -async def seed_preset_profiles_route( - engine: str, - db: Session = Depends(get_db), -): - """Seed preset voice profiles for an engine. - - Creates profiles for all available preset voices that don't already exist. - Returns the count of newly created profiles. - """ - if engine == "kokoro": - return _seed_kokoro_presets(db) - if engine == "qwen_custom_voice": - return _seed_qwen_custom_voice_presets(db) - raise HTTPException(status_code=400, detail=f"No presets available for engine: {engine}") - - -def _seed_kokoro_presets(db: Session): - """Seed Kokoro preset profiles.""" - try: - from ..backends.kokoro_backend import KOKORO_VOICES - - created = 0 - for voice_id, display_name, gender, lang in KOKORO_VOICES: - profile_name = display_name - - # Disambiguate duplicate display names across languages - # (e.g. "Alpha" exists in Hindi and Japanese, "Dora" in Spanish and Portuguese) - dupes = [v for v in KOKORO_VOICES if v[1] == display_name] - if len(dupes) > 1: - lang_labels = {"en": "English", "es": "Spanish", "fr": "French", "hi": "Hindi", - "it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese"} - profile_name = f"{display_name} {lang_labels.get(lang, lang)}" - - # Skip if preset already exists - existing = ( - db.query(DBVoiceProfile) - .filter_by(preset_engine="kokoro", preset_voice_id=voice_id) - .first() - ) - if existing: - continue - - unique_name = profile_name - suffix = 2 - while db.query(DBVoiceProfile).filter_by(name=unique_name).first(): - unique_name = f"{profile_name} {suffix}" - suffix += 1 - - profile = DBVoiceProfile( - id=str(uuid.uuid4()), - name=unique_name, - description=f"Kokoro preset voice — {display_name} ({gender})", - language=lang, - voice_type="preset", - preset_engine="kokoro", - preset_voice_id=voice_id, - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - db.add(profile) - created += 1 - - if created > 0: - db.commit() - logger.info(f"Seeded {created} Kokoro preset profiles") - - return {"engine": "kokoro", "created": created, "total_available": len(KOKORO_VOICES)} - except Exception as e: - logger.exception(f"Failed to seed Kokoro profiles: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -def _seed_qwen_custom_voice_presets(db: Session): - """Seed Qwen CustomVoice preset profiles.""" - try: - from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES - - created = 0 - for speaker_id, display_name, gender, lang, description in QWEN_CUSTOM_VOICES: - # Skip if preset already exists - existing = ( - db.query(DBVoiceProfile) - .filter_by(preset_engine="qwen_custom_voice", preset_voice_id=speaker_id) - .first() - ) - if existing: - continue - - # Skip name collisions - if db.query(DBVoiceProfile).filter_by(name=display_name).first(): - continue - - profile = DBVoiceProfile( - id=str(uuid.uuid4()), - name=display_name, - description=f"Qwen CustomVoice — {description}", - language=lang, - voice_type="preset", - preset_engine="qwen_custom_voice", - preset_voice_id=speaker_id, - default_engine="qwen_custom_voice", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), - ) - db.add(profile) - created += 1 - - if created > 0: - db.commit() - logger.info(f"Seeded {created} Qwen CustomVoice preset profiles") - - return {"engine": "qwen_custom_voice", "created": created, "total_available": len(QWEN_CUSTOM_VOICES)} - except Exception as e: - logger.exception(f"Failed to seed Qwen CustomVoice profiles: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - @router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse) async def get_profile( profile_id: str, diff --git a/backend/services/profiles.py b/backend/services/profiles.py index 16ea7b5..cd418aa 100644 --- a/backend/services/profiles.py +++ b/backend/services/profiles.py @@ -61,6 +61,20 @@ def _profile_to_response( ) +def _get_preset_voice_ids(engine: str) -> set[str]: + if engine == "kokoro": + from ..backends.kokoro_backend import KOKORO_VOICES + + return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES} + + if engine == "qwen_custom_voice": + from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES + + return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES} + + return set() + + def _validate_profile_fields( *, voice_type: str, @@ -74,6 +88,10 @@ def _validate_profile_fields( return "Preset profiles require both preset_engine and preset_voice_id" if default_engine and default_engine != preset_engine: return "Preset profiles must use their preset_engine as default_engine" + + available_voice_ids = _get_preset_voice_ids(preset_engine) + if available_voice_ids and preset_voice_id not in available_voice_ids: + return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'" return None if voice_type == "designed": @@ -92,6 +110,30 @@ def _validate_profile_fields( return None +def validate_profile_engine(profile, engine: str) -> None: + voice_type = getattr(profile, "voice_type", None) or "cloned" + + if voice_type == "preset": + preset_engine = getattr(profile, "preset_engine", None) + preset_voice_id = getattr(profile, "preset_voice_id", None) + if not preset_engine or not preset_voice_id: + raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata") + if preset_engine != engine: + raise ValueError( + f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'" + ) + return + + if voice_type == "designed": + design_prompt = getattr(profile, "design_prompt", None) + if not design_prompt or not design_prompt.strip(): + raise ValueError(f"Designed profile {profile.id} is missing design_prompt") + return + + if engine not in CLONING_ENGINES: + raise ValueError(f"Engine '{engine}' does not support cloned voice profiles") + + async def create_profile( data: VoiceProfileCreate, db: Session, @@ -476,6 +518,7 @@ async def create_voice_prompt_for_profile( raise ValueError(f"Profile not found: {profile_id}") voice_type = getattr(profile, "voice_type", None) or "cloned" + validate_profile_engine(profile, engine) # ── Preset profiles: return engine-specific voice reference ── if voice_type == "preset": diff --git a/backend/voicebox-server.spec b/backend/voicebox-server.spec index 1c208ba..c1acc54 100644 --- a/backend/voicebox-server.spec +++ b/backend/voicebox-server.spec @@ -5,7 +5,7 @@ from PyInstaller.utils.hooks import copy_metadata datas = [] binaries = [] -hiddenimports = ['backend', 'backend.main', 'backend.config', 'backend.database', 'backend.models', 'backend.services.profiles', 'backend.services.history', 'backend.services.tts', 'backend.services.transcribe', 'backend.utils.platform_detect', 'backend.backends', 'backend.backends.pytorch_backend', 'backend.utils.audio', 'backend.utils.cache', 'backend.utils.progress', 'backend.utils.hf_progress', 'backend.services.cuda', 'backend.services.effects', 'backend.utils.effects', 'backend.services.versions', 'pedalboard', 'chatterbox', 'chatterbox.tts_turbo', 'chatterbox.mtl_tts', 'backend.backends.chatterbox_backend', 'backend.backends.chatterbox_turbo_backend', 'backend.backends.luxtts_backend', 'zipvoice', 'zipvoice.luxvoice', 'torch', 'transformers', 'fastapi', 'uvicorn', 'sqlalchemy', 'soundfile', 'qwen_tts', 'qwen_tts.inference', 'qwen_tts.inference.qwen3_tts_model', 'qwen_tts.inference.qwen3_tts_tokenizer', 'qwen_tts.core', 'qwen_tts.cli', 'requests', 'pkg_resources.extern', 'backend.backends.hume_backend', 'tada', 'tada.modules', 'tada.modules.tada', 'tada.modules.encoder', 'tada.modules.decoder', 'tada.modules.aligner', 'tada.modules.acoustic_spkr_verf', 'tada.nn', 'tada.nn.vibevoice', 'tada.utils', 'tada.utils.gray_code', 'tada.utils.text', 'backend.utils.dac_shim', 'torchaudio', 'backend.backends.kokoro_backend', 'kokoro', 'kokoro.pipeline', 'kokoro.model', 'kokoro.istftnet', 'kokoro.modules', 'kokoro.custom_stft', 'en_core_web_sm', 'loguru', 'backend.backends.mlx_backend', 'mlx', 'mlx.core', 'mlx.nn', 'mlx_audio', 'mlx_audio.tts', 'mlx_audio.stt'] +hiddenimports = ['backend', 'backend.main', 'backend.config', 'backend.database', 'backend.models', 'backend.services.profiles', 'backend.services.history', 'backend.services.tts', 'backend.services.transcribe', 'backend.utils.platform_detect', 'backend.backends', 'backend.backends.pytorch_backend', 'backend.backends.qwen_custom_voice_backend', 'backend.utils.audio', 'backend.utils.cache', 'backend.utils.progress', 'backend.utils.hf_progress', 'backend.services.cuda', 'backend.services.effects', 'backend.utils.effects', 'backend.services.versions', 'pedalboard', 'chatterbox', 'chatterbox.tts_turbo', 'chatterbox.mtl_tts', 'backend.backends.chatterbox_backend', 'backend.backends.chatterbox_turbo_backend', 'backend.backends.luxtts_backend', 'zipvoice', 'zipvoice.luxvoice', 'torch', 'transformers', 'fastapi', 'uvicorn', 'sqlalchemy', 'soundfile', 'qwen_tts', 'qwen_tts.inference', 'qwen_tts.inference.qwen3_tts_model', 'qwen_tts.inference.qwen3_tts_tokenizer', 'qwen_tts.core', 'qwen_tts.cli', 'requests', 'pkg_resources.extern', 'backend.backends.hume_backend', 'tada', 'tada.modules', 'tada.modules.tada', 'tada.modules.encoder', 'tada.modules.decoder', 'tada.modules.aligner', 'tada.modules.acoustic_spkr_verf', 'tada.nn', 'tada.nn.vibevoice', 'tada.utils', 'tada.utils.gray_code', 'tada.utils.text', 'backend.utils.dac_shim', 'torchaudio', 'backend.backends.kokoro_backend', 'kokoro', 'kokoro.pipeline', 'kokoro.model', 'kokoro.istftnet', 'kokoro.modules', 'kokoro.custom_stft', 'en_core_web_sm', 'loguru', 'backend.backends.mlx_backend', 'mlx', 'mlx.core', 'mlx.nn', 'mlx_audio', 'mlx_audio.tts', 'mlx_audio.stt'] datas += copy_metadata('qwen-tts') datas += copy_metadata('requests') datas += copy_metadata('transformers') From e6f419cd702dbd8b58d7c97d6959ca8e3794117c Mon Sep 17 00:00:00 2001 From: James Pine Date: Thu, 19 Mar 2026 19:30:03 -0700 Subject: [PATCH 3/4] fix: show all engines in floating generator --- app/src/components/Generation/FloatingGenerateBox.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/src/components/Generation/FloatingGenerateBox.tsx b/app/src/components/Generation/FloatingGenerateBox.tsx index ae1cad2..f1cd571 100644 --- a/app/src/components/Generation/FloatingGenerateBox.tsx +++ b/app/src/components/Generation/FloatingGenerateBox.tsx @@ -408,7 +408,7 @@ export function FloatingGenerateBox({ /> - + From b108bb1cb183df0581daf72b7f11b41e09728bfd Mon Sep 17 00:00:00 2001 From: James Pine Date: Fri, 20 Mar 2026 15:06:07 -0700 Subject: [PATCH 4/4] fix: store media paths relative to data dir --- backend/config.py | 47 +++++++++++++++++++++++++++++++ backend/database/migrations.py | 40 ++++++++------------------ backend/database/seed.py | 6 ++-- backend/routes/audio.py | 16 +++++------ backend/routes/effects.py | 13 +++++---- backend/routes/history.py | 7 ++--- backend/routes/profiles.py | 6 ++-- backend/services/export_import.py | 20 +++++++------ backend/services/generation.py | 14 +++++---- backend/services/history.py | 8 +++--- backend/services/profiles.py | 28 +++++++++++------- backend/services/stories.py | 5 ++-- backend/services/versions.py | 8 +++--- 13 files changed, 131 insertions(+), 87 deletions(-) diff --git a/backend/config.py b/backend/config.py index ecfc392..959731c 100644 --- a/backend/config.py +++ b/backend/config.py @@ -22,6 +22,21 @@ _data_dir = Path("data").resolve() +def _path_relative_to_any_data_dir(path: Path) -> Path | None: + """Extract the path within a data dir from an absolute or relative path.""" + parts = path.parts + for idx, part in enumerate(parts): + if part != "data": + continue + + tail = parts[idx + 1 :] + if tail: + return Path(*tail) + return Path() + + return None + + def set_data_dir(path: str | Path): """ Set the data directory path. @@ -45,6 +60,38 @@ def get_data_dir() -> Path: return _data_dir +def to_storage_path(path: str | Path) -> str: + """Convert a filesystem path to a DB-safe path relative to the data dir.""" + resolved_path = Path(path).resolve() + + relative_to_any_data_dir = _path_relative_to_any_data_dir(resolved_path) + if relative_to_any_data_dir is not None: + return str(relative_to_any_data_dir) + + try: + return str(resolved_path.relative_to(_data_dir)) + except ValueError: + return str(resolved_path) + + +def resolve_storage_path(path: str | Path | None) -> Path | None: + """Resolve a DB-stored path against the configured data dir.""" + if path is None: + return None + + stored_path = Path(path) + if stored_path.is_absolute(): + rebased_path = _path_relative_to_any_data_dir(stored_path) + if rebased_path is not None: + candidate = (_data_dir / rebased_path).resolve() + if candidate.exists() or not stored_path.exists(): + return candidate + + return stored_path + + return (_data_dir / stored_path).resolve() + + def get_db_path() -> Path: """Get database file path.""" return _data_dir / "voicebox.db" diff --git a/backend/database/migrations.py b/backend/database/migrations.py index 6256d92..2bdd928 100644 --- a/backend/database/migrations.py +++ b/backend/database/migrations.py @@ -34,7 +34,7 @@ def run_migrations(engine) -> None: _migrate_generations(engine, inspector, tables) _migrate_effect_presets(engine, inspector, tables) _migrate_generation_versions(engine, inspector, tables) - _resolve_relative_paths(engine, tables) + _normalize_storage_paths(engine, tables) # -- helpers --------------------------------------------------------------- @@ -182,24 +182,11 @@ def _migrate_generation_versions(engine, inspector, tables: set[str]) -> None: _add_column(engine, "generation_versions", "source_version_id VARCHAR", "source_version_id") -def _resolve_relative_paths(engine, tables: set[str]) -> None: - """Resolve any relative file paths in the database to absolute paths. - - Earlier versions stored paths relative to CWD (e.g. "data/generations/abc.wav"). - These break when the production binary's CWD differs from the data directory. - This migration converts them to absolute paths using the configured data dir. - Idempotent: absolute paths are left untouched. - - Strategy: paths like "data/generations/abc.wav" are rebased onto the - configured data directory. If the path starts with "data/", strip that - prefix and prepend get_data_dir(). Otherwise, join the relative path - directly under get_data_dir(). - directly under get_data_dir(). If the rebased path still does not exist, - fall back to resolving relative to CWD. - """ +def _normalize_storage_paths(engine, tables: set[str]) -> None: + """Normalize stored file paths to be relative to the configured data dir.""" from pathlib import Path - from ..config import get_data_dir + from ..config import get_data_dir, to_storage_path, resolve_storage_path data_dir = get_data_dir() @@ -222,21 +209,18 @@ def _resolve_relative_paths(engine, tables: set[str]) -> None: if not path_val: continue p = Path(path_val) - if p.is_absolute(): + resolved = resolve_storage_path(p) + if resolved is None: continue - parts = p.parts - if parts and parts[0] == "data": - rebased = (data_dir / Path(*parts[1:])).resolve() - else: - rebased = (data_dir / p).resolve() - - resolved = rebased if rebased.exists() else p.resolve() - if resolved.exists(): + + normalized = to_storage_path(resolved) + + if normalized != path_val: conn.execute( text(f"UPDATE {table} SET {column} = :path WHERE id = :id"), - {"path": str(resolved), "id": row_id}, + {"path": normalized, "id": row_id}, ) total_fixed += 1 if total_fixed > 0: conn.commit() - logger.info("Resolved %d relative file paths to absolute", total_fixed) + logger.info("Normalized %d stored file paths", total_fixed) diff --git a/backend/database/seed.py b/backend/database/seed.py index b62edc2..b09e0b8 100644 --- a/backend/database/seed.py +++ b/backend/database/seed.py @@ -3,7 +3,8 @@ import json import logging import uuid -from pathlib import Path + +from .. import config logger = logging.getLogger(__name__) @@ -25,7 +26,8 @@ def backfill_generation_versions(SessionLocal, Generation, GenerationVersion) -> for gen in generations: if gen.id in existing_version_gen_ids: continue - if not Path(gen.audio_path).exists(): + resolved_audio_path = config.resolve_storage_path(gen.audio_path) + if resolved_audio_path is None or not resolved_audio_path.exists(): continue version = GenerationVersion( id=str(uuid.uuid4()), diff --git a/backend/routes/audio.py b/backend/routes/audio.py index 682d7aa..f80a44d 100644 --- a/backend/routes/audio.py +++ b/backend/routes/audio.py @@ -1,12 +1,10 @@ """Audio file serving endpoints.""" -from pathlib import Path - from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import FileResponse from sqlalchemy.orm import Session -from .. import models +from .. import config, models from ..services import history from ..database import get_db @@ -22,8 +20,8 @@ async def get_version_audio(version_id: str, db: Session = Depends(get_db)): if not version: raise HTTPException(status_code=404, detail="Version not found") - audio_path = Path(version.audio_path) - if not audio_path.exists(): + audio_path = config.resolve_storage_path(version.audio_path) + if audio_path is None or not audio_path.exists(): raise HTTPException(status_code=404, detail="Audio file not found") return FileResponse( @@ -40,8 +38,8 @@ async def get_audio(generation_id: str, db: Session = Depends(get_db)): if not generation: raise HTTPException(status_code=404, detail="Generation not found") - audio_path = Path(generation.audio_path) - if not audio_path.exists(): + audio_path = config.resolve_storage_path(generation.audio_path) + if audio_path is None or not audio_path.exists(): raise HTTPException(status_code=404, detail="Audio file not found") return FileResponse( @@ -60,8 +58,8 @@ async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)): if not sample: raise HTTPException(status_code=404, detail="Sample not found") - audio_path = Path(sample.audio_path) - if not audio_path.exists(): + audio_path = config.resolve_storage_path(sample.audio_path) + if audio_path is None or not audio_path.exists(): raise HTTPException(status_code=404, detail="Audio file not found") return FileResponse( diff --git a/backend/routes/effects.py b/backend/routes/effects.py index 8139176..52bbc8f 100644 --- a/backend/routes/effects.py +++ b/backend/routes/effects.py @@ -3,7 +3,6 @@ import asyncio import io import uuid -from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse @@ -41,10 +40,11 @@ async def preview_effects( all_versions = versions_mod.list_versions(generation_id, db) clean_version = next((v for v in all_versions if v.effects_chain is None), None) source_path = clean_version.audio_path if clean_version else gen.audio_path - if not source_path or not Path(source_path).exists(): + resolved_source_path = config.resolve_storage_path(source_path) + if resolved_source_path is None or not resolved_source_path.exists(): raise HTTPException(status_code=404, detail="Source audio file not found") - audio, sample_rate = await asyncio.to_thread(load_audio, source_path) + audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path)) processed = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts) import soundfile as sf @@ -193,10 +193,11 @@ async def apply_effects_to_generation( source_path = clean_version.audio_path source_version_id = clean_version.id - if not source_path or not Path(source_path).exists(): + resolved_source_path = config.resolve_storage_path(source_path) + if resolved_source_path is None or not resolved_source_path.exists(): raise HTTPException(status_code=404, detail="Source audio file not found") - audio, sample_rate = await asyncio.to_thread(load_audio, source_path) + audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path)) processed_audio = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts) version_id = str(uuid.uuid4()) @@ -208,7 +209,7 @@ async def apply_effects_to_generation( version = versions_mod.create_version( generation_id=generation_id, label=label, - audio_path=str(processed_path), + audio_path=config.to_storage_path(processed_path), db=db, effects_chain=chain_dicts, is_default=data.set_as_default, diff --git a/backend/routes/history.py b/backend/routes/history.py index 5435e29..f8233e3 100644 --- a/backend/routes/history.py +++ b/backend/routes/history.py @@ -1,13 +1,12 @@ """Generation history endpoints.""" import io -from pathlib import Path from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from fastapi.responses import FileResponse, StreamingResponse from sqlalchemy.orm import Session -from .. import models +from .. import config, models from ..services import export_import, history from ..app import safe_content_disposition from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db @@ -162,8 +161,8 @@ async def export_generation_audio( if not generation.audio_path: raise HTTPException(status_code=404, detail="Generation has no audio file") - audio_path = Path(generation.audio_path) - if not audio_path.is_file(): + audio_path = config.resolve_storage_path(generation.audio_path) + if audio_path is None or not audio_path.is_file(): raise HTTPException(status_code=404, detail="Audio file not found") safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip() diff --git a/backend/routes/profiles.py b/backend/routes/profiles.py index d65a113..7bc075c 100644 --- a/backend/routes/profiles.py +++ b/backend/routes/profiles.py @@ -10,7 +10,7 @@ from fastapi.responses import FileResponse, StreamingResponse from sqlalchemy.orm import Session -from .. import models +from .. import config, models from ..app import safe_content_disposition from ..database import VoiceProfile as DBVoiceProfile, get_db from ..services import channels, export_import, profiles @@ -258,8 +258,8 @@ async def get_profile_avatar( if not profile.avatar_path: raise HTTPException(status_code=404, detail="No avatar found for this profile") - avatar_path = Path(profile.avatar_path) - if not avatar_path.exists(): + avatar_path = config.resolve_storage_path(profile.avatar_path) + if avatar_path is None or not avatar_path.exists(): raise HTTPException(status_code=404, detail="Avatar file not found") return FileResponse(avatar_path) diff --git a/backend/services/export_import.py b/backend/services/export_import.py index 93252f5..514eaac 100644 --- a/backend/services/export_import.py +++ b/backend/services/export_import.py @@ -73,8 +73,8 @@ def export_profile_to_zip(profile_id: str, db: Session) -> bytes: # Check if profile has avatar has_avatar = False if profile.avatar_path: - avatar_path = Path(profile.avatar_path) - if avatar_path.exists(): + avatar_path = config.resolve_storage_path(profile.avatar_path) + if avatar_path is not None and avatar_path.exists(): has_avatar = True # Add avatar to ZIP root with original extension avatar_ext = avatar_path.suffix @@ -98,7 +98,9 @@ def export_profile_to_zip(profile_id: str, db: Session) -> bytes: for sample in samples: # Get filename from audio_path (should be {sample_id}.wav) - audio_path = Path(sample.audio_path) + audio_path = config.resolve_storage_path(sample.audio_path) + if audio_path is None: + raise ValueError(f"Audio file not found: {sample.audio_path}") filename = audio_path.name # Read audio file @@ -279,7 +281,7 @@ def export_generation_to_zip(generation_id: str, db: Session) -> bytes: # Build version manifest entries version_entries = [] for v in versions: - v_path = Path(v.audio_path) + v_path = config.resolve_storage_path(v.audio_path) effects_chain = None if v.effects_chain: effects_chain = json.loads(v.effects_chain) @@ -314,14 +316,14 @@ def export_generation_to_zip(generation_id: str, db: Session) -> bytes: # Add all version audio files for v in versions: - v_path = Path(v.audio_path) - if v_path.exists(): + v_path = config.resolve_storage_path(v.audio_path) + if v_path is not None and v_path.exists(): zip_file.write(v_path, f"audio/{v_path.name}") # Fallback: if no versions exist, include the generation's main audio if not versions: - audio_path = Path(generation.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(generation.audio_path) + if audio_path is not None and audio_path.exists(): zip_file.write(audio_path, f"audio/{audio_path.name}") zip_buffer.seek(0) @@ -426,7 +428,7 @@ async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict: profile_id=profile_id, text=generation_data["text"], language=generation_data["language"], - audio_path=str(audio_dest), + audio_path=config.to_storage_path(audio_dest), duration=generation_data["duration"], seed=generation_data.get("seed"), instruct=generation_data.get("instruct"), diff --git a/backend/services/generation.py b/backend/services/generation.py index d8d5214..a70e633 100644 --- a/backend/services/generation.py +++ b/backend/services/generation.py @@ -163,7 +163,7 @@ def _save_generate( versions_mod.create_version( generation_id=generation_id, label="original", - audio_path=str(clean_audio_path), + audio_path=config.to_storage_path(clean_audio_path), db=db, effects_chain=None, is_default=not has_effects, @@ -174,6 +174,8 @@ def _save_generate( if has_effects: from ..utils.effects import apply_effects, validate_effects_chain + assert effects_chain is not None + error_msg = validate_effects_chain(effects_chain) if error_msg: import logging @@ -189,13 +191,13 @@ def _save_generate( versions_mod.create_version( generation_id=generation_id, label="version-2", - audio_path=str(processed_path), + audio_path=config.to_storage_path(processed_path), db=db, effects_chain=effects_chain, is_default=True, ) - return final_audio_path + return config.to_storage_path(final_audio_path) def _save_retry( @@ -211,7 +213,7 @@ def _save_retry( """ audio_path = config.get_generations_dir() / f"{generation_id}.wav" save_audio(audio, str(audio_path), sample_rate) - return str(audio_path) + return config.to_storage_path(audio_path) def _save_regenerate( @@ -244,10 +246,10 @@ def _save_regenerate( versions_mod.create_version( generation_id=generation_id, label=label, - audio_path=str(audio_path), + audio_path=config.to_storage_path(audio_path), db=db, effects_chain=None, is_default=True, ) - return str(audio_path) + return config.to_storage_path(audio_path) diff --git a/backend/services/history.py b/backend/services/history.py index 8f45d48..473c4b3 100644 --- a/backend/services/history.py +++ b/backend/services/history.py @@ -253,8 +253,8 @@ async def delete_generation( # Delete main audio file (if not already removed by version cleanup) if generation.audio_path: - audio_path = Path(generation.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(generation.audio_path) + if audio_path is not None and audio_path.exists(): audio_path.unlink() # Delete from database @@ -283,8 +283,8 @@ async def delete_generations_by_profile( count = 0 for generation in generations: # Delete audio file - audio_path = Path(generation.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(generation.audio_path) + if audio_path is not None and audio_path.exists(): audio_path.unlink() # Delete from database diff --git a/backend/services/profiles.py b/backend/services/profiles.py index cd418aa..7883950 100644 --- a/backend/services/profiles.py +++ b/backend/services/profiles.py @@ -236,7 +236,7 @@ async def add_profile_sample( db_sample = DBProfileSample( id=sample_id, profile_id=profile_id, - audio_path=str(dest_path), + audio_path=config.to_storage_path(dest_path), reference_text=reference_text, ) @@ -441,8 +441,8 @@ async def delete_profile_sample( # Store profile_id before deleting profile_id = sample.profile_id - audio_path = Path(sample.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(sample.audio_path) + if audio_path is not None and audio_path.exists(): audio_path.unlink() db.delete(sample) @@ -556,14 +556,22 @@ async def create_voice_prompt_for_profile( if len(samples) == 1: sample = samples[0] + sample_audio_path = config.resolve_storage_path(sample.audio_path) + if sample_audio_path is None: + raise ValueError(f"Sample audio not found for profile {profile_id}") voice_prompt, _ = await tts_model.create_voice_prompt( - sample.audio_path, + str(sample_audio_path), sample.reference_text, use_cache=use_cache, ) return voice_prompt - audio_paths = [s.audio_path for s in samples] + audio_paths = [] + for sample in samples: + sample_audio_path = config.resolve_storage_path(sample.audio_path) + if sample_audio_path is None: + raise ValueError(f"Sample audio not found for profile {profile_id}") + audio_paths.append(str(sample_audio_path)) reference_texts = [s.reference_text for s in samples] combined_audio, combined_text = await tts_model.combine_voice_prompts( @@ -617,8 +625,8 @@ async def upload_avatar( raise ValueError(error_msg) if profile.avatar_path: - old_avatar = Path(profile.avatar_path) - if old_avatar.exists(): + old_avatar = config.resolve_storage_path(profile.avatar_path) + if old_avatar is not None and old_avatar.exists(): old_avatar.unlink() # Determine file extension from uploaded file @@ -639,7 +647,7 @@ async def upload_avatar( process_avatar(image_path, str(output_path)) - profile.avatar_path = str(output_path) + profile.avatar_path = config.to_storage_path(output_path) profile.updated_at = datetime.utcnow() db.commit() @@ -666,8 +674,8 @@ async def delete_avatar( if not profile or not profile.avatar_path: return False - avatar_path = Path(profile.avatar_path) - if avatar_path.exists(): + avatar_path = config.resolve_storage_path(profile.avatar_path) + if avatar_path is not None and avatar_path.exists(): avatar_path.unlink() profile.avatar_path = None diff --git a/backend/services/stories.py b/backend/services/stories.py index ac8e22b..611ab29 100644 --- a/backend/services/stories.py +++ b/backend/services/stories.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from sqlalchemy import func +from .. import config from ..models import ( StoryCreate, StoryResponse, @@ -826,8 +827,8 @@ async def export_story_audio( if version: resolved_audio_path = version.audio_path - audio_path = Path(resolved_audio_path) - if not audio_path.exists(): + audio_path = config.resolve_storage_path(resolved_audio_path) + if audio_path is None or not audio_path.exists(): continue try: diff --git a/backend/services/versions.py b/backend/services/versions.py index 1743a25..cbeb4b8 100644 --- a/backend/services/versions.py +++ b/backend/services/versions.py @@ -158,8 +158,8 @@ def delete_version(version_id: str, db: Session) -> bool: gen_id = version.generation_id # Delete audio file - audio_path = Path(version.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(version.audio_path) + if audio_path is not None and audio_path.exists(): audio_path.unlink() db.delete(version) @@ -193,8 +193,8 @@ def delete_versions_for_generation(generation_id: str, db: Session) -> int: ) count = 0 for v in versions: - audio_path = Path(v.audio_path) - if audio_path.exists(): + audio_path = config.resolve_storage_path(v.audio_path) + if audio_path is not None and audio_path.exists(): audio_path.unlink() db.delete(v) count += 1