|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2024 Pocket TTS Contributors |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: MIT OR Apache-2.0 |
| 4 | + |
| 5 | +use crate::ModelState; |
| 6 | +use crate::models::transformer::StreamingTransformer; |
| 7 | +use crate::modules::mlp::{LayerNorm, ModulationParams, SimpleMLPAdaLN}; |
| 8 | +use candle_core::{Result, Tensor}; |
| 9 | +use candle_nn::{Linear, Module, VarBuilder}; |
| 10 | + |
| 11 | +pub fn lsd_decode( |
| 12 | + flow_net: &SimpleMLPAdaLN, |
| 13 | + modulations: &[Vec<ModulationParams>], |
| 14 | + x_0: &Tensor, |
| 15 | +) -> Result<Tensor> { |
| 16 | + let mut current = x_0.clone(); |
| 17 | + let num_steps = modulations.len(); |
| 18 | + |
| 19 | + let step_factor = 1.0 / num_steps as f64; |
| 20 | + for step_mod in modulations { |
| 21 | + // Use forward_step_cached with pre-computed modulation batch for this ODE step |
| 22 | + let flow_dir = flow_net.forward_step_cached(¤t, step_mod)?; |
| 23 | + current = (current + flow_dir.affine(step_factor, 0.0)?)?; |
| 24 | + } |
| 25 | + Ok(current) |
| 26 | +} |
| 27 | + |
| 28 | +#[derive(Clone)] |
| 29 | +pub struct FlowLMModel { |
| 30 | + pub flow_net: SimpleMLPAdaLN, |
| 31 | + pub transformer: StreamingTransformer, |
| 32 | + pub input_linear: Linear, |
| 33 | + pub out_norm: LayerNorm, |
| 34 | + pub out_eos: Linear, |
| 35 | + pub bos_emb: Tensor, |
| 36 | + pub emb_mean: Tensor, |
| 37 | + pub emb_std: Tensor, |
| 38 | + pub ldim: usize, |
| 39 | + pub dim: usize, |
| 40 | + pub noise_clamp: Option<f32>, |
| 41 | +} |
| 42 | + |
| 43 | +fn sample_noise( |
| 44 | + device: &candle_core::Device, |
| 45 | + shape: (usize, usize), |
| 46 | + temp: f32, |
| 47 | + clamp: Option<f32>, |
| 48 | +) -> Result<Tensor> { |
| 49 | + let std = temp.sqrt(); |
| 50 | + match clamp { |
| 51 | + None => Tensor::randn(0.0f32, std, shape, device), |
| 52 | + Some(limit) => { |
| 53 | + // Rejection sampling for truncated normal |
| 54 | + let count = shape.0 * shape.1; |
| 55 | + let mut data = Vec::with_capacity(count); |
| 56 | + let mut rng = rand::thread_rng(); |
| 57 | + let dist = rand_distr::Normal::new(0.0f32, std) |
| 58 | + .map_err(|e| candle_core::Error::Msg(e.to_string()))?; |
| 59 | + |
| 60 | + while data.len() < count { |
| 61 | + let v = rand_distr::Distribution::sample(&dist, &mut rng); |
| 62 | + if v.abs() <= limit { |
| 63 | + data.push(v); |
| 64 | + } |
| 65 | + } |
| 66 | + Tensor::from_vec(data, shape, device) |
| 67 | + } |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +impl FlowLMModel { |
| 72 | + pub fn new( |
| 73 | + flow_net: SimpleMLPAdaLN, |
| 74 | + transformer: StreamingTransformer, |
| 75 | + ldim: usize, |
| 76 | + dim: usize, |
| 77 | + vb: VarBuilder, |
| 78 | + ) -> Result<Self> { |
| 79 | + let input_linear = candle_nn::linear_no_bias(ldim, dim, vb.pp("input_linear"))?; |
| 80 | + let out_norm = LayerNorm::new(dim, 1e-5, true, vb.pp("out_norm"))?; |
| 81 | + let out_eos = candle_nn::linear(dim, 1, vb.pp("out_eos"))?; |
| 82 | + let bos_emb = vb.get(ldim, "bos_emb")?; |
| 83 | + let emb_mean = vb.get(ldim, "emb_mean")?; |
| 84 | + let emb_std = vb.get(ldim, "emb_std")?; |
| 85 | + |
| 86 | + Ok(Self { |
| 87 | + flow_net, |
| 88 | + transformer, |
| 89 | + input_linear, |
| 90 | + out_norm, |
| 91 | + out_eos, |
| 92 | + bos_emb, |
| 93 | + emb_mean, |
| 94 | + emb_std, |
| 95 | + ldim, |
| 96 | + dim, |
| 97 | + noise_clamp: None, // Default to no clamp |
| 98 | + }) |
| 99 | + } |
| 100 | + |
| 101 | + #[allow(clippy::too_many_arguments)] |
| 102 | + pub fn forward( |
| 103 | + &self, |
| 104 | + sequence: &Tensor, |
| 105 | + text_embeddings: &Tensor, |
| 106 | + model_state: &mut ModelState, |
| 107 | + time_embeddings: &Tensor, |
| 108 | + temp: f32, |
| 109 | + eos_threshold: f32, |
| 110 | + step: usize, |
| 111 | + ) -> Result<(Tensor, bool)> { |
| 112 | + // sequence is [B, T, ldim] |
| 113 | + // text_embeddings is [B, S, dim] |
| 114 | + |
| 115 | + // Handle BOS (if NaN, use bos_emb) - simplistic check for NaN |
| 116 | + // In Candle we can use `Tensor::where_cond` |
| 117 | + // But for now let's assume sequence passed in doesn't have NaNs or handled upstream. |
| 118 | + // Original: sequence = torch.where(torch.isnan(sequence), self.bos_emb, sequence) |
| 119 | + |
| 120 | + // Let's assume BOS is handled by caller for now or if sequence empty. |
| 121 | + |
| 122 | + let x = self.input_linear.forward(sequence)?; |
| 123 | + let s_len = text_embeddings.dims()[1]; |
| 124 | + |
| 125 | + // Cat text embeddings and sequence embeddings only if text_embeddings is not empty |
| 126 | + let transformer_out_pre_norm = if s_len > 0 { |
| 127 | + let input = Tensor::cat(&[text_embeddings, &x], 1)?; |
| 128 | + let mut out = self.transformer.forward(&input, model_state, step)?; |
| 129 | + // Remove prefix (text embeddings length) |
| 130 | + out = out.narrow(1, s_len, out.dims()[1] - s_len)?; |
| 131 | + out |
| 132 | + } else { |
| 133 | + self.transformer.forward(&x, model_state, step)? |
| 134 | + }; |
| 135 | + |
| 136 | + let transformer_out = self.out_norm.forward(&transformer_out_pre_norm)?; |
| 137 | + |
| 138 | + // Only use the last frame for generation |
| 139 | + let last_frame = transformer_out |
| 140 | + .narrow(1, transformer_out.dims()[1] - 1, 1)? |
| 141 | + .squeeze(1)?; |
| 142 | + |
| 143 | + let eos_score = self |
| 144 | + .out_eos |
| 145 | + .forward(&last_frame)? |
| 146 | + .squeeze(0)? |
| 147 | + .squeeze(0)? |
| 148 | + .to_scalar::<f32>()?; |
| 149 | + let is_eos = eos_score > eos_threshold; |
| 150 | + |
| 151 | + // Generate noise with optional clamping |
| 152 | + let noise = sample_noise( |
| 153 | + last_frame.device(), |
| 154 | + (last_frame.dims()[0], self.ldim), |
| 155 | + temp, |
| 156 | + self.noise_clamp, |
| 157 | + )?; |
| 158 | + |
| 159 | + // Pre-compute all modulations for this frame's ODE steps (8 steps * N blocks) in batch |
| 160 | + let c_emb = self.flow_net.embed_condition(&last_frame)?; |
| 161 | + let modulations = self |
| 162 | + .flow_net |
| 163 | + .precompute_modulations(&c_emb, time_embeddings)?; |
| 164 | + |
| 165 | + let next_latent = lsd_decode(&self.flow_net, &modulations, &noise)?; |
| 166 | + |
| 167 | + Ok((next_latent, is_eos)) |
| 168 | + } |
| 169 | +} |
0 commit comments