Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions src/serve/asr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
//! Managed Whisper ASR child process.
//!
//! Spawns a Python-based Whisper ASR server as a child process using `uv run`,
//! writing the embedded Python script to a temporary file at runtime.

use std::time::Duration;

use snafu::{ResultExt, Snafu};
use tempfile::NamedTempFile;
use tokio::process::{Child, Command};
use tracing::{debug, info, warn};

/// Embedded Python script for the Whisper ASR server.
const WHISPER_SCRIPT: &str = include_str!("whisper_server.py");

/// Errors that can occur while managing the Whisper ASR process.
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum AsrError {
/// `uv` is not installed or not on `PATH`.
#[snafu(display(
"uv not found on PATH — install with: curl -LsSf https://astral.sh/uv/install.sh | sh"
))]
UvNotFound,

/// Failed to write the embedded Python script to a temp file.
#[snafu(display("failed to write whisper script: {source}"))]
ScriptWrite { source: std::io::Error },

/// Failed to spawn the `uv run` child process.
#[snafu(display("failed to spawn whisper process: {source}"))]
Spawn { source: std::io::Error },

/// The server did not become ready within the timeout.
#[snafu(display("whisper server did not start within {timeout_secs}s"))]
StartTimeout { timeout_secs: u64 },

/// Failed to find a free port.
#[snafu(display("failed to bind ephemeral port: {source}"))]
PortBind { source: std::io::Error },
}

/// Result type for ASR operations.
pub type Result<T> = std::result::Result<T, AsrError>;

/// Managed Whisper ASR child process.
pub struct WhisperProcess {
/// The spawned child process.
child: Child,
/// The port the server is listening on.
port: u16,
/// Keep the temp file alive so the script is not deleted.
_script: NamedTempFile,
}

