Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 93 additions & 18 deletions examples/text_generation_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl CharacterLSTM {
// Create network: embedding_size input -> hidden_size (single layer)
let network = LSTMNetwork::new(embedding_size, hidden_size, 1);

println!("📚 Built vocabulary: {} unique characters", vocab_size);
println!("Built vocabulary: {} unique characters", vocab_size);
println!("Characters: {:?}", chars.iter().take(20).collect::<Vec<_>>());
println!("Network: {} -> {} -> {}", embedding_size, hidden_size, embedding_size);

Expand Down Expand Up @@ -197,18 +197,18 @@ impl CharacterLSTM {

/// Train the character-level language model
fn train(&mut self, text: &str, epochs: usize, validation_split: f64) {
println!("🔤 Creating character sequences from text...");
println!("Creating character sequences from text...");
let sequences = self.create_sequences(text);

if sequences.is_empty() {
println!("No training sequences created!");
println!("No training sequences created!");
return;
}

let split_idx = ((sequences.len() as f64) * (1.0 - validation_split)) as usize;
let (train_data, val_data) = sequences.split_at(split_idx);

println!("📖 Training on {} sequences, validating on {} sequences",
println!("Training on {} sequences, validating on {} sequences",
train_data.len(), val_data.len());

// Create trainer with MSE loss for embedding regression
Expand All @@ -228,18 +228,15 @@ impl CharacterLSTM {
trainer.train(train_data, if val_data.is_empty() { None } else { Some(val_data) });

self.trainer = Some(trainer);
println!("Character LSTM training completed!");
println!("Character LSTM training completed!");
}

/// Generate text starting with a seed string
fn generate_text(&self, seed: &str, length: usize, temperature: f64) -> String {
let trainer = match &self.trainer {
Some(trainer) => trainer,
None => {
println!("❌ Model not trained yet!");
return String::new();
}
};
fn generate_text(&mut self, seed: &str, length: usize, temperature: f64) -> String {
if self.trainer.is_none() {
println!("Model not trained yet!");
return String::new();
}

let mut generated = seed.to_string();
let mut current_sequence: Vec<char> = seed.chars().collect();
Expand All @@ -249,6 +246,16 @@ impl CharacterLSTM {
current_sequence.insert(0, ' '); // Pad with spaces
}

let network = if let Some(ref trainer) = self.trainer {
&trainer.network
} else {
println!("Trainer not available");
return generated;
};

let mut inference_network = network.clone();
inference_network.eval();

for _ in 0..length {
// Prepare input sequence
let start_idx = current_sequence.len().saturating_sub(self.sequence_length);
Expand All @@ -258,8 +265,23 @@ impl CharacterLSTM {
.map(|&ch| self.char_to_embedding(ch))
.collect();

println!("⚠️ Text generation limited due to trainer interface constraints");
break;
let (outputs, _) = inference_network.forward_sequence_with_cache(&inputs);

if let Some((last_output, _)) = outputs.last() {
let predicted_embedding = self.project_to_embedding(last_output);

let next_char = self.sample_next_char(&predicted_embedding, temperature);

generated.push(next_char);
current_sequence.push(next_char);

if current_sequence.len() > self.sequence_length * 2 {
current_sequence.drain(0..self.sequence_length);
}
} else {
println!("No prediction generated, stopping text generation");
break;
}
}

generated
Expand All @@ -278,6 +300,59 @@ impl CharacterLSTM {

Array2::from_shape_vec((self.embedding_size, 1), embedding).unwrap()
}

/// Sample next character using temperature-based sampling
fn sample_next_char(&self, predicted_embedding: &Array2<f64>, temperature: f64) -> char {
let mut similarities = Vec::new();

for (&ch, &_idx) in &self.char_to_idx {
let char_embedding = self.char_to_embedding(ch);

let dot_product = predicted_embedding.iter()
.zip(char_embedding.iter())
.map(|(a, b)| a * b)
.sum::<f64>();

let pred_norm = predicted_embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
let char_norm = char_embedding.iter().map(|x| x * x).sum::<f64>().sqrt();

let similarity = if pred_norm > 0.0 && char_norm > 0.0 {
dot_product / (pred_norm * char_norm)
} else {
0.0
};

similarities.push((ch, similarity));
}

let max_similarity = similarities.iter().map(|(_, s)| *s).fold(f64::NEG_INFINITY, f64::max);
let mut probabilities = Vec::new();
let mut total_prob = 0.0;

for (ch, similarity) in &similarities {
let scaled_similarity = (similarity - max_similarity) / temperature;
let prob = scaled_similarity.exp();
probabilities.push((*ch, prob));
total_prob += prob;
}

for (_, prob) in &mut probabilities {
*prob /= total_prob;
}

let random_value: f64 = rand::random();
let mut cumulative_prob = 0.0;

for (ch, prob) in probabilities {
cumulative_prob += prob;
if random_value <= cumulative_prob {
return ch;
}
}

similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities[0].0
}
}

/// Sample training texts for different domains
Expand All @@ -297,20 +372,20 @@ fn get_sample_texts() -> HashMap<&'static str, &'static str> {
}

fn main() {
println!("📝 Advanced Text Generation with Character-Level LSTM");
println!("Advanced Text Generation with Character-Level LSTM");
println!("===================================================\n");

let sample_texts = get_sample_texts();

for (domain, text) in &sample_texts {
println!("🎭 Training {} model...", domain);
println!("Training {} model...", domain);
println!("Training text preview: {}...\n", &text[..text.len().min(100)]);

// Create and train model with embedding
let mut model = CharacterLSTM::new(text, 8, 32, 16); // 8-char sequences, 32 hidden, 16 embedding
model.train(text, 8, 0.1); // 8 epochs for quick demo, 10% validation

println!("\n🎲 Generating text samples:");
println!("\nGenerating text samples:");

// Generate with different temperatures
let temperatures = [0.8, 1.2];
Expand Down