diff --git a/app/api/routers/tts.py b/app/api/routers/tts.py index 3270c3d..ce6de69 100644 --- a/app/api/routers/tts.py +++ b/app/api/routers/tts.py @@ -1,17 +1,52 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query -from app.models.engine import TTSRequest +from app.api.deps import get_tts_engine +from app.engines.base import BaseTTSEngine router = APIRouter() @router.post("/synthesize") -async def synthesize_text(request: TTSRequest): - """TTS synthesis (placeholder - not implemented yet)""" +async def synthesize_text( + text: str = Query(..., description="Text to synthesize"), + voice: str | None = Query(None, description="Voice name/ID to use"), + speed: float = Query(1.0, gt=0, le=3.0, description="Speech speed multiplier"), + engine_params: str | None = Query(None, description="JSON engine parameters"), + tts_engine: BaseTTSEngine = Depends(get_tts_engine), +): + """ + Synthesize text to speech (invoke mode) + + Query params: + - engine: TTS engine name (required) + - text: Text to synthesize (required) + - voice: Optional voice name/ID + - speed: Speech speed multiplier (0 < speed <= 3.0) + - engine_params: Optional JSON engine parameters + + Returns complete audio with metrics. + """ raise HTTPException(501, "TTS not implemented yet") @router.post("/synthesize/stream") -async def synthesize_text_stream(request: TTSRequest): - """TTS streaming (placeholder - not implemented yet)""" +async def synthesize_text_stream( + text: str = Query(..., description="Text to synthesize"), + voice: str | None = Query(None, description="Voice name/ID to use"), + speed: float = Query(1.0, gt=0, le=3.0, description="Speech speed multiplier"), + engine_params: str | None = Query(None, description="JSON engine parameters"), + tts_engine: BaseTTSEngine = Depends(get_tts_engine), +): + """ + Synthesize text to speech with streaming + + Query params: + - engine: TTS engine name (required) + - text: Text to synthesize (required) + - voice: Optional voice name/ID + - speed: Speech speed multiplier (0 < speed <= 3.0) + - engine_params: Optional JSON engine parameters + + Returns progressive audio chunks followed by final response. + """ raise HTTPException(501, "TTS streaming not implemented yet") diff --git a/app/models/engine.py b/app/models/engine.py index d25f7b8..0df0d2b 100644 --- a/app/models/engine.py +++ b/app/models/engine.py @@ -2,14 +2,12 @@ Engine configuration and input/output data models for STT/TTS Design: -- STTRequest: Unified for REST (with audio_data) and WebSocket config (audio_data=None) -- TTSRequest: REST only with stream_response option - Lightweight chunks for streaming (minimal per-chunk metrics) - Full response models (STTResponse/TTSResponse) used for both invoke and streaming modes - STTPerformanceMetrics/TTSPerformanceMetrics: Extended with streaming-specific fields """ -from typing import Any, Literal +from typing import Literal from pydantic import BaseModel, Field @@ -41,69 +39,6 @@ class EngineConfig(BaseModel): timeout_seconds: int = Field(default=300, ge=1, description="Processing timeout") -# ============================================================================= -# Request Models -# ============================================================================= - - -class STTRequest(BaseModel): - """ - Input for STT processing (REST or WebSocket) - - Usage: - - REST POST: audio_data is required, stream_response controls output format - - WebSocket config: audio_data=None, subsequent messages are raw audio bytes - """ - - # Audio - required for REST, None for WebSocket config message - audio_data: bytes | None = Field( - None, description="Audio bytes (REST) or None (WebSocket config)" - ) - - # Config options - language: str | None = Field(None, description="Language hint (e.g., 'en', 'vi')") - format: str | None = Field(None, description="Audio format hint (wav, mp3, webm)") - sample_rate: int | None = Field( - None, description="Sample rate in Hz (important for streaming)" - ) - - # Response preference (REST only) - stream_response: bool = Field( - default=False, description="Return StreamingResponse chunks vs full response" - ) - - # Engine-specific parameters (flexible dict) - engine_params: dict[str, Any] = Field( - default_factory=dict, - description="Engine-specific parameters (e.g., temperature, beam_size, api_keys)", - ) - - -class TTSRequest(BaseModel): - """ - Input for TTS processing (REST only) - - stream_response controls whether to return full audio or StreamingResponse - """ - - text: str = Field(..., description="Text to synthesize") - voice: str | None = Field(None, description="Voice name/ID to use") - speed: float = Field( - default=1.0, gt=0, le=3.0, description="Speech speed multiplier" - ) - - # Response preference - stream_response: bool = Field( - default=False, description="Return StreamingResponse chunks vs full response" - ) - - # Engine-specific parameters (flexible dict) - engine_params: dict[str, Any] = Field( - default_factory=dict, - description="Engine-specific parameters (e.g., temperature, beam_size, api_keys)", - ) - - # ============================================================================= # Streaming Chunk Models (Lightweight - minimal per-chunk overhead) # ============================================================================= diff --git a/tests/unit/api/test_tts_router.py b/tests/unit/api/test_tts_router.py index 5cd62ce..9547b26 100644 --- a/tests/unit/api/test_tts_router.py +++ b/tests/unit/api/test_tts_router.py @@ -4,48 +4,50 @@ class TestTTSRouter: """TTS router stub endpoint tests""" - def test_synthesize_returns_501(self, client): + def test_synthesize_returns_501(self, client_both): """POST /synthesize returns 501""" - response = client.post( + response = client_both.post( "/api/v1/tts/synthesize", - json={"text": "Hello", "engine": "default"}, + params={"text": "Hello", "engine": "default"}, ) assert response.status_code == 501 - def test_synthesize_stream_returns_501(self, client): + def test_synthesize_stream_returns_501(self, client_both): """POST /synthesize/stream returns 501""" - response = client.post( + response = client_both.post( "/api/v1/tts/synthesize/stream", - json={"text": "Hello", "engine": "default"}, + params={"text": "Hello", "engine": "default"}, ) assert response.status_code == 501 - def test_endpoints_exist(self, client): + def test_endpoints_exist(self, client_both): """TTS endpoints exist (not 404)""" - r1 = client.post("/api/v1/tts/synthesize", json={"text": "x", "engine": "x"}) - r2 = client.post( - "/api/v1/tts/synthesize/stream", json={"text": "x", "engine": "x"} + r1 = client_both.post( + "/api/v1/tts/synthesize", params={"text": "x", "engine": "default"} + ) + r2 = client_both.post( + "/api/v1/tts/synthesize/stream", params={"text": "x", "engine": "default"} ) assert r1.status_code != 404 assert r2.status_code != 404 - def test_synthesize_with_minimal_request(self, client): + def test_synthesize_with_minimal_request(self, client_both): """Synthesize with minimal request""" - response = client.post( + response = client_both.post( "/api/v1/tts/synthesize", - json={"text": "Test", "engine": "test"}, + params={"text": "Test", "engine": "default"}, ) assert response.status_code == 501 - def test_synthesize_missing_fields_returns_422(self, client): + def test_synthesize_missing_fields_returns_422(self, client_both): """Missing required fields returns 422""" - response = client.post( + response = client_both.post( "/api/v1/tts/synthesize", - json={}, + params={}, ) assert response.status_code == 422 diff --git a/tests/unit/models/test_engine.py b/tests/unit/models/test_engine.py index 0671b2c..9744ed4 100644 --- a/tests/unit/models/test_engine.py +++ b/tests/unit/models/test_engine.py @@ -5,10 +5,8 @@ EngineConfig, Segment, STTChunk, - STTRequest, STTResponse, TTSChunk, - TTSRequest, TTSResponse, ) from app.models.metrics import ( @@ -95,105 +93,6 @@ def test_segment_requires_all_fields(self): Segment(end=1.0, text="hello") # Missing start -class TestSTTRequest: - """Test STTRequest model (unified for REST and WebSocket)""" - - def test_create_stt_request_for_rest(self): - """Should create STTRequest for REST with audio_data""" - request = STTRequest(audio_data=b"fake audio") - - assert request.audio_data == b"fake audio" - assert request.language is None - assert request.format is None - assert request.sample_rate is None - assert request.stream_response is False - - def test_create_stt_request_for_websocket_config(self): - """Should create STTRequest for WebSocket config (no audio_data)""" - request = STTRequest( - language="en", - format="wav", - sample_rate=16000, - ) - - assert request.audio_data is None # WebSocket config - assert request.language == "en" - assert request.format == "wav" - assert request.sample_rate == 16000 - - def test_create_stt_request_with_stream_response(self): - """Should create STTRequest with streaming response preference""" - request = STTRequest( - audio_data=b"fake audio", - stream_response=True, - ) - - assert request.audio_data == b"fake audio" - assert request.stream_response is True - - def test_stt_request_all_fields_optional_except_pattern(self): - """Should allow creating empty STTRequest (for WebSocket config)""" - request = STTRequest() - assert request.audio_data is None - assert request.engine_params == {} - - def test_stt_request_with_engine_params(self): - """Should allow passing engine-specific parameters""" - params = {"temperature": 0.7, "beam_size": 5} - request = STTRequest(engine_params=params) - assert request.engine_params == params - - -class TestTTSRequest: - """Test TTSRequest model (REST only)""" - - def test_create_tts_request_minimal(self): - """Should create TTSRequest with only text""" - request = TTSRequest(text="Hello world") - - assert request.text == "Hello world" - assert request.voice is None - assert request.speed == 1.0 # Default - assert request.stream_response is False - - def test_create_tts_request_with_all_fields(self): - """Should create TTSRequest with all fields""" - request = TTSRequest( - text="Hello world", - voice="en-US-JennyNeural", - speed=1.2, - stream_response=True, - ) - - assert request.text == "Hello world" - assert request.voice == "en-US-JennyNeural" - assert request.speed == 1.2 - assert request.stream_response is True - - def test_tts_request_requires_text(self): - """Should require text field""" - with pytest.raises(ValidationError): - TTSRequest(voice="en-US-JennyNeural") - - def test_tts_request_validates_speed(self): - """Speed must be > 0 and <= 3.0""" - TTSRequest(text="test", speed=0.5) - TTSRequest(text="test", speed=3.0) - - with pytest.raises(ValidationError): - TTSRequest(text="test", speed=0) # Too low - - with pytest.raises(ValidationError): - TTSRequest(text="test", speed=3.5) # Too high - - def test_tts_request_with_engine_params(self): - """Should allow passing engine-specific parameters""" - params = {"api_key": "test_key", "model_version": "v1"} - request = TTSRequest(text="hello", engine_params=params) - assert request.engine_params == params - assert request.text == "hello" - - class TestSTTResponse: """Test STTResponse model for invoke mode"""