impl WhisperProcess {
/// Start the Whisper ASR server on a random available port.
///
/// Writes the embedded Python script to a temp file, then spawns it
/// via `uv run` with the required dependencies.
pub async fn start() -> Result<Self> {
// Verify uv is available.
let uv_ok = Command::new("uv")
.arg("--version")
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.await;

match uv_ok {
Ok(status) if status.success() => {}
_ => return Err(AsrError::UvNotFound),
}

// Find a free port by binding to port 0.
let port = {
let listener = std::net::TcpListener::bind("127.0.0.1:0").context(PortBindSnafu)?;
listener.local_addr().context(PortBindSnafu)?.port()
};

// Write the embedded script to a temp file.
let script = tempfile::Builder::new()
.prefix("kotoba-whisper-")
.suffix(".py")
.tempfile()
.context(ScriptWriteSnafu)?;

std::fs::write(script.path(), WHISPER_SCRIPT).context(ScriptWriteSnafu)?;

info!(port, "spawning whisper ASR server");

let child = Command::new("uv")
.arg("run")
.arg("--python")
.arg("3.11")
.arg("--with")
.arg("faster-whisper")
.arg("--with")
.arg("fastapi[standard]")
.arg("python")
.arg(script.path())
.arg("large-v3-turbo")
.arg(port.to_string())
.stdout(std::process::Stdio::inherit())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()
.context(SpawnSnafu)?;

let process = Self {
child,
port,
_script: script,
};

process.wait_ready(Duration::from_secs(120)).await?;

Ok(process)
}

/// The URL of the running ASR endpoint.
pub fn url(&self) -> String {
format!("http://127.0.0.1:{}/v1/audio/transcriptions", self.port)
}

/// Wait for the server to be ready by polling the TCP port.
async fn wait_ready(&self, timeout: Duration) -> Result<()> {
let deadline = tokio::time::Instant::now() + timeout;
let addr = format!("127.0.0.1:{}", self.port);

loop {
if tokio::time::Instant::now() >= deadline {
return Err(AsrError::StartTimeout {
timeout_secs: timeout.as_secs(),
});
}

match tokio::net::TcpStream::connect(&addr).await {
Ok(_) => {
debug!(port = self.port, "whisper ASR server is ready");
return Ok(());
}
Err(_) => {
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
}
}

/// Kill the child process.
pub async fn shutdown(&mut self) {
if let Err(e) = self.child.kill().await {
warn!("failed to kill whisper process: {e}");
}
}
}
11 changes: 9 additions & 2 deletions src/serve/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub struct AppState {
pub config: Arc<crate::app_config::AppConfig>,
/// Factory for creating TTS backends.
pub factory: Arc<dyn BackendFactory>,
/// Default ASR endpoint URL (from managed Whisper or fallback).
pub asr_url: Arc<String>,
}

/// `GET /health` — returns a simple health-check response.
Expand Down Expand Up @@ -589,11 +591,16 @@ pub async fn pcm_worklet() -> impl IntoResponse {
/// `GET /ws/voice` — WebSocket endpoint for real-time voice conversation.
///
/// Accepts query parameters for ASR/LLM/TTS configuration and upgrades to
/// a WebSocket that runs the full voice pipeline server-side.
/// a WebSocket that runs the full voice pipeline server-side. When the client
/// does not supply an `asr_url` query param, the managed Whisper URL is used.
pub async fn voice_ws(
State(state): State<AppState>,
Query(params): Query<super::voice::VoiceParams>,
Query(mut params): Query<super::voice::VoiceParams>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
// Fall back to the managed ASR URL when the client did not provide one.
if params.asr_url.is_none() {
params.asr_url = Some((*state.asr_url).clone());
}
ws.on_upgrade(move |socket| super::voice::handle_voice_ws(socket, state, params))
}
26 changes: 26 additions & 0 deletions src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//! recording
//! - `GET /demo` — bundled web demo for real-time voice conversation

mod asr;
mod handlers;
mod models;
#[cfg(test)]
Expand Down Expand Up @@ -51,11 +52,31 @@ pub fn build_router(state: AppState) -> Router {

/// Start the HTTP server and listen for requests.
pub async fn run(host: &str, port: u16) -> crate::error::Result<()> {
// Start managed Whisper ASR server.
eprintln!("Starting Whisper ASR server...");
let mut whisper = match asr::WhisperProcess::start().await {
Ok(w) => {
eprintln!(" Whisper ASR: {}", w.url());
Some(w)
}
Err(e) => {
eprintln!(" WARNING: Failed to start Whisper ASR: {e}");
eprintln!(" Voice pipeline will use external ASR endpoint");
None
}
};

let asr_url = whisper.as_ref().map_or_else(
|| "http://localhost:8000/v1/audio/transcriptions".to_string(),
asr::WhisperProcess::url,
);

let config = Arc::new(crate::app_config::load().clone());

let state = AppState {
config,
factory: Arc::new(DefaultBackendFactory),
asr_url: Arc::new(asr_url),
};

let app = build_router(state);
Expand All @@ -75,5 +96,10 @@ pub async fn run(host: &str, port: u16) -> crate::error::Result<()> {

axum::serve(listener, app).await.context(error::IoSnafu)?;

// Cleanup on exit.
if let Some(ref mut w) = whisper {
w.shutdown().await;
}

Ok(())
}
1 change: 1 addition & 0 deletions src/serve/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async fn test_server() -> String {
let state = AppState {
config: Arc::new(AppConfig::default()),
factory: Arc::new(StubBackendFactory),
asr_url: Arc::new("http://localhost:8000/v1/audio/transcriptions".to_string()),
};
let app = super::build_router(state);
let listener = TcpListener::bind("127.0.0.1:0")
Expand Down
13 changes: 8 additions & 5 deletions src/serve/voice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ enum ServerMessage {
/// Configuration passed as query parameters on the WebSocket URL.
#[derive(Debug, Clone, Deserialize)]
pub struct VoiceParams {
/// Whisper-compatible ASR endpoint URL.
#[serde(default = "default_asr_url")]
pub asr_url: String,
/// Whisper-compatible ASR endpoint URL (filled from managed process when
/// absent).
#[serde(default)]
pub asr_url: Option<String>,
/// OpenAI-compatible LLM endpoint base URL.
#[serde(default = "default_llm_url")]
pub llm_url: String,
Expand All @@ -108,7 +109,8 @@ pub struct VoiceParams {
pub system_prompt: String,
}

fn default_asr_url() -> String { "http://localhost:8000/v1/audio/transcriptions".to_string() }
/// Fallback ASR URL used when no managed Whisper process is available.
const FALLBACK_ASR_URL: &str = "http://localhost:8000/v1/audio/transcriptions";

fn default_llm_url() -> String { "http://localhost:11434/v1".to_string() }

Expand Down Expand Up @@ -655,7 +657,8 @@ async fn process_utterance(
.await;

// 1. ASR
let transcript = transcribe(client, &params.asr_url, audio, INPUT_SAMPLE_RATE).await?;
let asr_url = params.asr_url.as_deref().unwrap_or(FALLBACK_ASR_URL);
let transcript = transcribe(client, asr_url, audio, INPUT_SAMPLE_RATE).await?;

if transcript.is_empty() {
return Err("empty transcript".to_string());
Expand Down
31 changes: 31 additions & 0 deletions src/serve/whisper_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Minimal OpenAI-compatible Whisper ASR server managed by kotoba."""
import sys, io, tempfile, uvicorn
from fastapi import FastAPI, UploadFile, File, Form
from faster_whisper import WhisperModel

app = FastAPI()
model = None

@app.on_event("startup")
def load():
global model
model_size = sys.argv[1] if len(sys.argv) > 1 else "large-v3-turbo"
print(f"Loading Whisper model ({model_size})...", flush=True)
model = WhisperModel(model_size, device="cpu", compute_type="float32")
print("Whisper ready.", flush=True)

@app.post("/v1/audio/transcriptions")
async def transcribe(file: UploadFile = File(...), model_name: str = Form("whisper-1", alias="model")):
data = await file.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as f:
f.write(data)
f.flush()
segments, info = model.transcribe(f.name, beam_size=5, language="ja",
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500))
text = "".join(s.text for s in segments).strip()
return {"text": text}

if __name__ == "__main__":
port = int(sys.argv[2]) if len(sys.argv) > 2 else 8000
uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning")
Loading