Skip to content

Commit 9cd19f1

Browse files
committed
feat: add voice prompt extraction for cached voice cloning
- Add POST /v1/voice/extract endpoint to extract reusable voice prompts - Update POST /v1/audio/speech to accept voice_prompt parameter - Voice prompts are serialized as base64-encoded numpy arrays - Allows caching voice embeddings to avoid re-processing ref_audio - Add tests for voice prompt extraction and synthesis Workflow: 1. Extract: POST /v1/voice/extract with ref_audio → get voice_prompt 2. Reuse: POST /v1/audio/speech with voice_prompt → faster synthesis
1 parent 2cb1b48 commit 9cd19f1

File tree

4 files changed

+332
-22
lines changed

4 files changed

+332
-22
lines changed

README.md

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ Multi-model text-to-speech API with voice cloning support.
66

77
- **Multi-model architecture** - Pluggable backend system for different TTS models
88
- **Voice cloning** - Clone voices with reference audio + transcript
9+
- **Voice prompt caching** - Extract and reuse voice embeddings for faster synthesis
910
- **ref_text support** - Provide transcript for better voice cloning quality
1011
- **GPU accelerated** - CUDA support for fast inference
1112

1213
## Supported Backends
1314

14-
| Backend | Voice Cloning | ref_text Support |
15-
|---------|--------------|------------------|
16-
| `qwen3-tts` |||
15+
| Backend | Voice Cloning | ref_text | Voice Prompt |
16+
|---------|--------------|----------|--------------|
17+
| `qwen3-tts` ||||
1718

1819
## API Endpoints
1920

@@ -25,8 +26,9 @@ Synthesize speech from text.
2526
- `text` (required): Text to synthesize
2627
- `language`: Target language (default: "English")
2728
- `speaker`: Preset speaker for basic TTS (e.g., "Vivian", "Ryan")
28-
- `ref_audio`: Reference audio file for voice cloning
29+
- `ref_audio`: Reference audio file for voice cloning (on-the-fly)
2930
- `ref_text`: Transcript of reference audio (improves cloning quality)
31+
- `voice_prompt`: Pre-extracted voice prompt from `/v1/voice/extract` (cached)
3032
- `speed`: Speech speed multiplier (default: 1.0)
3133

3234
**Example:**
@@ -42,12 +44,51 @@ curl -X POST http://localhost:8000/v1/audio/speech \
4244
-F "speaker=Ryan" \
4345
-o output.wav
4446

