diff --git a/app/src/components/ServerSettings/ConnectionForm.tsx b/app/src/components/ServerSettings/ConnectionForm.tsx index 44eeb4e6..3b5ad845 100644 --- a/app/src/components/ServerSettings/ConnectionForm.tsx +++ b/app/src/components/ServerSettings/ConnectionForm.tsx @@ -1,9 +1,12 @@ import { zodResolver } from '@hookform/resolvers/zod'; +import { Loader2, XCircle } from 'lucide-react'; import { useEffect } from 'react'; import { useForm } from 'react-hook-form'; import * as z from 'zod'; +import { Badge } from '@/components/ui/badge'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Checkbox } from '@/components/ui/checkbox'; import { Form, FormControl, @@ -14,10 +17,10 @@ import { FormMessage, } from '@/components/ui/form'; import { Input } from '@/components/ui/input'; -import { Checkbox } from '@/components/ui/checkbox'; import { useToast } from '@/components/ui/use-toast'; -import { useServerStore } from '@/stores/serverStore'; +import { useServerHealth } from '@/lib/hooks/useServer'; import { usePlatform } from '@/platform/PlatformContext'; +import { useServerStore } from '@/stores/serverStore'; const connectionSchema = z.object({ serverUrl: z.string().url('Please enter a valid URL'), @@ -34,6 +37,7 @@ export function ConnectionForm() { const mode = useServerStore((state) => state.mode); const setMode = useServerStore((state) => state.setMode); const { toast } = useToast(); + const { data: health, isLoading, error: healthError } = useServerHealth(); const form = useForm({ resolver: zodResolver(connectionSchema), @@ -51,7 +55,7 @@ export function ConnectionForm() { function onSubmit(data: ConnectionFormValues) { setServerUrl(data.serverUrl); - form.reset(data); // Reset form state after successful submission + form.reset(data); toast({ title: 'Server URL updated', description: `Connected to ${data.serverUrl}`, @@ -59,11 +63,7 @@ export function ConnectionForm() { } return ( - + Server Connection @@ -89,6 +89,37 @@ export function ConnectionForm() { + {/* Connection status */} +
+ {isLoading ? ( +
+ + Checking connection... +
+ ) : healthError ? ( +
+ + + Connection failed: {healthError.message} + +
+ ) : health ? ( +
+ + {health.model_loaded || health.model_downloaded ? 'Model Ready' : 'No Model'} + + + GPU: {health.gpu_available ? 'Available' : 'Not Available'} + + {health.vram_used_mb && ( + VRAM: {health.vram_used_mb.toFixed(0)} MB + )} +
+ ) : null} +
+
state.maxChunkChars); + const setMaxChunkChars = useServerStore((state) => state.setMaxChunkChars); + const crossfadeMs = useServerStore((state) => state.crossfadeMs); + const setCrossfadeMs = useServerStore((state) => state.setCrossfadeMs); + + return ( + + + Generation Settings + + Controls for long text generation. These settings apply to all engines. + + + +
+
+
+ + + {maxChunkChars} chars + +
+ setMaxChunkChars(value)} + min={100} + max={2000} + step={50} + aria-label="Auto-chunking character limit" + /> +

+ Long text is split into chunks at sentence boundaries before generating. Lower values + can improve quality for long outputs. +

+
+ +
+
+ + + {crossfadeMs === 0 ? 'Cut' : `${crossfadeMs}ms`} + +
+ setCrossfadeMs(value)} + min={0} + max={200} + step={10} + aria-label="Chunk crossfade duration" + /> +

+ Blends audio between chunks to smooth transitions. Set to 0 for a hard cut. +

+
+
+
+
+ ); +} diff --git a/app/src/components/ServerSettings/GpuAcceleration.tsx b/app/src/components/ServerSettings/GpuAcceleration.tsx index 69824d63..94cc1d6d 100644 --- a/app/src/components/ServerSettings/GpuAcceleration.tsx +++ b/app/src/components/ServerSettings/GpuAcceleration.tsx @@ -1,7 +1,6 @@ import { useQuery, useQueryClient } from '@tanstack/react-query'; -import { AlertCircle, Cpu, Download, Loader2, RotateCw, Trash2, Zap } from 'lucide-react'; +import { AlertCircle, Download, Loader2, RotateCw, Trash2 } from 'lucide-react'; import { useCallback, useEffect, useRef, useState } from 'react'; -import { Badge } from '@/components/ui/badge'; import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Progress } from '@/components/ui/progress'; @@ -216,31 +215,19 @@ export function GpuAcceleration() { return ( - - - GPU Acceleration - + GPU Acceleration {/* Current status */} -
-
-
Backend
-
- {isCurrentlyCuda ? 'CUDA (GPU accelerated)' : 'CPU'} -
+
+
Backend
+
+ {isCurrentlyCuda + ? 'CUDA (GPU accelerated)' + : hasNativeGpu + ? `${health.backend_type === 'mlx' ? 'MLX' : 'PyTorch'} (GPU accelerated)` + : 'CPU'}
- - {isCurrentlyCuda ? ( - <> - CUDA - - ) : ( - <> - CPU - - )} -
{/* GPU info from health */} @@ -257,14 +244,6 @@ export function GpuAcceleration() { )} {/* Native GPU detected - no CUDA download needed */} - {hasNativeGpu && ( -
-
- Your system uses {health.gpu_type} for acceleration. No additional - downloads needed. -
-
- )} {/* CUDA download section - only show when native GPU is NOT detected (i.e., Windows/Linux NVIDIA users) */} {!hasNativeGpu && ( diff --git a/app/src/components/ServerSettings/ModelManagement.tsx b/app/src/components/ServerSettings/ModelManagement.tsx index 9811f643..37e8b229 100644 --- a/app/src/components/ServerSettings/ModelManagement.tsx +++ b/app/src/components/ServerSettings/ModelManagement.tsx @@ -342,17 +342,18 @@ export function ModelManagement() { setDetailOpen(true); }; - const ttsModels = modelStatus?.models.filter((m) => m.model_name.startsWith('qwen-tts')) ?? []; - const otherTtsModels = + const voiceModels = modelStatus?.models.filter( - (m) => m.model_name.startsWith('luxtts') || m.model_name.startsWith('chatterbox'), + (m) => + m.model_name.startsWith('qwen-tts') || + m.model_name.startsWith('luxtts') || + m.model_name.startsWith('chatterbox'), ) ?? []; const whisperModels = modelStatus?.models.filter((m) => m.model_name.startsWith('whisper')) ?? []; // Build sections const sections: { label: string; models: ModelStatus[] }[] = [ - { label: 'Voice Generation', models: ttsModels }, - ...(otherTtsModels.length > 0 ? [{ label: 'Other Voice Models', models: otherTtsModels }] : []), + { label: 'Voice Generation', models: voiceModels }, { label: 'Transcription', models: whisperModels }, ]; @@ -564,12 +565,6 @@ export function ModelManagement() { Loaded )} - {freshSelectedModel.downloaded && !freshSelectedModel.loaded && ( - - - Downloaded - - )} {selectedState?.hasError && ( @@ -595,24 +590,6 @@ export function ModelManagement() { {hfModelInfo && (
- {/* Stats row */} -
- - - {formatDownloads(hfModelInfo.downloads)} - - - - {formatDownloads(hfModelInfo.likes)} - - {license && ( - - - {formatLicense(license)} - - )} -
- {/* Pipeline tag + author */}
{hfModelInfo.pipeline_tag && ( @@ -632,6 +609,24 @@ export function ModelManagement() { )}
+ {/* Stats row */} +
+ + + {formatDownloads(hfModelInfo.downloads)} + + + + {formatDownloads(hfModelInfo.likes)} + + {license && ( + + + {formatLicense(license)} + + )} +
+ {/* Languages */} {hfModelInfo.cardData?.language && hfModelInfo.cardData.language.length > 0 && (
@@ -647,8 +642,8 @@ export function ModelManagement() { {/* Disk size */} {freshSelectedModel.downloaded && freshSelectedModel.size_mb && ( -
- +
+ {formatSize(freshSelectedModel.size_mb)} on disk
)} @@ -661,7 +656,7 @@ export function ModelManagement() { )} {/* Actions */} -
+
{selectedState?.hasError ? ( <>
- {/* Model download progress */} -
- - - - - - -
- {isLoading ? (
diff --git a/app/src/components/ServerTab/ServerTab.tsx b/app/src/components/ServerTab/ServerTab.tsx index 1f32ac04..000ec5b7 100644 --- a/app/src/components/ServerTab/ServerTab.tsx +++ b/app/src/components/ServerTab/ServerTab.tsx @@ -1,19 +1,19 @@ import { ConnectionForm } from '@/components/ServerSettings/ConnectionForm'; +import { GenerationSettings } from '@/components/ServerSettings/GenerationSettings'; import { GpuAcceleration } from '@/components/ServerSettings/GpuAcceleration'; -import { ServerStatus } from '@/components/ServerSettings/ServerStatus'; import { UpdateStatus } from '@/components/ServerSettings/UpdateStatus'; import { usePlatform } from '@/platform/PlatformContext'; export function ServerTab() { const platform = usePlatform(); return ( -
+
- + + {platform.metadata.isTauri && } + {platform.metadata.isTauri && }
- {platform.metadata.isTauri && } - {platform.metadata.isTauri && }
Created by{' '} state.setAudioWithAutoPlay); const setIsGenerating = useGenerationStore((state) => state.setIsGenerating); + const maxChunkChars = useServerStore((state) => state.maxChunkChars); + const crossfadeMs = useServerStore((state) => state.crossfadeMs); const [downloadingModelName, setDownloadingModelName] = useState(null); const [downloadingDisplayName, setDownloadingDisplayName] = useState(null); @@ -110,6 +113,8 @@ export function useGenerationForm(options: UseGenerationFormOptions = {}) { model_size: isQwen ? data.modelSize : undefined, engine, instruct: isQwen ? data.instruct || undefined : undefined, + max_chunk_chars: maxChunkChars, + crossfade_ms: crossfadeMs, }); toast({ diff --git a/app/src/stores/serverStore.ts b/app/src/stores/serverStore.ts index 36d9f0af..1795b61c 100644 --- a/app/src/stores/serverStore.ts +++ b/app/src/stores/serverStore.ts @@ -13,6 +13,12 @@ interface ServerStore { keepServerRunningOnClose: boolean; setKeepServerRunningOnClose: (keepRunning: boolean) => void; + + maxChunkChars: number; + setMaxChunkChars: (value: number) => void; + + crossfadeMs: number; + setCrossfadeMs: (value: number) => void; } export const useServerStore = create()( @@ -29,6 +35,12 @@ export const useServerStore = create()( keepServerRunningOnClose: false, setKeepServerRunningOnClose: (keepRunning) => set({ keepServerRunningOnClose: keepRunning }), + + maxChunkChars: 800, + setMaxChunkChars: (value) => set({ maxChunkChars: value }), + + crossfadeMs: 50, + setCrossfadeMs: (value) => set({ crossfadeMs: value }), }), { name: 'voicebox-server', diff --git a/backend/main.py b/backend/main.py index 69fcc869..cb9a2bd3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -824,18 +824,25 @@ async def download_chatterbox_turbo_background(): engine=engine, ) - audio, sample_rate = await tts_model.generate( - data.text, - voice_prompt, - data.language, - data.seed, - data.instruct, - ) + from .utils.chunked_tts import generate_chunked - # Trim trailing silence/hallucination for Chatterbox output + # Resolve per-chunk trim function for engines that need it + trim_fn = None if engine in ("chatterbox", "chatterbox_turbo"): from .utils.audio import trim_tts_output - audio = trim_tts_output(audio, sample_rate) + trim_fn = trim_tts_output + + audio, sample_rate = await generate_chunked( + tts_model, + data.text, + voice_prompt, + language=data.language, + seed=data.seed, + instruct=data.instruct, + max_chunk_chars=data.max_chunk_chars, + crossfade_ms=data.crossfade_ms, + trim_fn=trim_fn, + ) # Calculate duration duration = len(audio) / sample_rate @@ -949,18 +956,24 @@ async def stream_speech( data.profile_id, db, engine=engine, ) - audio, sample_rate = await tts_model.generate( - data.text, - voice_prompt, - data.language, - data.seed, - data.instruct, - ) + from .utils.chunked_tts import generate_chunked - # Trim trailing silence/hallucination for Chatterbox output + trim_fn = None if engine in ("chatterbox", "chatterbox_turbo"): from .utils.audio import trim_tts_output - audio = trim_tts_output(audio, sample_rate) + trim_fn = trim_tts_output + + audio, sample_rate = await generate_chunked( + tts_model, + data.text, + voice_prompt, + language=data.language, + seed=data.seed, + instruct=data.instruct, + max_chunk_chars=data.max_chunk_chars, + crossfade_ms=data.crossfade_ms, + trim_fn=trim_fn, + ) wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate) diff --git a/backend/models.py b/backend/models.py index ebfe70e9..b462b67a 100644 --- a/backend/models.py +++ b/backend/models.py @@ -52,12 +52,14 @@ class Config: class GenerationRequest(BaseModel): """Request model for voice generation.""" profile_id: str - text: str = Field(..., min_length=1, max_length=5000) + text: str = Field(..., min_length=1, max_length=50000) language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he)$") seed: Optional[int] = Field(None, ge=0) model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B)$") instruct: Optional[str] = Field(None, max_length=500) engine: Optional[str] = Field(default="qwen", pattern="^(qwen|luxtts|chatterbox|chatterbox_turbo)$") + max_chunk_chars: int = Field(default=800, ge=100, le=5000, description="Max characters per chunk for long text splitting") + crossfade_ms: int = Field(default=50, ge=0, le=500, description="Crossfade duration in ms between chunks (0 for hard cut)") class GenerationResponse(BaseModel): diff --git a/backend/utils/chunked_tts.py b/backend/utils/chunked_tts.py new file mode 100644 index 00000000..53a454c6 --- /dev/null +++ b/backend/utils/chunked_tts.py @@ -0,0 +1,302 @@ +""" +Chunked TTS generation utilities. + +Splits long text into sentence-boundary chunks, generates audio per-chunk +via any TTSBackend, and concatenates with crossfade. All logic is +engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface. + +Short text (≤ max_chunk_chars) uses the single-shot fast path with zero +overhead. +""" + +import logging +import re +from typing import List, Tuple + +import numpy as np + +logger = logging.getLogger("voicebox.chunked-tts") + +# Default chunk size in characters. Can be overridden per-request via +# the ``max_chunk_chars`` field on GenerationRequest. +DEFAULT_MAX_CHUNK_CHARS = 800 + +# Common abbreviations that should NOT be treated as sentence endings. +# Lowercase for case-insensitive matching. +_ABBREVIATIONS = frozenset( + { + "mr", + "mrs", + "ms", + "dr", + "prof", + "sr", + "jr", + "st", + "ave", + "blvd", + "inc", + "ltd", + "corp", + "dept", + "est", + "approx", + "vs", + "etc", + "e.g", + "i.e", + "a.m", + "p.m", + "u.s", + "u.s.a", + "u.k", + } +) + +# Paralinguistic tags used by Chatterbox Turbo. The splitter must never +# cut inside one of these. +_PARA_TAG_RE = re.compile(r"\[[^\]]*\]") + + +# --------------------------------------------------------------------------- +# Text splitting +# --------------------------------------------------------------------------- + + +def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]: + """Split *text* at natural boundaries into chunks of at most *max_chars*. + + Priority: sentence-end (``.!?`` not preceded by an abbreviation and not + inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut. + + Paralinguistic tags like ``[laugh]`` are treated as atomic and will not + be split across chunks. + """ + text = text.strip() + if not text: + return [] + if len(text) <= max_chars: + return [text] + + chunks: List[str] = [] + remaining = text + + while remaining: + remaining = remaining.lstrip() + if not remaining: + break + if len(remaining) <= max_chars: + chunks.append(remaining) + break + + segment = remaining[:max_chars] + + # Try to split at the last real sentence ending + split_pos = _find_last_sentence_end(segment) + if split_pos == -1: + split_pos = _find_last_clause_boundary(segment) + if split_pos == -1: + split_pos = segment.rfind(" ") + if split_pos == -1: + # Absolute fallback: hard cut but avoid splitting inside a tag + split_pos = _safe_hard_cut(segment, max_chars) + + chunk = remaining[: split_pos + 1].strip() + if chunk: + chunks.append(chunk) + remaining = remaining[split_pos + 1 :] + + return chunks + + +def _find_last_sentence_end(text: str) -> int: + """Return the index of the last sentence-ending punctuation in *text*. + + Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.) + and periods inside bracket tags (``[laugh]``). Also handles CJK + sentence-ending punctuation (``。!?``). + """ + best = -1 + # ASCII sentence ends + for m in re.finditer(r"[.!?](?:\s|$)", text): + pos = m.start() + char = text[pos] + # Skip periods after abbreviations + if char == ".": + # Walk backwards to find the preceding word + word_start = pos - 1 + while word_start >= 0 and text[word_start].isalpha(): + word_start -= 1 + word = text[word_start + 1 : pos].lower() + if word in _ABBREVIATIONS: + continue + # Skip decimal numbers (digit immediately before the period) + if word_start >= 0 and text[word_start].isdigit(): + continue + # Skip if we're inside a bracket tag + if _inside_bracket_tag(text, pos): + continue + best = pos + # CJK sentence-ending punctuation + for m in re.finditer(r"[\u3002\uff01\uff1f]", text): + if m.start() > best: + best = m.start() + return best + + +def _find_last_clause_boundary(text: str) -> int: + """Return the index of the last clause-boundary punctuation.""" + best = -1 + for m in re.finditer(r"[;:,\u2014](?:\s|$)", text): + pos = m.start() + # Skip if inside a bracket tag + if _inside_bracket_tag(text, pos): + continue + best = pos + return best + + +def _inside_bracket_tag(text: str, pos: int) -> bool: + """Return True if *pos* falls inside a ``[...]`` tag.""" + for m in _PARA_TAG_RE.finditer(text): + if m.start() < pos < m.end(): + return True + return False + + +def _safe_hard_cut(segment: str, max_chars: int) -> int: + """Find a hard-cut position that doesn't split a ``[tag]``.""" + cut = max_chars - 1 + # Check if the cut falls inside a bracket tag; if so, move before it + for m in _PARA_TAG_RE.finditer(segment): + if m.start() < cut < m.end(): + return m.start() - 1 if m.start() > 0 else cut + return cut + + +# --------------------------------------------------------------------------- +# Audio concatenation +# --------------------------------------------------------------------------- + + +def concatenate_audio_chunks( + chunks: List[np.ndarray], + sample_rate: int, + crossfade_ms: int = 50, +) -> np.ndarray: + """Concatenate audio arrays with a short crossfade to eliminate clicks. + + Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz. + """ + if not chunks: + return np.array([], dtype=np.float32) + if len(chunks) == 1: + return chunks[0] + + crossfade_samples = int(sample_rate * crossfade_ms / 1000) + result = np.array(chunks[0], dtype=np.float32, copy=True) + + for chunk in chunks[1:]: + if len(chunk) == 0: + continue + overlap = min(crossfade_samples, len(result), len(chunk)) + if overlap > 0: + fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32) + fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32) + result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in + result = np.concatenate([result, chunk[overlap:]]) + else: + result = np.concatenate([result, chunk]) + + return result + + +# --------------------------------------------------------------------------- +# Engine-agnostic chunked generation +# --------------------------------------------------------------------------- + + +async def generate_chunked( + backend, + text: str, + voice_prompt: dict, + language: str = "en", + seed: int | None = None, + instruct: str | None = None, + max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS, + crossfade_ms: int = 50, + trim_fn=None, +) -> Tuple[np.ndarray, int]: + """Generate audio with automatic chunking for long text. + + For text shorter than *max_chunk_chars* this is a thin wrapper around + ``backend.generate()`` with zero overhead. + + For longer text the input is split at natural sentence boundaries, + each chunk is generated independently, optionally trimmed (useful for + Chatterbox engines that hallucinate trailing noise), and the results + are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0). + + Parameters + ---------- + backend : TTSBackend + Any backend implementing the ``generate()`` protocol. + text : str + Input text (may be arbitrarily long). + voice_prompt, language, seed, instruct + Forwarded to ``backend.generate()`` verbatim. + max_chunk_chars : int + Maximum characters per chunk (default 800). + crossfade_ms : int + Crossfade duration in milliseconds between chunks. 0 for a hard + cut with no overlap (default 50). + trim_fn : callable | None + Optional ``(audio, sample_rate) -> audio`` post-processing + function applied to each chunk before concatenation (e.g. + ``trim_tts_output`` for Chatterbox engines). + + Returns + ------- + (audio, sample_rate) : Tuple[np.ndarray, int] + """ + chunks = split_text_into_chunks(text, max_chunk_chars) + + if len(chunks) <= 1: + # Short text — single-shot fast path + audio, sample_rate = await backend.generate( + text, voice_prompt, language, seed, instruct, + ) + if trim_fn is not None: + audio = trim_fn(audio, sample_rate) + return audio, sample_rate + + # Long text — chunked generation + logger.info( + "Splitting %d chars into %d chunks (max %d chars each)", + len(text), len(chunks), max_chunk_chars, + ) + audio_chunks: List[np.ndarray] = [] + sample_rate: int | None = None + + for i, chunk_text in enumerate(chunks): + logger.info( + "Generating chunk %d/%d (%d chars)", + i + 1, len(chunks), len(chunk_text), + ) + # Vary the seed per chunk to avoid correlated RNG artefacts, + # but keep it deterministic so the same (text, seed) pair + # always produces the same output. + chunk_seed = (seed + i) if seed is not None else None + + chunk_audio, chunk_sr = await backend.generate( + chunk_text, voice_prompt, language, chunk_seed, instruct, + ) + if trim_fn is not None: + chunk_audio = trim_fn(chunk_audio, chunk_sr) + + audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32)) + if sample_rate is None: + sample_rate = chunk_sr + + audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms) + return audio, sample_rate