Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions app/api/routers/tts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query

from app.models.engine import TTSRequest
from app.api.deps import get_tts_engine
from app.engines.base import BaseTTSEngine

router = APIRouter()


@router.post("/synthesize")
async def synthesize_text(request: TTSRequest):
"""TTS synthesis (placeholder - not implemented yet)"""
async def synthesize_text(
text: str = Query(..., description="Text to synthesize"),
voice: str | None = Query(None, description="Voice name/ID to use"),
speed: float = Query(1.0, gt=0, le=3.0, description="Speech speed multiplier"),
engine_params: str | None = Query(None, description="JSON engine parameters"),
tts_engine: BaseTTSEngine = Depends(get_tts_engine),
):
"""
Synthesize text to speech (invoke mode)

Query params:
- engine: TTS engine name (required)
- text: Text to synthesize (required)
- voice: Optional voice name/ID
- speed: Speech speed multiplier (0 < speed <= 3.0)
- engine_params: Optional JSON engine parameters

Returns complete audio with metrics.
"""
raise HTTPException(501, "TTS not implemented yet")


@router.post("/synthesize/stream")
async def synthesize_text_stream(request: TTSRequest):
"""TTS streaming (placeholder - not implemented yet)"""
async def synthesize_text_stream(
text: str = Query(..., description="Text to synthesize"),
voice: str | None = Query(None, description="Voice name/ID to use"),
speed: float = Query(1.0, gt=0, le=3.0, description="Speech speed multiplier"),
engine_params: str | None = Query(None, description="JSON engine parameters"),
tts_engine: BaseTTSEngine = Depends(get_tts_engine),
):
"""
Synthesize text to speech with streaming

Query params:
- engine: TTS engine name (required)
- text: Text to synthesize (required)
- voice: Optional voice name/ID
- speed: Speech speed multiplier (0 < speed <= 3.0)
- engine_params: Optional JSON engine parameters

Returns progressive audio chunks followed by final response.
"""
raise HTTPException(501, "TTS streaming not implemented yet")
67 changes: 1 addition & 66 deletions app/models/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
Engine configuration and input/output data models for STT/TTS

Design:
- STTRequest: Unified for REST (with audio_data) and WebSocket config (audio_data=None)
- TTSRequest: REST only with stream_response option
- Lightweight chunks for streaming (minimal per-chunk metrics)
- Full response models (STTResponse/TTSResponse) used for both invoke and streaming modes
- STTPerformanceMetrics/TTSPerformanceMetrics: Extended with streaming-specific fields
"""

from typing import Any, Literal
from typing import Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -41,69 +39,6 @@ class EngineConfig(BaseModel):
timeout_seconds: int = Field(default=300, ge=1, description="Processing timeout")


# =============================================================================
# Request Models
# =============================================================================


class STTRequest(BaseModel):
"""
Input for STT processing (REST or WebSocket)

Usage:
- REST POST: audio_data is required, stream_response controls output format
- WebSocket config: audio_data=None, subsequent messages are raw audio bytes
"""

# Audio - required for REST, None for WebSocket config message
audio_data: bytes | None = Field(
None, description="Audio bytes (REST) or None (WebSocket config)"
)

# Config options
language: str | None = Field(None, description="Language hint (e.g., 'en', 'vi')")
format: str | None = Field(None, description="Audio format hint (wav, mp3, webm)")
sample_rate: int | None = Field(
None, description="Sample rate in Hz (important for streaming)"
)

# Response preference (REST only)
stream_response: bool = Field(
default=False, description="Return StreamingResponse chunks vs full response"
)

# Engine-specific parameters (flexible dict)
engine_params: dict[str, Any] = Field(
default_factory=dict,
description="Engine-specific parameters (e.g., temperature, beam_size, api_keys)",
)


class TTSRequest(BaseModel):
"""
Input for TTS processing (REST only)

