88import io
99import os
1010from abc import ABC , abstractmethod
11- from typing import Optional , Tuple
11+ from typing import Optional , Tuple , List
1212from contextlib import asynccontextmanager
1313
1414import torch
@@ -34,6 +34,7 @@ def synthesize(
3434 self ,
3535 text : str ,
3636 language : str = "English" ,
37+ speaker : Optional [str ] = None ,
3738 ref_audio : Optional [Tuple [bytes , int ]] = None ,
3839 ref_text : Optional [str ] = None ,
3940 speed : float = 1.0 ,
@@ -44,6 +45,7 @@ def synthesize(
4445 Args:
4546 text: Text to synthesize
4647 language: Target language
48+ speaker: Preset speaker name (for basic TTS)
4749 ref_audio: Optional (audio_bytes, sample_rate) for voice cloning
4850 ref_text: Optional transcript of reference audio for voice cloning
4951 speed: Speech speed multiplier
@@ -58,56 +60,79 @@ def get_info(self) -> dict:
5860 """Return model information."""
5961 pass
6062
63+ @abstractmethod
64+ def get_speakers (self ) -> List [str ]:
65+ """Return available preset speakers."""
66+ pass
67+
6168
6269class Qwen3TTSBackend (TTSBackend ):
63- """Qwen3-TTS backend with voice cloning support."""
70+ """
71+ Qwen3-TTS backend with voice cloning support.
72+
73+ Uses CustomVoice model for basic TTS and Base model for voice cloning.
74+ """
6475
65- def __init__ (self , model_name : str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ):
66- self .model_name = model_name
67- self .model = None
76+ def __init__ (
77+ self ,
78+ custom_voice_model : str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" ,
79+ base_model : str = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ,
80+ ):
81+ self .custom_voice_model_name = custom_voice_model
82+ self .base_model_name = base_model
83+ self .custom_voice_model = None
84+ self .base_model = None
85+ self .device = "cuda:0" if torch .cuda .is_available () else "cpu"
86+ self .dtype = torch .bfloat16 if torch .cuda .is_available () else torch .float32
6887
6988 def load (self ) -> None :
7089 from qwen_tts import Qwen3TTSModel
7190
72- print (f"Loading Qwen3-TTS model: { self .model_name } " )
73- self .model = Qwen3TTSModel .from_pretrained (
74- self .model_name ,
75- device_map = "cuda:0" if torch .cuda .is_available () else "cpu" ,
76- dtype = torch .bfloat16 if torch .cuda .is_available () else torch .float32 ,
91+ # Load CustomVoice model for basic TTS with preset speakers
92+ print (f"Loading CustomVoice model: { self .custom_voice_model_name } " )
93+ self .custom_voice_model = Qwen3TTSModel .from_pretrained (
94+ self .custom_voice_model_name ,
95+ device_map = self .device ,
96+ dtype = self .dtype ,
97+ )
98+ print ("CustomVoice model loaded" )
99+
100+ # Load Base model for voice cloning
101+ print (f"Loading Base model: { self .base_model_name } " )
102+ self .base_model = Qwen3TTSModel .from_pretrained (
103+ self .base_model_name ,
104+ device_map = self .device ,
105+ dtype = self .dtype ,
77106 )
78- print ("Qwen3-TTS model loaded successfully " )
107+ print ("Base model loaded" )
79108
80109 def synthesize (
81110 self ,
82111 text : str ,
83112 language : str = "English" ,
113+ speaker : Optional [str ] = None ,
84114 ref_audio : Optional [Tuple [bytes , int ]] = None ,
85115 ref_text : Optional [str ] = None ,
86116 speed : float = 1.0 ,
87117 ) -> Tuple [bytes , int ]:
88- if self .model is None :
89- raise RuntimeError ("Model not loaded" )
118+ if self .custom_voice_model is None or self . base_model is None :
119+ raise RuntimeError ("Models not loaded" )
90120
91- if ref_audio and ref_text :
92- # Voice cloning with reference audio and text
93- wavs , sr = self .model .generate_voice_clone (
121+ if ref_audio :
122+ # Voice cloning path - use Base model
123+ wavs , sr = self .base_model .generate_voice_clone (
94124 text = text ,
95125 language = language ,
96126 ref_audio = ref_audio ,
97127 ref_text = ref_text ,
98128 )
99- elif ref_audio :
100- # Voice cloning with just reference audio (model will auto-transcribe)
101- wavs , sr = self .model .generate_voice_clone (
102- text = text ,
103- language = language ,
104- ref_audio = ref_audio ,
105- )
106129 else :
107- # Basic TTS without voice cloning
108- wavs , sr = self .model .generate (
130+ # Basic TTS path - use CustomVoice model with preset speaker
131+ speaker = speaker or "Vivian"
132+ wavs , sr = self .custom_voice_model .generate_custom_voice (
109133 text = text ,
110134 language = language ,
135+ speaker = speaker ,
111136 )
112137
113138 # Convert to WAV bytes
@@ -120,12 +145,23 @@ def synthesize(
120145 def get_info (self ) -> dict :
121146 return {
122147 "backend" : "qwen3-tts" ,
123- "model" : self .model_name ,
148+ "custom_voice_model" : self .custom_voice_model_name ,
149+ "base_model" : self .base_model_name ,
124150 "supports_voice_cloning" : True ,
125151 "supports_ref_text" : True ,
126152 "device" : "cuda" if torch .cuda .is_available () else "cpu" ,
127153 }
128154
155+ def get_speakers (self ) -> List [str ]:
156+ """Return available preset speakers from CustomVoice model."""
157+ if self .custom_voice_model is None :
158+ return []
159+ try :
160+ return self .custom_voice_model .get_supported_speakers ()
161+ except Exception :
162+ # Fallback to known speakers
163+ return ["Vivian" , "Ryan" , "Sophia" , "Isabella" , "Evan" , "Lily" ]
164+
129165
130166# =============================================================================
131167# Backend Registry
@@ -155,12 +191,7 @@ async def lifespan(app: FastAPI):
155191 global backend
156192
157193 backend_name = os .environ .get ("TTS_BACKEND" , "qwen3-tts" )
158- model_name = os .environ .get ("TTS_MODEL" , None )
159-
160194 backend = get_backend (backend_name )
161- if model_name and hasattr (backend , 'model_name' ):
162- backend .model_name = model_name
163-
164195 backend .load ()
165196
166197 yield
@@ -172,7 +203,7 @@ async def lifespan(app: FastAPI):
172203app = FastAPI (
173204 title = "TTS Server" ,
174205 description = "Multi-model text-to-speech API with voice cloning support" ,
175- version = "0.1 .0" ,
206+ version = "0.2 .0" ,
176207 lifespan = lifespan ,
177208)
178209
@@ -194,20 +225,30 @@ async def list_models():
194225 }
195226
196227
228+ @app .get ("/v1/speakers" )
229+ async def list_speakers ():
230+ """List available preset speakers for basic TTS."""
231+ if backend is None :
232+ raise HTTPException (status_code = 503 , detail = "Model not loaded" )
233+ return {"speakers" : backend .get_speakers ()}
234+
235+
197236@app .post ("/v1/audio/speech" )
198237async def synthesize_speech (
199238 text : str = Form (..., description = "Text to synthesize" ),
200239 language : str = Form ("English" , description = "Target language" ),
240+ speaker : Optional [str ] = Form (None , description = "Preset speaker for basic TTS (e.g., Vivian, Ryan)" ),
201241 ref_audio : Optional [UploadFile ] = File (None , description = "Reference audio for voice cloning" ),
202242 ref_text : Optional [str ] = Form (None , description = "Transcript of reference audio" ),
203243 speed : float = Form (1.0 , description = "Speech speed multiplier" ),
204244):
205245 """
206246 Synthesize speech from text.
207247
208- For voice cloning, provide both ref_audio and ref_text.
209- The ref_text should be the exact transcript of the reference audio
210- for best voice cloning quality.
248+ **Basic TTS:** Just provide `text` and optionally `speaker`.
249+
250+ **Voice Cloning:** Provide `ref_audio` and `ref_text` for best quality.
251+ The `ref_text` should be the exact transcript of the reference audio.
211252 """
212253 if backend is None :
213254 raise HTTPException (status_code = 503 , detail = "Model not loaded" )
@@ -217,13 +258,16 @@ async def synthesize_speech(
217258 ref_audio_data = None
218259 if ref_audio :
219260 audio_bytes = await ref_audio .read ()
220- # Assume 16kHz sample rate for reference audio, model will resample if needed
221- ref_audio_data = (audio_bytes , 16000 )
261+ # Load audio to get actual sample rate
262+ audio_buffer = io .BytesIO (audio_bytes )
263+ audio_data , sample_rate = sf .read (audio_buffer )
264+ ref_audio_data = (audio_data , sample_rate )
222265
223266 # Synthesize
224267 wav_bytes , sample_rate = backend .synthesize (
225268 text = text ,
226269 language = language ,
270+ speaker = speaker ,
227271 ref_audio = ref_audio_data ,
228272 ref_text = ref_text ,
229273 speed = speed ,
@@ -242,13 +286,6 @@ async def synthesize_speech(
242286 raise HTTPException (status_code = 500 , detail = str (e ))
243287
244288
245- @app .get ("/docs" , include_in_schema = False )
246- async def docs_redirect ():
247- """Redirect to Swagger UI."""
248- from fastapi .responses import RedirectResponse
249- return RedirectResponse (url = "/docs" )
250-
251-
252289if __name__ == "__main__" :
253290 import uvicorn
254291 uvicorn .run (app , host = "0.0.0.0" , port = 8000 )
0 commit comments