Skip to content

Commit b7ca655

Browse files
committed
fix: CI docker build
1 parent b79c442 commit b7ca655

File tree

9 files changed

+1126
-14
lines changed

9 files changed

+1126
-14
lines changed

.dockerignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ target/
99
# Large local artifacts (not needed for image builds)
1010
models/
1111
.plugins/
12+
!plugins/native/pocket-tts/vendor/pocket-tts/src/models/
13+
!plugins/native/pocket-tts/vendor/pocket-tts/src/models/**
1214

1315
# Node modules
1416
node_modules/

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ samples/audio/system/*.flac
5959
samples/audio/system/*.mp3
6060
samples/audio/system/*.m4a
6161
models
62+
!plugins/native/pocket-tts/vendor/pocket-tts/src/models/
63+
!plugins/native/pocket-tts/vendor/pocket-tts/src/models/**
6264

6365
# Example plugin build outputs
6466
/examples/plugins/*/build/

crates/plugin-native/src/wrapper.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,6 @@ impl ProcessorNode for NativeNodeWrapper {
310310

311311
// Move the blocking FFI call to spawn_blocking
312312
let state = Arc::clone(&self.state);
313-
// spawn_blocking can only fail with JoinError if the task panics.
314-
// If that happens, it's a serious bug that should crash.
315-
#[allow(clippy::expect_used)]
316313
let error_msg = tokio::task::spawn_blocking(move || {
317314
let handle = state.begin_call()?;
318315

@@ -338,8 +335,11 @@ impl ProcessorNode for NativeNodeWrapper {
338335
error
339336
})
340337
.await
341-
// spawn_blocking only panics if the task panics, which indicates a serious bug
342-
.expect("Update params task panicked");
338+
.map_err(|e| {
339+
StreamKitError::Runtime(format!(
340+
"Update params task panicked: {e}"
341+
))
342+
})?;
343343

344344
if let Some(err) = error_msg {
345345
warn!(node = %node_name, error = %err, "Parameter update failed");
@@ -369,7 +369,6 @@ impl ProcessorNode for NativeNodeWrapper {
369369
let session_id = context.session_id.clone();
370370
let node_id = node_name.clone();
371371

372-
#[allow(clippy::expect_used)]
373372
let (outputs, error) = tokio::task::spawn_blocking(move || {
374373
let Some(handle) = state.begin_call() else {
375374
return (Vec::new(), None);
@@ -418,7 +417,7 @@ impl ProcessorNode for NativeNodeWrapper {
418417
(outputs, error)
419418
})
420419
.await
421-
.expect("Plugin flush task panicked");
420+
.map_err(|e| StreamKitError::Runtime(format!("Plugin flush task panicked: {e}")))?;
422421

423422
// Send flush outputs
424423
for (pin, pkt) in outputs {
@@ -439,12 +438,7 @@ impl ProcessorNode for NativeNodeWrapper {
439438
let telemetry_tx = context.telemetry_tx.clone();
440439
let session_id = context.session_id.clone();
441440
let node_id = node_name.clone();
442-
// spawn_blocking can only fail with JoinError if the task panics.
443-
// If that happens, it's a serious bug that should crash.
444441
let pin_cstr = Arc::clone(&input_pin_cstrs[pin_index]);
445-
// spawn_blocking can only fail with JoinError if the task panics.
446-
// If that happens, it's a serious bug that should crash.
447-
#[allow(clippy::expect_used)]
448442
let (outputs, error) = tokio::task::spawn_blocking(move || {
449443
let Some(handle) = state.begin_call() else {
450444
return (Vec::new(), None);
@@ -499,8 +493,9 @@ impl ProcessorNode for NativeNodeWrapper {
499493
(outputs, error)
500494
})
501495
.await
502-
// spawn_blocking only panics if the task panics, which indicates a serious bug
503-
.expect("Plugin processing task panicked");
496+
.map_err(|e| {
497+
StreamKitError::Runtime(format!("Plugin processing task panicked: {e}"))
498+
})?;
504499

505500
// Now send outputs (after dropping c_packet and result)
506501
for (pin, pkt) in outputs {

plugins/native/pocket-tts/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ SPDX-License-Identifier: MPL-2.0
77
# Pocket TTS Native Plugin
88

99
A native StreamKit plugin for Kyutai Pocket TTS using the Rust/Candle port.
10+
Upstream Rust port: https://github.com/babybirdprd/pocket-tts
1011
This plugin runs fully on CPU and streams 24kHz mono audio.
1112

1213
## Build
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024 Pocket TTS Contributors
2+
//
3+
// SPDX-License-Identifier: MIT OR Apache-2.0
4+
5+
use crate::ModelState;
6+
use crate::models::transformer::StreamingTransformer;
7+
use crate::modules::mlp::{LayerNorm, ModulationParams, SimpleMLPAdaLN};
8+
use candle_core::{Result, Tensor};
9+
use candle_nn::{Linear, Module, VarBuilder};
10+
11+
pub fn lsd_decode(
12+
flow_net: &SimpleMLPAdaLN,
13+
modulations: &[Vec<ModulationParams>],
14+
x_0: &Tensor,
15+
) -> Result<Tensor> {
16+
let mut current = x_0.clone();
17+
let num_steps = modulations.len();
18+
19+
let step_factor = 1.0 / num_steps as f64;
20+
for step_mod in modulations {
21+
// Use forward_step_cached with pre-computed modulation batch for this ODE step
22+
let flow_dir = flow_net.forward_step_cached(&current, step_mod)?;
23+
current = (current + flow_dir.affine(step_factor, 0.0)?)?;
24+
}
25+
Ok(current)
26+
}
27+
28+
#[derive(Clone)]
29+
pub struct FlowLMModel {
30+
pub flow_net: SimpleMLPAdaLN,
31+
pub transformer: StreamingTransformer,
32+
pub input_linear: Linear,
33+
pub out_norm: LayerNorm,
34+
pub out_eos: Linear,
35+
pub bos_emb: Tensor,
36+
pub emb_mean: Tensor,
37+
pub emb_std: Tensor,
38+
pub ldim: usize,
39+
pub dim: usize,
40+
pub noise_clamp: Option<f32>,
41+
}
42+
43+
fn sample_noise(
44+
device: &candle_core::Device,
45+
shape: (usize, usize),
46+
temp: f32,
47+
clamp: Option<f32>,
48+
) -> Result<Tensor> {
49+
let std = temp.sqrt();
50+
match clamp {
51+
None => Tensor::randn(0.0f32, std, shape, device),
52+
Some(limit) => {
53+
// Rejection sampling for truncated normal
54+
let count = shape.0 * shape.1;
55+
let mut data = Vec::with_capacity(count);
56+
let mut rng = rand::thread_rng();
57+
let dist = rand_distr::Normal::new(0.0f32, std)
58+
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
59+
60+
while data.len() < count {
61+
let v = rand_distr::Distribution::sample(&dist, &mut rng);
62+
if v.abs() <= limit {
63+
data.push(v);
64+
}
65+
}
66+
Tensor::from_vec(data, shape, device)
67+
}
68+
}
69+
}
70+
71+
impl FlowLMModel {
72+
pub fn new(
73+
flow_net: SimpleMLPAdaLN,
74+
transformer: StreamingTransformer,
75+
ldim: usize,
76+
dim: usize,
77+
vb: VarBuilder,
78+
) -> Result<Self> {
79+
let input_linear = candle_nn::linear_no_bias(ldim, dim, vb.pp("input_linear"))?;
80+
let out_norm = LayerNorm::new(dim, 1e-5, true, vb.pp("out_norm"))?;
81+
let out_eos = candle_nn::linear(dim, 1, vb.pp("out_eos"))?;
82+
let bos_emb = vb.get(ldim, "bos_emb")?;
83+
let emb_mean = vb.get(ldim, "emb_mean")?;
84+
let emb_std = vb.get(ldim, "emb_std")?;
85+
86+
Ok(Self {
87+
flow_net,
88+
transformer,
89+
input_linear,
90+
out_norm,
91+
out_eos,
92+
bos_emb,
93+
emb_mean,
94+
emb_std,
95+
ldim,
96+
dim,
97+
noise_clamp: None, // Default to no clamp
98+
})
99+
}
100+
101+
#[allow(clippy::too_many_arguments)]
102+
pub fn forward(
103+
&self,
104+
sequence: &Tensor,
105+
text_embeddings: &Tensor,
106+
model_state: &mut ModelState,
107+
time_embeddings: &Tensor,
108+
temp: f32,
109+
eos_threshold: f32,
110+
step: usize,
111+
) -> Result<(Tensor, bool)> {
112+
// sequence is [B, T, ldim]
113+
// text_embeddings is [B, S, dim]
114+
115+
// Handle BOS (if NaN, use bos_emb) - simplistic check for NaN
116+
// In Candle we can use `Tensor::where_cond`
117+
// But for now let's assume sequence passed in doesn't have NaNs or handled upstream.
118+
// Original: sequence = torch.where(torch.isnan(sequence), self.bos_emb, sequence)
119+
120+
// Let's assume BOS is handled by caller for now or if sequence empty.
121+
122+
let x = self.input_linear.forward(sequence)?;
123+
let s_len = text_embeddings.dims()[1];
124+
125+
// Cat text embeddings and sequence embeddings only if text_embeddings is not empty
126+
let transformer_out_pre_norm = if s_len > 0 {
127+
let input = Tensor::cat(&[text_embeddings, &x], 1)?;
128+
let mut out = self.transformer.forward(&input, model_state, step)?;
129+
// Remove prefix (text embeddings length)
130+
out = out.narrow(1, s_len, out.dims()[1] - s_len)?;
131+
out
132+
} else {
133+
self.transformer.forward(&x, model_state, step)?
134+
};
135+
136+
let transformer_out = self.out_norm.forward(&transformer_out_pre_norm)?;
137+
138+
// Only use the last frame for generation
139+
let last_frame = transformer_out
140+
.narrow(1, transformer_out.dims()[1] - 1, 1)?
141+
.squeeze(1)?;
142+
143+
let eos_score = self
144+
.out_eos
145+
.forward(&last_frame)?
146+
.squeeze(0)?
147+
.squeeze(0)?
148+
.to_scalar::<f32>()?;
149+
let is_eos = eos_score > eos_threshold;
150+
151+
// Generate noise with optional clamping
152+
let noise = sample_noise(
153+
last_frame.device(),
154+
(last_frame.dims()[0], self.ldim),
155+
temp,
156+
self.noise_clamp,
157+
)?;
158+
159+
// Pre-compute all modulations for this frame's ODE steps (8 steps * N blocks) in batch
160+
let c_emb = self.flow_net.embed_condition(&last_frame)?;
161+
let modulations = self
162+
.flow_net
163+
.precompute_modulations(&c_emb, time_embeddings)?;
164+
165+
let next_latent = lsd_decode(&self.flow_net, &modulations, &noise)?;
166+
167+
Ok((next_latent, is_eos))
168+
}
169+
}

0 commit comments

Comments
 (0)