stream_response controls whether to return full audio or StreamingResponse
"""

text: str = Field(..., description="Text to synthesize")
voice: str | None = Field(None, description="Voice name/ID to use")
speed: float = Field(
default=1.0, gt=0, le=3.0, description="Speech speed multiplier"
)

# Response preference
stream_response: bool = Field(
default=False, description="Return StreamingResponse chunks vs full response"
)

# Engine-specific parameters (flexible dict)
engine_params: dict[str, Any] = Field(
default_factory=dict,
description="Engine-specific parameters (e.g., temperature, beam_size, api_keys)",
)


# =============================================================================
# Streaming Chunk Models (Lightweight - minimal per-chunk overhead)
# =============================================================================
Expand Down
34 changes: 18 additions & 16 deletions tests/unit/api/test_tts_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,50 @@
class TestTTSRouter:
"""TTS router stub endpoint tests"""

def test_synthesize_returns_501(self, client):
def test_synthesize_returns_501(self, client_both):
"""POST /synthesize returns 501"""
response = client.post(
response = client_both.post(
"/api/v1/tts/synthesize",
json={"text": "Hello", "engine": "default"},
params={"text": "Hello", "engine": "default"},
)

assert response.status_code == 501

def test_synthesize_stream_returns_501(self, client):
def test_synthesize_stream_returns_501(self, client_both):
"""POST /synthesize/stream returns 501"""
response = client.post(
response = client_both.post(
"/api/v1/tts/synthesize/stream",
json={"text": "Hello", "engine": "default"},
params={"text": "Hello", "engine": "default"},
)

assert response.status_code == 501

def test_endpoints_exist(self, client):
def test_endpoints_exist(self, client_both):
"""TTS endpoints exist (not 404)"""
r1 = client.post("/api/v1/tts/synthesize", json={"text": "x", "engine": "x"})
r2 = client.post(
"/api/v1/tts/synthesize/stream", json={"text": "x", "engine": "x"}
r1 = client_both.post(
"/api/v1/tts/synthesize", params={"text": "x", "engine": "default"}
)
r2 = client_both.post(
"/api/v1/tts/synthesize/stream", params={"text": "x", "engine": "default"}
)

assert r1.status_code != 404
assert r2.status_code != 404

def test_synthesize_with_minimal_request(self, client):
def test_synthesize_with_minimal_request(self, client_both):
"""Synthesize with minimal request"""
response = client.post(
response = client_both.post(
"/api/v1/tts/synthesize",
json={"text": "Test", "engine": "test"},
params={"text": "Test", "engine": "default"},
)

assert response.status_code == 501

def test_synthesize_missing_fields_returns_422(self, client):
def test_synthesize_missing_fields_returns_422(self, client_both):
"""Missing required fields returns 422"""
response = client.post(
response = client_both.post(
"/api/v1/tts/synthesize",
json={},
params={},
)

assert response.status_code == 422
101 changes: 0 additions & 101 deletions tests/unit/models/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
EngineConfig,
Segment,
STTChunk,
STTRequest,
STTResponse,
TTSChunk,
TTSRequest,
TTSResponse,
)
from app.models.metrics import (
Expand Down Expand Up @@ -95,105 +93,6 @@ def test_segment_requires_all_fields(self):
Segment(end=1.0, text="hello") # Missing start


class TestSTTRequest:
"""Test STTRequest model (unified for REST and WebSocket)"""

def test_create_stt_request_for_rest(self):
"""Should create STTRequest for REST with audio_data"""
request = STTRequest(audio_data=b"fake audio")

assert request.audio_data == b"fake audio"
assert request.language is None
assert request.format is None
assert request.sample_rate is None
assert request.stream_response is False

def test_create_stt_request_for_websocket_config(self):
"""Should create STTRequest for WebSocket config (no audio_data)"""
request = STTRequest(
language="en",
format="wav",
sample_rate=16000,
)

assert request.audio_data is None # WebSocket config
assert request.language == "en"
assert request.format == "wav"
assert request.sample_rate == 16000

def test_create_stt_request_with_stream_response(self):
"""Should create STTRequest with streaming response preference"""
request = STTRequest(
audio_data=b"fake audio",
stream_response=True,
)

assert request.audio_data == b"fake audio"
assert request.stream_response is True

def test_stt_request_all_fields_optional_except_pattern(self):
"""Should allow creating empty STTRequest (for WebSocket config)"""
request = STTRequest()
assert request.audio_data is None
assert request.engine_params == {}

def test_stt_request_with_engine_params(self):
"""Should allow passing engine-specific parameters"""
params = {"temperature": 0.7, "beam_size": 5}
request = STTRequest(engine_params=params)
assert request.engine_params == params


class TestTTSRequest:
"""Test TTSRequest model (REST only)"""

def test_create_tts_request_minimal(self):
"""Should create TTSRequest with only text"""
request = TTSRequest(text="Hello world")

assert request.text == "Hello world"
assert request.voice is None
assert request.speed == 1.0 # Default
assert request.stream_response is False

def test_create_tts_request_with_all_fields(self):
"""Should create TTSRequest with all fields"""
request = TTSRequest(
text="Hello world",
voice="en-US-JennyNeural",
speed=1.2,
stream_response=True,
)

assert request.text == "Hello world"
assert request.voice == "en-US-JennyNeural"
assert request.speed == 1.2
assert request.stream_response is True

def test_tts_request_requires_text(self):
"""Should require text field"""
with pytest.raises(ValidationError):
TTSRequest(voice="en-US-JennyNeural")

def test_tts_request_validates_speed(self):
"""Speed must be > 0 and <= 3.0"""
TTSRequest(text="test", speed=0.5)
TTSRequest(text="test", speed=3.0)

with pytest.raises(ValidationError):
TTSRequest(text="test", speed=0) # Too low

with pytest.raises(ValidationError):
TTSRequest(text="test", speed=3.5) # Too high

def test_tts_request_with_engine_params(self):
"""Should allow passing engine-specific parameters"""
params = {"api_key": "test_key", "model_version": "v1"}
request = TTSRequest(text="hello", engine_params=params)
assert request.engine_params == params
assert request.text == "hello"


class TestSTTResponse:
"""Test STTResponse model for invoke mode"""

Expand Down