Skip to content

Commit 1056cb6

Browse files
committed
fix: use correct Qwen3-TTS model methods
- Load both CustomVoice and Base models - Use generate_custom_voice() for basic TTS with preset speakers - Use generate_voice_clone() for voice cloning with ref_audio/ref_text - Add /v1/speakers endpoint to list available preset speakers - Add speaker parameter to /v1/audio/speech endpoint - Fix audio loading to properly read sample rate from reference files The Base model only supports generate_voice_clone(), not a generic generate() method. For basic TTS, we need the CustomVoice model.
1 parent 2f0cc07 commit 1056cb6

1 file changed

Lines changed: 81 additions & 44 deletions

File tree

server.py

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import io
99
import os
1010
from abc import ABC, abstractmethod
11-
from typing import Optional, Tuple
11+
from typing import Optional, Tuple, List
1212
from contextlib import asynccontextmanager
1313

1414
import torch
@@ -34,6 +34,7 @@ def synthesize(
3434
self,
3535
text: str,
3636
language: str = "English",
37+
speaker: Optional[str] = None,
3738
ref_audio: Optional[Tuple[bytes, int]] = None,
3839
ref_text: Optional[str] = None,
3940
speed: float = 1.0,
@@ -44,6 +45,7 @@ def synthesize(
4445
Args:
4546
text: Text to synthesize
4647
language: Target language
48+
speaker: Preset speaker name (for basic TTS)
4749
ref_audio: Optional (audio_bytes, sample_rate) for voice cloning
4850
ref_text: Optional transcript of reference audio for voice cloning
4951
speed: Speech speed multiplier
@@ -58,56 +60,79 @@ def get_info(self) -> dict:
5860
"""Return model information."""
5961
pass
6062

63+
@abstractmethod
64+
def get_speakers(self) -> List[str]:
65+
"""Return available preset speakers."""
66+
pass
67+
6168

6269
class Qwen3TTSBackend(TTSBackend):
63-
"""Qwen3-TTS backend with voice cloning support."""
70+
"""
71+
Qwen3-TTS backend with voice cloning support.
72+
73+
Uses CustomVoice model for basic TTS and Base model for voice cloning.
74+
"""
6475

65-
def __init__(self, model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"):
66-
self.model_name = model_name
67-
self.model = None
76+
def __init__(
77+
self,
78+
custom_voice_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
79+
base_model: str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
80+
):
81+
self.custom_voice_model_name = custom_voice_model
82+
self.base_model_name = base_model
83+
self.custom_voice_model = None
84+
self.base_model = None
85+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
86+
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
6887

6988
def load(self) -> None:
7089
from qwen_tts import Qwen3TTSModel
7190

72-
print(f"Loading Qwen3-TTS model: {self.model_name}")
73-
self.model = Qwen3TTSModel.from_pretrained(
74-
self.model_name,
75-
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
76-
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
91+
# Load CustomVoice model for basic TTS with preset speakers
92+
print(f"Loading CustomVoice model: {self.custom_voice_model_name}")
93+
self.custom_voice_model = Qwen3TTSModel.from_pretrained(
94+
self.custom_voice_model_name,
95+
device_map=self.device,
96+
dtype=self.dtype,
97+
)
98+
print("CustomVoice model loaded")
99+
100+
# Load Base model for voice cloning
101+
print(f"Loading Base model: {self.base_model_name}")
102+
self.base_model = Qwen3TTSModel.from_pretrained(
103+
self.base_model_name,
104+
device_map=self.device,
105+
dtype=self.dtype,
77106
)
78-
print("Qwen3-TTS model loaded successfully")
107+
print("Base model loaded")
79108

80109
def synthesize(
81110
self,
82111
text: str,
83112
language: str = "English",
113+
speaker: Optional[str] = None,
84114
ref_audio: Optional[Tuple[bytes, int]] = None,
85115
ref_text: Optional[str] = None,
86116
speed: float = 1.0,
87117
) -> Tuple[bytes, int]:
88-
if self.model is None:
89-
raise RuntimeError("Model not loaded")
118+
if self.custom_voice_model is None or self.base_model is None:
119+
raise RuntimeError("Models not loaded")
90120

91-
if ref_audio and ref_text:
92-
# Voice cloning with reference audio and text
93-
wavs, sr = self.model.generate_voice_clone(
121+
if ref_audio:
122+
# Voice cloning path - use Base model
123+
wavs, sr = self.base_model.generate_voice_clone(
94124
text=text,
95125
language=language,
96126
ref_audio=ref_audio,
97127
ref_text=ref_text,
98128
)
99-
elif ref_audio:
100-
# Voice cloning with just reference audio (model will auto-transcribe)
101-
wavs, sr = self.model.generate_voice_clone(
102-
text=text,
103-
language=language,
104-
ref_audio=ref_audio,
105-
)
106129
else:
107-
# Basic TTS without voice cloning
108-
wavs, sr = self.model.generate(
130+
# Basic TTS path - use CustomVoice model with preset speaker
131+
speaker = speaker or "Vivian"
132+
wavs, sr = self.custom_voice_model.generate_custom_voice(
109133
text=text,
110134
language=language,
135+
speaker=speaker,
111136
)
112137

113138
# Convert to WAV bytes
@@ -120,12 +145,23 @@ def synthesize(
120145
def get_info(self) -> dict:
121146
return {
122147
"backend": "qwen3-tts",
123-
"model": self.model_name,
148+
"custom_voice_model": self.custom_voice_model_name,
149+
"base_model": self.base_model_name,
124150
"supports_voice_cloning": True,
125151
"supports_ref_text": True,
126152
"device": "cuda" if torch.cuda.is_available() else "cpu",
127153
}
128154

155+
def get_speakers(self) -> List[str]:
156+
"""Return available preset speakers from CustomVoice model."""
157+
if self.custom_voice_model is None:
158+
return []
159+
try:
160+
return self.custom_voice_model.get_supported_speakers()
161+
except Exception:
162+
# Fallback to known speakers
163+
return ["Vivian", "Ryan", "Sophia", "Isabella", "Evan", "Lily"]
164+
129165

130166
# =============================================================================
131167
# Backend Registry
@@ -155,12 +191,7 @@ async def lifespan(app: FastAPI):
155191
global backend
156192

157193
backend_name = os.environ.get("TTS_BACKEND", "qwen3-tts")
158-
model_name = os.environ.get("TTS_MODEL", None)
159-
160194
backend = get_backend(backend_name)
161-
if model_name and hasattr(backend, 'model_name'):
162-
backend.model_name = model_name
163-
164195
backend.load()
165196

166197
yield
@@ -172,7 +203,7 @@ async def lifespan(app: FastAPI):
172203
app = FastAPI(
173204
title="TTS Server",
174205
description="Multi-model text-to-speech API with voice cloning support",
175-
version="0.1.0",
206+
version="0.2.0",
176207
lifespan=lifespan,
177208
)
178209

@@ -194,20 +225,30 @@ async def list_models():
194225
}
195226

196227

228+
@app.get("/v1/speakers")
229+
async def list_speakers():
230+
"""List available preset speakers for basic TTS."""
231+
if backend is None:
232+
raise HTTPException(status_code=503, detail="Model not loaded")
233+
return {"speakers": backend.get_speakers()}
234+
235+
197236
@app.post("/v1/audio/speech")
198237
async def synthesize_speech(
199238
text: str = Form(..., description="Text to synthesize"),
200239
language: str = Form("English", description="Target language"),
240+
speaker: Optional[str] = Form(None, description="Preset speaker for basic TTS (e.g., Vivian, Ryan)"),
201241
ref_audio: Optional[UploadFile] = File(None, description="Reference audio for voice cloning"),
202242
ref_text: Optional[str] = Form(None, description="Transcript of reference audio"),
203243
speed: float = Form(1.0, description="Speech speed multiplier"),
204244
):
205245
"""
206246
Synthesize speech from text.
207247
208-
For voice cloning, provide both ref_audio and ref_text.
209-
The ref_text should be the exact transcript of the reference audio
210-
for best voice cloning quality.
248+
**Basic TTS:** Just provide `text` and optionally `speaker`.
249+
250+
**Voice Cloning:** Provide `ref_audio` and `ref_text` for best quality.
251+
The `ref_text` should be the exact transcript of the reference audio.
211252
"""
212253
if backend is None:
213254
raise HTTPException(status_code=503, detail="Model not loaded")
@@ -217,13 +258,16 @@ async def synthesize_speech(
217258
ref_audio_data = None
218259
if ref_audio:
219260
audio_bytes = await ref_audio.read()
220-
# Assume 16kHz sample rate for reference audio, model will resample if needed
221-
ref_audio_data = (audio_bytes, 16000)
261+
# Load audio to get actual sample rate
262+
audio_buffer = io.BytesIO(audio_bytes)
263+
audio_data, sample_rate = sf.read(audio_buffer)
264+
ref_audio_data = (audio_data, sample_rate)
222265

223266
# Synthesize
224267
wav_bytes, sample_rate = backend.synthesize(
225268
text=text,
226269
language=language,
270+
speaker=speaker,
227271
ref_audio=ref_audio_data,
228272
ref_text=ref_text,
229273
speed=speed,
@@ -242,13 +286,6 @@ async def synthesize_speech(
242286
raise HTTPException(status_code=500, detail=str(e))
243287

244288

245-
@app.get("/docs", include_in_schema=False)
246-
async def docs_redirect():
247-
"""Redirect to Swagger UI."""
248-
from fastapi.responses import RedirectResponse
249-
return RedirectResponse(url="/docs")
250-
251-
252289
if __name__ == "__main__":
253290
import uvicorn
254291
uvicorn.run(app, host="0.0.0.0", port=8000)

0 commit comments

Comments
 (0)