From d215cd2c42021d2cae00a8747e9bb5f0fb1031bb Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Mon, 1 Dec 2025 01:19:08 +0200 Subject: [PATCH] fix: implement proper text generation inference --- examples/text_generation_advanced.rs | 111 ++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 18 deletions(-) diff --git a/examples/text_generation_advanced.rs b/examples/text_generation_advanced.rs index 6c65a8f..2f02c01 100644 --- a/examples/text_generation_advanced.rs +++ b/examples/text_generation_advanced.rs @@ -36,7 +36,7 @@ impl CharacterLSTM { // Create network: embedding_size input -> hidden_size (single layer) let network = LSTMNetwork::new(embedding_size, hidden_size, 1); - println!("šŸ“š Built vocabulary: {} unique characters", vocab_size); + println!("Built vocabulary: {} unique characters", vocab_size); println!("Characters: {:?}", chars.iter().take(20).collect::>()); println!("Network: {} -> {} -> {}", embedding_size, hidden_size, embedding_size); @@ -197,18 +197,18 @@ impl CharacterLSTM { /// Train the character-level language model fn train(&mut self, text: &str, epochs: usize, validation_split: f64) { - println!("šŸ”¤ Creating character sequences from text..."); + println!("Creating character sequences from text..."); let sequences = self.create_sequences(text); if sequences.is_empty() { - println!("āŒ No training sequences created!"); + println!("No training sequences created!"); return; } let split_idx = ((sequences.len() as f64) * (1.0 - validation_split)) as usize; let (train_data, val_data) = sequences.split_at(split_idx); - println!("šŸ“– Training on {} sequences, validating on {} sequences", + println!("Training on {} sequences, validating on {} sequences", train_data.len(), val_data.len()); // Create trainer with MSE loss for embedding regression @@ -228,18 +228,15 @@ impl CharacterLSTM { trainer.train(train_data, if val_data.is_empty() { None } else { Some(val_data) }); self.trainer = Some(trainer); - println!("āœ… Character LSTM training completed!"); + println!("Character LSTM training completed!"); } /// Generate text starting with a seed string - fn generate_text(&self, seed: &str, length: usize, temperature: f64) -> String { - let trainer = match &self.trainer { - Some(trainer) => trainer, - None => { - println!("āŒ Model not trained yet!"); - return String::new(); - } - }; + fn generate_text(&mut self, seed: &str, length: usize, temperature: f64) -> String { + if self.trainer.is_none() { + println!("Model not trained yet!"); + return String::new(); + } let mut generated = seed.to_string(); let mut current_sequence: Vec = seed.chars().collect(); @@ -249,6 +246,16 @@ impl CharacterLSTM { current_sequence.insert(0, ' '); // Pad with spaces } + let network = if let Some(ref trainer) = self.trainer { + &trainer.network + } else { + println!("Trainer not available"); + return generated; + }; + + let mut inference_network = network.clone(); + inference_network.eval(); + for _ in 0..length { // Prepare input sequence let start_idx = current_sequence.len().saturating_sub(self.sequence_length); @@ -258,8 +265,23 @@ impl CharacterLSTM { .map(|&ch| self.char_to_embedding(ch)) .collect(); - println!("āš ļø Text generation limited due to trainer interface constraints"); - break; + let (outputs, _) = inference_network.forward_sequence_with_cache(&inputs); + + if let Some((last_output, _)) = outputs.last() { + let predicted_embedding = self.project_to_embedding(last_output); + + let next_char = self.sample_next_char(&predicted_embedding, temperature); + + generated.push(next_char); + current_sequence.push(next_char); + + if current_sequence.len() > self.sequence_length * 2 { + current_sequence.drain(0..self.sequence_length); + } + } else { + println!("No prediction generated, stopping text generation"); + break; + } } generated @@ -278,6 +300,59 @@ impl CharacterLSTM { Array2::from_shape_vec((self.embedding_size, 1), embedding).unwrap() } + + /// Sample next character using temperature-based sampling + fn sample_next_char(&self, predicted_embedding: &Array2, temperature: f64) -> char { + let mut similarities = Vec::new(); + + for (&ch, &_idx) in &self.char_to_idx { + let char_embedding = self.char_to_embedding(ch); + + let dot_product = predicted_embedding.iter() + .zip(char_embedding.iter()) + .map(|(a, b)| a * b) + .sum::(); + + let pred_norm = predicted_embedding.iter().map(|x| x * x).sum::().sqrt(); + let char_norm = char_embedding.iter().map(|x| x * x).sum::().sqrt(); + + let similarity = if pred_norm > 0.0 && char_norm > 0.0 { + dot_product / (pred_norm * char_norm) + } else { + 0.0 + }; + + similarities.push((ch, similarity)); + } + + let max_similarity = similarities.iter().map(|(_, s)| *s).fold(f64::NEG_INFINITY, f64::max); + let mut probabilities = Vec::new(); + let mut total_prob = 0.0; + + for (ch, similarity) in &similarities { + let scaled_similarity = (similarity - max_similarity) / temperature; + let prob = scaled_similarity.exp(); + probabilities.push((*ch, prob)); + total_prob += prob; + } + + for (_, prob) in &mut probabilities { + *prob /= total_prob; + } + + let random_value: f64 = rand::random(); + let mut cumulative_prob = 0.0; + + for (ch, prob) in probabilities { + cumulative_prob += prob; + if random_value <= cumulative_prob { + return ch; + } + } + + similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + similarities[0].0 + } } /// Sample training texts for different domains @@ -297,20 +372,20 @@ fn get_sample_texts() -> HashMap<&'static str, &'static str> { } fn main() { - println!("šŸ“ Advanced Text Generation with Character-Level LSTM"); + println!("Advanced Text Generation with Character-Level LSTM"); println!("===================================================\n"); let sample_texts = get_sample_texts(); for (domain, text) in &sample_texts { - println!("šŸŽ­ Training {} model...", domain); + println!("Training {} model...", domain); println!("Training text preview: {}...\n", &text[..text.len().min(100)]); // Create and train model with embedding let mut model = CharacterLSTM::new(text, 8, 32, 16); // 8-char sequences, 32 hidden, 16 embedding model.train(text, 8, 0.1); // 8 epochs for quick demo, 10% validation - println!("\nšŸŽ² Generating text samples:"); + println!("\nGenerating text samples:"); // Generate with different temperatures let temperatures = [0.8, 1.2];