From ef5137b0100724241912c4b2bcd6297fa1e52647 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 12:02:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(channels):=20voice=20transcription=20quali?= =?UTF-8?q?ty=20=E2=80=94=20annotation=20+=20optional=20LLM=20correction?= =?UTF-8?q?=20(#1215)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add two layers of post-processing for STT output in voice channels: - Layer A (always on): prepend a hint so the downstream LLM treats voice input as speech-recognised text that may contain errors. - Layer B (opt-in via stt.correction.enabled): run a fast LLM pass to fix obvious transcription mistakes before delivery. Failure is non-fatal — the adapter falls back to the raw transcription. Wires SttCorrectionConfig + the kernel driver registry into the Telegram and Web channel adapters. The driver registry is read from KernelHandle at message-handling time to avoid extra plumbing through polling loops. Closes #1215 --- crates/app/src/lib.rs | 8 +- crates/channels/src/lib.rs | 1 + crates/channels/src/telegram/adapter.rs | 43 ++++++- crates/channels/src/voice.rs | 152 ++++++++++++++++++++++++ crates/channels/src/web.rs | 92 ++++++++++---- crates/drivers/stt/src/config.rs | 30 +++++ crates/drivers/stt/src/lib.rs | 2 +- 7 files changed, 302 insertions(+), 26 deletions(-) create mode 100644 crates/channels/src/voice.rs diff --git a/crates/app/src/lib.rs b/crates/app/src/lib.rs index 3f5ca27cc..e2894f36b 100644 --- a/crates/app/src/lib.rs +++ b/crates/app/src/lib.rs @@ -378,9 +378,12 @@ pub async fn start_with_options( .await .whatever_context("Failed to initialize BackendState")?; + let stt_correction = config.stt.as_ref().and_then(|s| s.correction.clone()); + let web_adapter = Arc::new( rara_channels::web::WebAdapter::new(config.owner_token.clone()) - .with_stt_service(stt_service.clone()), + .with_stt_service(stt_service.clone()) + .with_stt_correction(stt_correction.clone()), ); let web_router = web_adapter.router(); @@ -388,6 +391,7 @@ pub async fn start_with_options( &backend.settings_svc, rara.user_question_manager.clone(), stt_service, + stt_correction, tts_service, ) .await @@ -737,6 +741,7 @@ async fn try_build_telegram( settings_svc: &rara_backend_admin::settings::SettingsSvc, user_question_manager: rara_kernel::user_question::UserQuestionManagerRef, stt_service: Option, + stt_correction: Option, tts_service: Option, ) -> Result>, Whatever> { use rara_domain_shared::settings::{SettingsProvider, keys}; @@ -786,6 +791,7 @@ async fn try_build_telegram( .with_config(tg_config) .with_user_question_manager(user_question_manager) .with_stt_service(stt_service) + .with_stt_correction(stt_correction) .with_tts_service(tts_service), ); diff --git a/crates/channels/src/lib.rs b/crates/channels/src/lib.rs index 4ca1326b3..3402fe20c 100644 --- a/crates/channels/src/lib.rs +++ b/crates/channels/src/lib.rs @@ -27,6 +27,7 @@ pub mod telegram; pub mod terminal; pub mod tool_display; +pub mod voice; pub mod web; pub mod wechat; diff --git a/crates/channels/src/telegram/adapter.rs b/crates/channels/src/telegram/adapter.rs index 0ff5e9a18..85b616faa 100644 --- a/crates/channels/src/telegram/adapter.rs +++ b/crates/channels/src/telegram/adapter.rs @@ -935,6 +935,8 @@ pub struct TelegramAdapter { user_question_manager: Option, /// Optional STT service for transcribing voice messages to text. stt_service: Option, + /// Optional configuration for the post-transcription LLM correction pass. + stt_correction: Option, /// Optional TTS service for synthesizing voice replies. tts_service: Option, /// Chat IDs whose most recent inbound message was a voice note. @@ -966,6 +968,7 @@ impl TelegramAdapter { active_streams: Arc::new(DashMap::new()), user_question_manager: None, stt_service: None, + stt_correction: None, tts_service: None, voice_chat_ids: Arc::new(DashSet::new()), } @@ -1088,6 +1091,22 @@ impl TelegramAdapter { self } + /// Attach an optional STT correction config. + /// + /// When `SttCorrectionConfig::enabled` is `true`, raw transcriptions are + /// routed through a fast LLM that fixes obvious speech-recognition + /// mistakes before delivery. The driver registry is read from the bound + /// `KernelHandle` at message-handling time, so no extra plumbing is + /// required. + #[must_use] + pub fn with_stt_correction( + mut self, + correction: Option, + ) -> Self { + self.stt_correction = correction; + self + } + /// Attach a TTS service for synthesizing voice replies. #[must_use] pub fn with_tts_service(mut self, tts: Option) -> Self { @@ -1395,6 +1414,7 @@ impl ChannelAdapter for TelegramAdapter { .clone() .into(); let stt_service = self.stt_service.clone(); + let stt_correction = self.stt_correction.clone(); let voice_chat_ids = Arc::clone(&self.voice_chat_ids); // Register slash-menu with Telegram so '/' shows available commands. @@ -1468,6 +1488,7 @@ impl ChannelAdapter for TelegramAdapter { command_handlers, callback_handlers, stt_service, + stt_correction, voice_chat_ids, ) .await; @@ -1522,6 +1543,7 @@ async fn polling_loop( command_handlers: Arc<[Arc]>, callback_handlers: Arc<[Arc]>, stt_service: Option, + stt_correction: Option, voice_chat_ids: Arc>, ) { let mut offset: Option = None; @@ -1580,6 +1602,7 @@ async fn polling_loop( let command_handlers = Arc::clone(&command_handlers); let callback_handlers = Arc::clone(&callback_handlers); let stt = stt_service.clone(); + let stt_corr = stt_correction.clone(); let voice_ids = Arc::clone(&voice_chat_ids); tokio::spawn(async move { handle_update( @@ -1594,6 +1617,7 @@ async fn polling_loop( &command_handlers, &callback_handlers, &stt, + stt_corr.as_ref(), &voice_ids, ) .await; @@ -2366,6 +2390,7 @@ async fn handle_update( command_handlers: &[Arc], callback_handlers: &[Arc], stt_service: &Option, + stt_correction: Option<&rara_stt::SttCorrectionConfig>, voice_chat_ids: &Arc>, ) { // Read a snapshot of the runtime config for this update. @@ -2801,11 +2826,25 @@ async fn handle_update( tracing::info!(len = text.len(), "voice message transcribed"); // Mark this chat so egress replies with a voice note. voice_chat_ids.insert(chat_id); + + // Layer B: optional LLM correction (non-fatal on failure). + let corrected = crate::voice::maybe_correct( + &text, + stt_correction, + Some(handle.driver_registry()), + ) + .await; + + // Layer A: annotate so the downstream LLM treats the + // text as speech-recognised input that may contain + // mistakes. + let annotated = crate::voice::annotate_voice(&corrected); + let combined = match raw.content { MessageContent::Text(ref caption) if !caption.trim().is_empty() => { - format!("{caption}\n\n{text}") + format!("{caption}\n\n{annotated}") } - _ => text, + _ => annotated, }; RawPlatformMessage { content: MessageContent::Text(combined), diff --git a/crates/channels/src/voice.rs b/crates/channels/src/voice.rs new file mode 100644 index 000000000..1a0fb4e3f --- /dev/null +++ b/crates/channels/src/voice.rs @@ -0,0 +1,152 @@ +// Copyright 2025 Rararulab +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Voice transcription post-processing utilities. +//! +//! Two layers of quality improvement for STT output: +//! +//! - **Layer A (annotation)**: Always prepends a hint so the downstream LLM +//! knows the text came from speech recognition and may contain errors. +//! - **Layer B (correction)**: Optionally runs a fast LLM pass to fix obvious +//! transcription mistakes before delivery. Controlled by +//! [`SttCorrectionConfig`]. + +use rara_kernel::llm::{ + DriverRegistryRef, + types::{CompletionRequest, Message, ToolChoice}, +}; +use rara_stt::SttCorrectionConfig; + +/// Prefix prepended to every voice transcription so the downstream LLM +/// interprets the text with appropriate error tolerance. +pub const VOICE_ANNOTATION_PREFIX: &str = + "[Voice transcription \u{2014} may contain errors, interpret by context]"; + +/// Annotate a transcribed text with the voice-transcription hint. +pub fn annotate_voice(text: &str) -> String { format!("{VOICE_ANNOTATION_PREFIX}\n{text}") } + +/// Run an optional LLM correction pass on the raw STT output. +/// +/// Returns the corrected text when correction is enabled and succeeds. +/// Falls back to the original `text` on any error — correction failure +/// must never block the message. +pub async fn maybe_correct( + text: &str, + correction: Option<&SttCorrectionConfig>, + driver_registry: Option<&DriverRegistryRef>, +) -> String { + let Some(cfg) = correction.filter(|c| c.enabled) else { + return text.to_owned(); + }; + let Some(registry) = driver_registry else { + tracing::debug!("STT correction enabled but no driver registry available, skipping"); + return text.to_owned(); + }; + + let (driver, model) = match registry.resolve( + "_stt_correction", + cfg.provider.as_deref(), + cfg.model.as_deref(), + ) { + Ok(pair) => pair, + Err(e) => { + tracing::warn!(error = %e, "STT correction: failed to resolve LLM driver, skipping"); + return text.to_owned(); + } + }; + + let request = CompletionRequest { + model, + messages: vec![ + Message::system( + "You are a transcription error corrector. Fix obvious speech recognition \ + mistakes. Output only the corrected text.", + ), + Message::user(format!( + "Correct any transcription errors in the following voice message. Preserve the \ + original meaning. Only fix obvious mistakes. Output the corrected text only, no \ + explanation.\n\n{text}" + )), + ], + tools: Vec::new(), + temperature: Some(0.1), + max_tokens: Some(4096), + thinking: None, + tool_choice: ToolChoice::None, + parallel_tool_calls: false, + frequency_penalty: None, + top_p: None, + }; + + match driver.complete(request).await { + Ok(resp) => resp.content.unwrap_or_else(|| text.to_owned()), + Err(e) => { + tracing::warn!(error = %e, "STT correction LLM call failed, using raw transcription"); + text.to_owned() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn annotation_format() { + let result = annotate_voice("hello world"); + assert!(result.starts_with(VOICE_ANNOTATION_PREFIX)); + assert!(result.ends_with("hello world")); + // Exactly one newline between prefix and text. + let expected = format!("{VOICE_ANNOTATION_PREFIX}\nhello world"); + assert_eq!(result, expected); + } + + #[test] + fn annotation_preserves_multiline() { + let input = "line one\nline two"; + let result = annotate_voice(input); + assert_eq!( + result, + format!("{VOICE_ANNOTATION_PREFIX}\nline one\nline two") + ); + } + + #[tokio::test] + async fn correction_disabled_returns_original() { + let text = "some text"; + // No correction config. + assert_eq!(maybe_correct(text, None, None).await, text); + + // Config present but disabled. + let cfg = SttCorrectionConfig { + enabled: false, + model: None, + provider: None, + }; + assert_eq!(maybe_correct(text, Some(&cfg), None).await, text); + } + + #[tokio::test] + async fn correction_enabled_no_registry_returns_original() { + let cfg = SttCorrectionConfig { + enabled: true, + model: Some("test-model".to_owned()), + provider: Some("test".to_owned()), + }; + assert_eq!( + maybe_correct("raw text", Some(&cfg), None).await, + "raw text" + ); + } +} diff --git a/crates/channels/src/web.rs b/crates/channels/src/web.rs index c11d882ac..19692db2f 100644 --- a/crates/channels/src/web.rs +++ b/crates/channels/src/web.rs @@ -370,6 +370,8 @@ pub struct WebAdapter { shutdown_rx: watch::Receiver, /// Optional STT service for transcribing voice messages to text. stt_service: Option, + /// Optional configuration for the post-transcription LLM correction pass. + stt_correction: Option, } impl WebAdapter { @@ -385,6 +387,7 @@ impl WebAdapter { shutdown_tx, shutdown_rx, stt_service: None, + stt_correction: None, } } @@ -395,6 +398,21 @@ impl WebAdapter { self } + /// Attach an optional STT correction config. + /// + /// When `SttCorrectionConfig::enabled` is `true`, raw transcriptions are + /// routed through a fast LLM that fixes obvious speech-recognition + /// mistakes before delivery. The driver registry is read from the bound + /// `KernelHandle` at message-handling time. + #[must_use] + pub fn with_stt_correction( + mut self, + correction: Option, + ) -> Self { + self.stt_correction = correction; + self + } + /// Returns an [`axum::Router`] with WebSocket, SSE, and message endpoints. /// /// Mount this into your application: @@ -410,6 +428,7 @@ impl WebAdapter { owner_token: self.owner_token.clone(), shutdown_rx: self.shutdown_rx.clone(), stt_service: self.stt_service.clone(), + stt_correction: self.stt_correction.clone(), }; Router::new() @@ -440,9 +459,6 @@ impl WebAdapter { user_id: &str, content: MessageContent, ) -> Result<(), String> { - let content = transcribe_audio_blocks(content, &self.stt_service).await; - let raw = build_raw_platform_message(session_key, user_id, content); - let handle = { let guard = self.sink.read().await; guard @@ -451,6 +467,15 @@ impl WebAdapter { .ok_or_else(|| "adapter not started".to_owned())? }; + let content = transcribe_audio_blocks( + content, + &self.stt_service, + self.stt_correction.as_ref(), + Some(handle.driver_registry()), + ) + .await; + let raw = build_raw_platform_message(session_key, user_id, content); + let msg = handle.resolve(raw).await.map_err(|e| e.to_string())?; handle.submit_message(msg).map_err(|e| e.to_string())?; @@ -505,6 +530,7 @@ struct WebAdapterState { owner_token: Option, shutdown_rx: watch::Receiver, stt_service: Option, + stt_correction: Option, } // --------------------------------------------------------------------------- @@ -588,9 +614,16 @@ use rara_kernel::channel::types::ContentBlock; /// Transcribe any `AudioBase64` blocks in the message content, replacing them /// with `Text` blocks containing the transcribed text. +/// +/// When `correction` is enabled and a `driver_registry` is provided, every +/// transcribed clip is routed through an LLM correction pass. The result is +/// then prefixed with [`crate::voice::VOICE_ANNOTATION_PREFIX`] so the +/// downstream agent treats it as speech-recognised input. async fn transcribe_audio_blocks( content: MessageContent, stt: &Option, + correction: Option<&rara_stt::SttCorrectionConfig>, + driver_registry: Option<&rara_kernel::llm::DriverRegistryRef>, ) -> MessageContent { let blocks = match content { MessageContent::Text(_) => return content, @@ -608,12 +641,9 @@ async fn transcribe_audio_blocks( for block in blocks { match block { ContentBlock::AudioBase64 { data, media_type } => { - let text = transcribe_single_audio(&data, &media_type, stt).await; - let text = if text.is_empty() { - "[voice message]".to_owned() - } else { - text - }; + let text = + transcribe_single_audio(&data, &media_type, stt, correction, driver_registry) + .await; result.push(ContentBlock::Text { text }); } other => result.push(other), @@ -630,10 +660,16 @@ async fn transcribe_audio_blocks( } /// Transcribe a single base64-encoded audio clip via the STT service. +/// +/// On success, the raw transcription is optionally corrected via LLM +/// (Layer B) and then annotated (Layer A) so the downstream agent +/// interprets it as speech-recognised input. async fn transcribe_single_audio( data_b64: &str, media_type: &str, stt: &Option, + correction: Option<&rara_stt::SttCorrectionConfig>, + driver_registry: Option<&rara_kernel::llm::DriverRegistryRef>, ) -> String { use base64::Engine; @@ -653,7 +689,13 @@ async fn transcribe_single_audio( match stt.transcribe(audio_bytes, media_type).await { Ok(text) => { info!(len = text.len(), "voice message transcribed"); - text + if text.is_empty() { + return "[voice message]".to_owned(); + } + // Layer B: optional LLM correction (non-fatal on failure). + let corrected = crate::voice::maybe_correct(&text, correction, driver_registry).await; + // Layer A: annotate so downstream LLM interprets it as STT output. + crate::voice::annotate_voice(&corrected) } Err(e) => { warn!(error = %e, "STT transcription failed"); @@ -751,6 +793,7 @@ async fn handle_ws(socket: WebSocket, params: SessionQuery, state: WebAdapterSta let session_key = session_key.clone(); let user_id = params.user_id.clone(); let stt_service = state.stt_service.clone(); + let stt_correction = state.stt_correction.clone(); tokio::spawn(async move { while let Some(Ok(msg)) = ws_rx.next().await { let text = match msg { @@ -767,12 +810,26 @@ async fn handle_ws(socket: WebSocket, params: SessionQuery, state: WebAdapterSta } let payload = parse_inbound_text_frame(&text); + + let guard = sink.read().await; + let Some(s) = guard.as_ref().cloned() else { + drop(guard); + warn!(session_key, "adapter not started, dropping inbound frame"); + continue; + }; + drop(guard); + // Transcribe any audio blocks before submitting to the kernel. - let content = transcribe_audio_blocks(payload.content, &stt_service).await; + let content = transcribe_audio_blocks( + payload.content, + &stt_service, + stt_correction.as_ref(), + Some(s.driver_registry()), + ) + .await; let raw = build_raw_platform_message(&session_key, &user_id, content); - let guard = sink.read().await; - if let Some(ref s) = *guard { + { // Send typing indicator before processing. WebAdapter::broadcast_event(&sessions, &session_key, &WebEvent::Typing); // Resolve identity + session first (like TG adapter), @@ -826,15 +883,6 @@ async fn handle_ws(socket: WebSocket, params: SessionQuery, state: WebAdapterSta ); } } - } else { - warn!(session_key, "sink not set, cannot dispatch message"); - WebAdapter::broadcast_event( - &sessions, - &session_key, - &WebEvent::Error { - message: "adapter not started".to_owned(), - }, - ); } } }) diff --git a/crates/drivers/stt/src/config.rs b/crates/drivers/stt/src/config.rs index 250c24558..f3336d520 100644 --- a/crates/drivers/stt/src/config.rs +++ b/crates/drivers/stt/src/config.rs @@ -14,6 +14,10 @@ use serde::{Deserialize, Serialize}; /// model: "whisper-1" # optional /// language: "zh" # optional, auto-detect if omitted /// timeout_secs: 60 # optional, default 60 +/// correction: # optional LLM post-correction +/// enabled: true +/// model: "glm-4-flash" +/// provider: "glm" /// ``` #[derive(Debug, Clone, Builder, Serialize, Deserialize)] pub struct SttConfig { @@ -45,6 +49,32 @@ pub struct SttConfig { /// Path to the whisper model file (required when `managed: true`). #[serde(skip_serializing_if = "Option::is_none")] pub model_path: Option, + + /// Optional LLM correction pass after transcription. + /// + /// When enabled, the raw transcription is sent through a fast LLM to + /// fix obvious speech-recognition errors before delivery. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub correction: Option, +} + +/// Configuration for the optional LLM post-correction pass. +/// +/// When `enabled` is `true`, the raw STT output is sent through a fast LLM +/// that fixes obvious transcription errors while preserving the original +/// meaning. Correction failure never blocks the message — the adapter falls +/// back to the raw transcription. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SttCorrectionConfig { + /// Whether to run an LLM correction pass after transcription. + pub enabled: bool, + /// The LLM model to use for correction (e.g. `"glm-4-flash"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// The LLM provider to use (e.g. `"glm"`). Falls back to the default + /// driver when omitted. + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, } fn default_model() -> String { "whisper-1".to_owned() } diff --git a/crates/drivers/stt/src/lib.rs b/crates/drivers/stt/src/lib.rs index d086536eb..5f91adc6e 100644 --- a/crates/drivers/stt/src/lib.rs +++ b/crates/drivers/stt/src/lib.rs @@ -5,7 +5,7 @@ mod error; mod process; mod service; -pub use config::SttConfig; +pub use config::{SttConfig, SttCorrectionConfig}; pub use error::{Result, SttError}; pub use process::WhisperProcess; pub use service::SttService;