Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion crates/app/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,20 @@ pub async fn start_with_options(
.await
.whatever_context("Failed to initialize BackendState")?;

let stt_correction = config.stt.as_ref().and_then(|s| s.correction.clone());

let web_adapter = Arc::new(
rara_channels::web::WebAdapter::new(config.owner_token.clone())
.with_stt_service(stt_service.clone()),
.with_stt_service(stt_service.clone())
.with_stt_correction(stt_correction.clone()),
);
let web_router = web_adapter.router();

let telegram_adapter = match try_build_telegram(
&backend.settings_svc,
rara.user_question_manager.clone(),
stt_service,
stt_correction,
tts_service,
)
.await
Expand Down Expand Up @@ -737,6 +741,7 @@ async fn try_build_telegram(
settings_svc: &rara_backend_admin::settings::SettingsSvc,
user_question_manager: rara_kernel::user_question::UserQuestionManagerRef,
stt_service: Option<rara_stt::SttService>,
stt_correction: Option<rara_stt::SttCorrectionConfig>,
tts_service: Option<rara_tts::TtsService>,
) -> Result<Option<Arc<rara_channels::telegram::TelegramAdapter>>, Whatever> {
use rara_domain_shared::settings::{SettingsProvider, keys};
Expand Down Expand Up @@ -786,6 +791,7 @@ async fn try_build_telegram(
.with_config(tg_config)
.with_user_question_manager(user_question_manager)
.with_stt_service(stt_service)
.with_stt_correction(stt_correction)
.with_tts_service(tts_service),
);

Expand Down
1 change: 1 addition & 0 deletions crates/channels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
pub mod telegram;
pub mod terminal;
pub mod tool_display;
pub mod voice;
pub mod web;
pub mod wechat;

Expand Down
43 changes: 41 additions & 2 deletions crates/channels/src/telegram/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,8 @@ pub struct TelegramAdapter {
user_question_manager: Option<UserQuestionManagerRef>,
/// Optional STT service for transcribing voice messages to text.
stt_service: Option<rara_stt::SttService>,
/// Optional configuration for the post-transcription LLM correction pass.
stt_correction: Option<rara_stt::SttCorrectionConfig>,
/// Optional TTS service for synthesizing voice replies.
tts_service: Option<rara_tts::TtsService>,
/// Chat IDs whose most recent inbound message was a voice note.
Expand Down Expand Up @@ -966,6 +968,7 @@ impl TelegramAdapter {
active_streams: Arc::new(DashMap::new()),
user_question_manager: None,
stt_service: None,
stt_correction: None,
tts_service: None,
voice_chat_ids: Arc::new(DashSet::new()),
}
Expand Down Expand Up @@ -1088,6 +1091,22 @@ impl TelegramAdapter {
self
}

/// Attach an optional STT correction config.
///
/// When `SttCorrectionConfig::enabled` is `true`, raw transcriptions are
/// routed through a fast LLM that fixes obvious speech-recognition
/// mistakes before delivery. The driver registry is read from the bound
/// `KernelHandle` at message-handling time, so no extra plumbing is
/// required.
#[must_use]
pub fn with_stt_correction(
mut self,
correction: Option<rara_stt::SttCorrectionConfig>,
) -> Self {
self.stt_correction = correction;
self
}

/// Attach a TTS service for synthesizing voice replies.
#[must_use]
pub fn with_tts_service(mut self, tts: Option<rara_tts::TtsService>) -> Self {
Expand Down Expand Up @@ -1395,6 +1414,7 @@ impl ChannelAdapter for TelegramAdapter {
.clone()
.into();
let stt_service = self.stt_service.clone();
let stt_correction = self.stt_correction.clone();
let voice_chat_ids = Arc::clone(&self.voice_chat_ids);

// Register slash-menu with Telegram so '/' shows available commands.
Expand Down Expand Up @@ -1468,6 +1488,7 @@ impl ChannelAdapter for TelegramAdapter {
command_handlers,
callback_handlers,
stt_service,
stt_correction,
voice_chat_ids,
)
.await;
Expand Down Expand Up @@ -1522,6 +1543,7 @@ async fn polling_loop(
command_handlers: Arc<[Arc<dyn CommandHandler>]>,
callback_handlers: Arc<[Arc<dyn CallbackHandler>]>,
stt_service: Option<rara_stt::SttService>,
stt_correction: Option<rara_stt::SttCorrectionConfig>,
voice_chat_ids: Arc<DashSet<i64>>,
) {
let mut offset: Option<i32> = None;
Expand Down Expand Up @@ -1580,6 +1602,7 @@ async fn polling_loop(
let command_handlers = Arc::clone(&command_handlers);
let callback_handlers = Arc::clone(&callback_handlers);
let stt = stt_service.clone();
let stt_corr = stt_correction.clone();
let voice_ids = Arc::clone(&voice_chat_ids);
tokio::spawn(async move {
handle_update(
Expand All @@ -1594,6 +1617,7 @@ async fn polling_loop(
&command_handlers,
&callback_handlers,
&stt,
stt_corr.as_ref(),
&voice_ids,
)
.await;
Expand Down Expand Up @@ -2366,6 +2390,7 @@ async fn handle_update(
command_handlers: &[Arc<dyn CommandHandler>],
callback_handlers: &[Arc<dyn CallbackHandler>],
stt_service: &Option<rara_stt::SttService>,
stt_correction: Option<&rara_stt::SttCorrectionConfig>,
voice_chat_ids: &Arc<DashSet<i64>>,
) {
// Read a snapshot of the runtime config for this update.
Expand Down Expand Up @@ -2801,11 +2826,25 @@ async fn handle_update(
tracing::info!(len = text.len(), "voice message transcribed");
// Mark this chat so egress replies with a voice note.
voice_chat_ids.insert(chat_id);

// Layer B: optional LLM correction (non-fatal on failure).
let corrected = crate::voice::maybe_correct(
&text,
stt_correction,
Some(handle.driver_registry()),
)
.await;

// Layer A: annotate so the downstream LLM treats the
// text as speech-recognised input that may contain
// mistakes.
let annotated = crate::voice::annotate_voice(&corrected);

let combined = match raw.content {
MessageContent::Text(ref caption) if !caption.trim().is_empty() => {
format!("{caption}\n\n{text}")
format!("{caption}\n\n{annotated}")
}
_ => text,
_ => annotated,
};
RawPlatformMessage {
content: MessageContent::Text(combined),
Expand Down
152 changes: 152 additions & 0 deletions crates/channels/src/voice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright 2025 Rararulab
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Voice transcription post-processing utilities.
//!
//! Two layers of quality improvement for STT output:
//!
//! - **Layer A (annotation)**: Always prepends a hint so the downstream LLM
//! knows the text came from speech recognition and may contain errors.
//! - **Layer B (correction)**: Optionally runs a fast LLM pass to fix obvious
//! transcription mistakes before delivery. Controlled by
//! [`SttCorrectionConfig`].

use rara_kernel::llm::{
DriverRegistryRef,
types::{CompletionRequest, Message, ToolChoice},
};
use rara_stt::SttCorrectionConfig;

/// Prefix prepended to every voice transcription so the downstream LLM
/// interprets the text with appropriate error tolerance.
pub const VOICE_ANNOTATION_PREFIX: &str =
"[Voice transcription \u{2014} may contain errors, interpret by context]";

/// Annotate a transcribed text with the voice-transcription hint.
pub fn annotate_voice(text: &str) -> String { format!("{VOICE_ANNOTATION_PREFIX}\n{text}") }

/// Run an optional LLM correction pass on the raw STT output.
///
/// Returns the corrected text when correction is enabled and succeeds.
/// Falls back to the original `text` on any error — correction failure
/// must never block the message.
pub async fn maybe_correct(
text: &str,
correction: Option<&SttCorrectionConfig>,
driver_registry: Option<&DriverRegistryRef>,
) -> String {
let Some(cfg) = correction.filter(|c| c.enabled) else {
return text.to_owned();
};
let Some(registry) = driver_registry else {
tracing::debug!("STT correction enabled but no driver registry available, skipping");
return text.to_owned();
};

let (driver, model) = match registry.resolve(
"_stt_correction",
cfg.provider.as_deref(),
cfg.model.as_deref(),
) {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(error = %e, "STT correction: failed to resolve LLM driver, skipping");
return text.to_owned();
}
};

let request = CompletionRequest {
model,
messages: vec![
Message::system(
"You are a transcription error corrector. Fix obvious speech recognition \
mistakes. Output only the corrected text.",
),
Message::user(format!(
"Correct any transcription errors in the following voice message. Preserve the \
original meaning. Only fix obvious mistakes. Output the corrected text only, no \
explanation.\n\n{text}"
)),
],
tools: Vec::new(),
temperature: Some(0.1),
max_tokens: Some(4096),
thinking: None,
tool_choice: ToolChoice::None,
parallel_tool_calls: false,
frequency_penalty: None,
top_p: None,
};

match driver.complete(request).await {
Ok(resp) => resp.content.unwrap_or_else(|| text.to_owned()),
Err(e) => {
tracing::warn!(error = %e, "STT correction LLM call failed, using raw transcription");
text.to_owned()
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn annotation_format() {
let result = annotate_voice("hello world");
assert!(result.starts_with(VOICE_ANNOTATION_PREFIX));
assert!(result.ends_with("hello world"));
// Exactly one newline between prefix and text.
let expected = format!("{VOICE_ANNOTATION_PREFIX}\nhello world");
assert_eq!(result, expected);
}

#[test]
fn annotation_preserves_multiline() {
let input = "line one\nline two";
let result = annotate_voice(input);
assert_eq!(
result,
format!("{VOICE_ANNOTATION_PREFIX}\nline one\nline two")
);
}

#[tokio::test]
async fn correction_disabled_returns_original() {
let text = "some text";
// No correction config.
assert_eq!(maybe_correct(text, None, None).await, text);

// Config present but disabled.
let cfg = SttCorrectionConfig {
enabled: false,
model: None,
provider: None,
};
assert_eq!(maybe_correct(text, Some(&cfg), None).await, text);
}

#[tokio::test]
async fn correction_enabled_no_registry_returns_original() {
let cfg = SttCorrectionConfig {
enabled: true,
model: Some("test-model".to_owned()),
provider: Some("test".to_owned()),
};
assert_eq!(
maybe_correct("raw text", Some(&cfg), None).await,
"raw text"
);
}
}
Loading
Loading