45-
# Voice cloning with ref_text
47+
# Voice cloning (on-the-fly) - processes ref_audio each time
4648
curl -X POST http://localhost:8000/v1/audio/speech \
4749
-F "text=Hello, this is my cloned voice." \
4850
-F "ref_audio=@reference.wav" \
4951
-F "ref_text=This is the transcript of my reference audio." \
5052
-o cloned.wav
53+
54+
# Voice cloning (cached) - faster, uses pre-extracted prompt
55+
curl -X POST http://localhost:8000/v1/audio/speech \
56+
-F "text=Hello, this is my cloned voice." \
57+
-F "voice_prompt=$VOICE_PROMPT" \
58+
-o cloned.wav
59+
```
60+
61+
### `POST /v1/voice/extract`
62+
63+
Extract a reusable voice prompt from reference audio. The returned prompt can be cached and reused with `/v1/audio/speech` to avoid re-processing the reference audio on every request.
64+
65+
**Parameters:**
66+
- `ref_audio` (required): Reference audio file
67+
- `ref_text`: Transcript of reference audio (improves quality)
68+
- `language`: Language of the reference audio (default: "English")
69+
70+
**Returns:**
71+
- `voice_prompt`: Base64-encoded voice embedding (store this)
72+
- `format`: Encoding format (e.g., "base64-numpy")
73+
74+
**Example:**
75+
```bash
76+
# Extract voice prompt
77+
VOICE_PROMPT=$(curl -X POST http://localhost:8000/v1/voice/extract \
78+
-F "ref_audio=@reference.wav" \
79+
-F "ref_text=This is the transcript of my reference audio." \
80+
| jq -r '.voice_prompt')
81+
82+
# Use the cached prompt for multiple synthesis requests
83+
curl -X POST http://localhost:8000/v1/audio/speech \
84+
-F "text=First sentence with cloned voice." \
85+
-F "voice_prompt=$VOICE_PROMPT" \
86+
-o output1.wav
87+
88+
curl -X POST http://localhost:8000/v1/audio/speech \
89+
-F "text=Second sentence with same voice." \
90+
-F "voice_prompt=$VOICE_PROMPT" \
91+
-o output2.wav
5192
```
5293

5394
### `GET /health`

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ python-multipart>=0.0.12
55

66
# Audio processing
77
soundfile>=0.12.1
8+
numpy>=1.24.0
89

910
# Qwen3-TTS backend
1011
qwen-tts>=0.0.5

server.py

Lines changed: 190 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@
33
44
Supported models:
55
- qwen3-tts: Qwen3-TTS with voice cloning via ref_audio + ref_text
6+
7+
Voice cloning can be done two ways:
8+
1. On-the-fly: Pass ref_audio + ref_text with each synthesis request
9+
2. Cached: Extract a voice_prompt once, then reuse it for multiple requests
610
"""
711

812
import io
913
import os
14+
import base64
1015
from abc import ABC, abstractmethod
11-
from typing import Optional, Tuple, List
16+
from typing import Optional, Tuple, List, Any
1217
from contextlib import asynccontextmanager
1318

1419
import torch
20+
import numpy as np
1521
import soundfile as sf
1622
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
1723
from fastapi.responses import StreamingResponse, JSONResponse
@@ -65,6 +71,46 @@ def get_speakers(self) -> List[str]:
6571
"""Return available preset speakers."""
6672
pass
6773

74+
def extract_voice_prompt(
75+
self,
76+
ref_audio: Tuple[Any, int],
77+
ref_text: Optional[str] = None,
78+
language: str = "English",
79+
) -> str:
80+
"""
81+
Extract a reusable voice prompt from reference audio.
82+
83+
Args:
84+
ref_audio: Tuple of (audio_data, sample_rate)
85+
ref_text: Optional transcript of reference audio
86+
language: Language of the reference audio
87+
88+
Returns:
89+
Base64-encoded voice prompt that can be reused
90+
"""
91+
raise NotImplementedError("This backend does not support voice prompt extraction")
92+
93+
def synthesize_with_prompt(
94+
self,
95+
text: str,
96+
voice_prompt: str,
97+
language: str = "English",
98+
speed: float = 1.0,
99+
) -> Tuple[bytes, int]:
100+
"""
101+
Synthesize speech using a pre-extracted voice prompt.
102+
103+
Args:
104+
text: Text to synthesize
105+
voice_prompt: Base64-encoded voice prompt from extract_voice_prompt
106+
language: Target language
107+
speed: Speech speed multiplier
108+
109+
Returns:
110+
Tuple of (wav_bytes, sample_rate)
111+
"""
112+
raise NotImplementedError("This backend does not support voice prompt synthesis")
113+
68114

69115
class Qwen3TTSBackend(TTSBackend):
70116
"""
@@ -149,6 +195,7 @@ def get_info(self) -> dict:
149195
"base_model": self.base_model_name,
150196
"supports_voice_cloning": True,
151197
"supports_ref_text": True,
198+
"supports_voice_prompt": True,
152199
"device": "cuda" if torch.cuda.is_available() else "cpu",
153200
}
154201

@@ -162,6 +209,65 @@ def get_speakers(self) -> List[str]:
162209
# Fallback to known speakers
163210
return ["Vivian", "Ryan", "Sophia", "Isabella", "Evan", "Lily"]
164211

212+
def extract_voice_prompt(
213+
self,
214+
ref_audio: Tuple[Any, int],
215+
ref_text: Optional[str] = None,
216+
language: str = "English",
217+
) -> str:
218+
"""Extract a reusable voice prompt from reference audio."""
219+
if self.base_model is None:
220+
raise RuntimeError("Base model not loaded")
221+
222+
# Use the Base model's create_voice_clone_prompt method
223+
voice_prompt = self.base_model.create_voice_clone_prompt(
224+
ref_audio=ref_audio,
225+
ref_text=ref_text,
226+
language=language,
227+
)
228+
229+
# Serialize to base64 - voice_prompt is typically tensor data
230+
# Convert to numpy, then to bytes, then base64
231+
if hasattr(voice_prompt, 'cpu'):
232+
# It's a torch tensor
233+
prompt_np = voice_prompt.cpu().numpy()
234+
else:
235+
prompt_np = np.array(voice_prompt)
236+
237+
buffer = io.BytesIO()
238+
np.save(buffer, prompt_np, allow_pickle=False)
239+
return base64.b64encode(buffer.getvalue()).decode('utf-8')
240+
241+
def synthesize_with_prompt(
242+
self,
243+
text: str,
244+
voice_prompt: str,
245+
language: str = "English",
246+
speed: float = 1.0,
247+
) -> Tuple[bytes, int]:
248+
"""Synthesize speech using a pre-extracted voice prompt."""
249+
if self.base_model is None:
250+
raise RuntimeError("Base model not loaded")
251+
252+
# Decode the voice prompt
253+
buffer = io.BytesIO(base64.b64decode(voice_prompt))
254+
prompt_np = np.load(buffer, allow_pickle=False)
255+
prompt_tensor = torch.from_numpy(prompt_np).to(self.device)
256+
257+
# Generate using the cached voice prompt
258+
wavs, sr = self.base_model.generate_voice_clone(
259+
text=text,
260+
language=language,
261+
voice_clone_prompt=prompt_tensor,
262+
)
263+
264+
# Convert to WAV bytes
265+
wav_buffer = io.BytesIO()
266+
sf.write(wav_buffer, wavs[0], sr, format='WAV')
267+
wav_buffer.seek(0)
268+
269+
return wav_buffer.read(), sr
270+
165271

166272
# =============================================================================
167273
# Backend Registry
@@ -203,7 +309,7 @@ async def lifespan(app: FastAPI):
203309
app = FastAPI(
204310
title="TTS Server",
205311
description="Multi-model text-to-speech API with voice cloning support",
206-
version="0.2.0",
312+
version="0.3.0",
207313
lifespan=lifespan,
208314
)
209315

@@ -233,45 +339,111 @@ async def list_speakers():
233339
return {"speakers": backend.get_speakers()}
234340

235341

342+
@app.post("/v1/voice/extract")
343+
async def extract_voice_prompt(
344+
ref_audio: UploadFile = File(..., description="Reference audio for voice extraction"),
345+
ref_text: Optional[str] = Form(None, description="Transcript of reference audio"),
346+
language: str = Form("English", description="Language of the reference audio"),
347+
):
348+
"""
349+
Extract a reusable voice prompt from reference audio.
350+
351+
The returned `voice_prompt` can be cached and reused with `/v1/audio/speech`
352+
to avoid re-processing the reference audio on every request.
353+
354+
**Returns:**
355+
- `voice_prompt`: Base64-encoded voice embedding (store this)
356+
- `format`: Always "base64-numpy" for this backend
357+
"""
358+
if backend is None:
359+
raise HTTPException(status_code=503, detail="Model not loaded")
360+
361+
try:
362+
# Read reference audio
363+
audio_bytes = await ref_audio.read()
364+
audio_buffer = io.BytesIO(audio_bytes)
365+
audio_data, sample_rate = sf.read(audio_buffer)
366+
ref_audio_data = (audio_data, sample_rate)
367+
368+
# Extract voice prompt
369+
voice_prompt = backend.extract_voice_prompt(
370+
ref_audio=ref_audio_data,
371+
ref_text=ref_text,
372+
language=language,
373+
)
374+
375+
return JSONResponse({
376+
"voice_prompt": voice_prompt,
377+
"format": "base64-numpy",
378+
"ref_text": ref_text,
379+
"language": language,
380+
})
381+
382+
except NotImplementedError as e:
383+
raise HTTPException(status_code=501, detail=str(e))
384+
except Exception as e:
385+
raise HTTPException(status_code=500, detail=str(e))
386+
387+
236388
@app.post("/v1/audio/speech")
237389
async def synthesize_speech(
238390
text: str = Form(..., description="Text to synthesize"),
239391
language: str = Form("English", description="Target language"),
240392
speaker: Optional[str] = Form(None, description="Preset speaker for basic TTS (e.g., Vivian, Ryan)"),
241393
ref_audio: Optional[UploadFile] = File(None, description="Reference audio for voice cloning"),
242394
ref_text: Optional[str] = Form(None, description="Transcript of reference audio"),
395+
voice_prompt: Optional[str] = Form(None, description="Pre-extracted voice prompt from /v1/voice/extract"),
243396
speed: float = Form(1.0, description="Speech speed multiplier"),
244397
):
245398
"""
246399
Synthesize speech from text.
247400
248401
**Basic TTS:** Just provide `text` and optionally `speaker`.
249402
250-
**Voice Cloning:** Provide `ref_audio` and `ref_text` for best quality.
403+
**Voice Cloning (on-the-fly):** Provide `ref_audio` and `ref_text`.
251404
The `ref_text` should be the exact transcript of the reference audio.
405+
406+
**Voice Cloning (cached):** Provide `voice_prompt` from `/v1/voice/extract`.
407+
This is faster as it skips re-processing the reference audio.
252408
"""
253409
if backend is None:
254410
raise HTTPException(status_code=503, detail="Model not loaded")
255411

256412
try:
257-
# Read reference audio if provided
258-
ref_audio_data = None
259-
if ref_audio:
413+
# Priority: voice_prompt > ref_audio > speaker
414+
if voice_prompt:
415+
# Use pre-extracted voice prompt
416+
wav_bytes, sample_rate = backend.synthesize_with_prompt(
417+
text=text,
418+
voice_prompt=voice_prompt,
419+
language=language,
420+
speed=speed,
421+
)
422+
elif ref_audio:
423+
# On-the-fly voice cloning
260424
audio_bytes = await ref_audio.read()
261-
# Load audio to get actual sample rate
262425
audio_buffer = io.BytesIO(audio_bytes)
263426
audio_data, sample_rate = sf.read(audio_buffer)
264427
ref_audio_data = (audio_data, sample_rate)
265428

266-
# Synthesize
267-
wav_bytes, sample_rate = backend.synthesize(
268-
text=text,
269-
language=language,
270-
speaker=speaker,
271-
ref_audio=ref_audio_data,
272-
ref_text=ref_text,
273-
speed=speed,
274-
)
429+
wav_bytes, sample_rate = backend.synthesize(
430+
text=text,
431+
language=language,
432+
speaker=speaker,
433+
ref_audio=ref_audio_data,
434+
ref_text=ref_text,
435+
speed=speed,
436+
)
437+
else:
438+
# Basic TTS with preset speaker
439+
wav_bytes, sample_rate = backend.synthesize(
440+
text=text,
441+
language=language,
442+
speaker=speaker,
443+
ref_audio=None,
444+
ref_text=None,
445+
speed=speed,
446+
)
275447

276448
return StreamingResponse(
277449
io.BytesIO(wav_bytes),
@@ -282,6 +454,8 @@ async def synthesize_speech(
282454
}
283455
)
284456

457+
except NotImplementedError as e:
458+
raise HTTPException(status_code=501, detail=str(e))
285459
except Exception as e:
286460
raise HTTPException(status_code=500, detail=str(e))
287461

0 commit comments

Comments
 (0)