From 37d3a518dc72e58cea7c94995f122b45cc8c61e3 Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Fri, 15 Aug 2025 02:57:03 +0300 Subject: [PATCH] feat: Add batch processing capabilities - Implement batch forward/backward passes for LSTM cells - Add LSTMBatchTrainer for efficient batch training - Add batch processing support to LSTM networks - Implement batch prediction capabilities - Add batch processing example - Maintain backward compatibility with existing API --- Cargo.toml | 6 +- README.md | 2 +- examples/batch_processing_example.rs | 230 +++++++++++++++++++ src/layers/lstm_cell.rs | 270 ++++++++++++++++++++++ src/lib.rs | 9 +- src/loss.rs | 52 +++++ src/models/lstm_network.rs | 155 ++++++++++++- src/training.rs | 332 ++++++++++++++++++++++++++- 8 files changed, 1042 insertions(+), 14 deletions(-) create mode 100644 examples/batch_processing_example.rs diff --git a/Cargo.toml b/Cargo.toml index 95a446e..b12b214 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-lstm" -version = "0.4.0" +version = "0.5.0" authors = ["Alex Kholodniak "] edition = "2021" rust-version = "1.70" @@ -84,3 +84,7 @@ path = "examples/time_series_prediction.rs" [[example]] name = "multi_layer_lstm" path = "examples/multi_layer_lstm.rs" + +[[example]] +name = "batch_processing_example" +path = "examples/batch_processing_example.rs" diff --git a/README.md b/README.md index 94f03ac..8380e78 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Add to your `Cargo.toml`: ```toml [dependencies] -rust-lstm = "0.4.0" +rust-lstm = "0.5.0" ``` ### Basic Usage diff --git a/examples/batch_processing_example.rs b/examples/batch_processing_example.rs new file mode 100644 index 0000000..a1147a5 --- /dev/null +++ b/examples/batch_processing_example.rs @@ -0,0 +1,230 @@ +use ndarray::{Array2, arr2}; +use rust_lstm::{LSTMNetwork, create_adam_batch_trainer, create_basic_trainer}; +use std::time::Instant; + +/// Generate synthetic sine wave sequences for batch processing demonstration +fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_size: usize) -> Vec<(Vec>, Vec>)> { + let mut data = Vec::new(); + + for i in 0..num_sequences { + let mut inputs = Vec::new(); + let mut targets = Vec::new(); + + let start = (i as f64) * 0.05; // Different starting points for variety + let frequency = 1.0 + (i as f64) * 0.1; // Different frequencies + + for j in 0..sequence_length { + let t = start + (j as f64) * 0.1; + + // Create multi-dimensional input + let mut input_vec = vec![0.0; input_size]; + input_vec[0] = (t * frequency * 2.0 * std::f64::consts::PI).sin(); + if input_size > 1 { + input_vec[1] = (t * frequency * 2.0 * std::f64::consts::PI).cos(); + } + if input_size > 2 { + input_vec[2] = t.sin() * t.cos(); // Some nonlinear combination + } + + // Target is the next value in the sine sequence + let target = ((t + 0.1) * frequency * 2.0 * std::f64::consts::PI).sin(); + + inputs.push(Array2::from_shape_vec((input_size, 1), input_vec).unwrap()); + targets.push(arr2(&[[target]])); + } + + data.push((inputs, targets)); + } + + data +} + +/// Benchmark training time comparison between single and batch processing +fn benchmark_training_performance() { + println!("BATCH PROCESSING PERFORMANCE BENCHMARK"); + println!("======================================\n"); + + let input_size = 3; + let hidden_size = 16; + let num_layers = 2; + let learning_rate = 0.001; + + // Generate training data + let train_data = generate_batch_sine_data(100, 10, input_size); + let val_data = generate_batch_sine_data(20, 10, input_size); + + println!("Dataset: {} training sequences, {} validation sequences", train_data.len(), val_data.len()); + println!("Network: {} -> {} hidden ({} layers)\n", input_size, hidden_size, num_layers); + + // Test 1: Single sequence processing (traditional) + println!("Testing Traditional Single-Sequence Processing..."); + let network1 = LSTMNetwork::new(input_size, hidden_size, num_layers); + let mut trainer1 = create_basic_trainer(network1, learning_rate); + + // Configure for quick demo + trainer1.config.epochs = 5; + trainer1.config.print_every = 1; + + let start_time = Instant::now(); + trainer1.train(&train_data, Some(&val_data)); + let single_time = start_time.elapsed(); + + let final_metrics1 = trainer1.get_latest_metrics().unwrap(); + println!("Single-sequence - Final loss: {:.6}, Time: {:.2}s\n", + final_metrics1.train_loss, single_time.as_secs_f64()); + + // Test 2: Batch processing with small batches + println!("Testing Batch Processing (batch size 8)..."); + let network2 = LSTMNetwork::new(input_size, hidden_size, num_layers); + let mut trainer2 = create_adam_batch_trainer(network2, learning_rate); + + trainer2.config.epochs = 5; + trainer2.config.print_every = 1; + + let start_time = Instant::now(); + trainer2.train(&train_data, Some(&val_data), 8); // Batch size 8 + let batch_time = start_time.elapsed(); + + let final_metrics2 = trainer2.get_latest_metrics().unwrap(); + println!("Batch processing - Final loss: {:.6}, Time: {:.2}s\n", + final_metrics2.train_loss, batch_time.as_secs_f64()); + + // Test 3: Larger batch size + println!("Testing Larger Batch Processing (batch size 16)..."); + let network3 = LSTMNetwork::new(input_size, hidden_size, num_layers); + let mut trainer3 = create_adam_batch_trainer(network3, learning_rate); + + trainer3.config.epochs = 5; + trainer3.config.print_every = 1; + + let start_time = Instant::now(); + trainer3.train(&train_data, Some(&val_data), 16); // Batch size 16 + let large_batch_time = start_time.elapsed(); + + let final_metrics3 = trainer3.get_latest_metrics().unwrap(); + println!("Large batch processing - Final loss: {:.6}, Time: {:.2}s\n", + final_metrics3.train_loss, large_batch_time.as_secs_f64()); + + // Performance summary + println!("PERFORMANCE SUMMARY:"); + println!("======================"); + println!("Single-sequence: {:.2}s (baseline)", single_time.as_secs_f64()); + println!("Batch-8: {:.2}s ({:.1}x speedup)", + batch_time.as_secs_f64(), + single_time.as_secs_f64() / batch_time.as_secs_f64()); + println!("Batch-16: {:.2}s ({:.1}x speedup)", + large_batch_time.as_secs_f64(), + single_time.as_secs_f64() / large_batch_time.as_secs_f64()); + + if batch_time < single_time { + println!("Batch processing achieved {:.1}x speedup!", + single_time.as_secs_f64() / batch_time.as_secs_f64()); + } else { + println!("Note: For small datasets, overhead may dominate. Try larger datasets for better speedup."); + } +} + +/// Demonstrate batch prediction capabilities +fn demonstrate_batch_prediction() { + println!("\nBATCH PREDICTION DEMONSTRATION"); + println!("==============================\n"); + + let input_size = 2; + let hidden_size = 8; + let num_layers = 1; + + // Create and train a simple model + let network = LSTMNetwork::new(input_size, hidden_size, num_layers); + let mut trainer = create_adam_batch_trainer(network, 0.01); + + // Generate small training dataset + let train_data = generate_batch_sine_data(20, 5, input_size); + + trainer.config.epochs = 10; + trainer.config.print_every = 5; + + println!("Training a small model for prediction demo..."); + trainer.train(&train_data, None, 4); + + // Create test sequences for batch prediction + let test_sequences = generate_batch_sine_data(3, 3, input_size); + let test_inputs: Vec<_> = test_sequences.iter().map(|(inputs, _)| inputs.clone()).collect(); + let _test_targets: Vec<_> = test_sequences.iter().map(|(_, targets)| targets.clone()).collect(); + + println!("\nPerforming batch predictions..."); + let predictions = trainer.predict_batch(&test_inputs); + + println!("Input sequences vs Predictions:"); + for (i, (inputs, preds)) in test_inputs.iter().zip(predictions.iter()).enumerate() { + println!("Sequence {}:", i + 1); + for (j, (input, pred)) in inputs.iter().zip(preds.iter()).enumerate() { + println!(" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}", + j + 1, input[[0, 0]], input[[1, 0]], pred[[0, 0]]); + } + println!(); + } +} + +/// Demonstrate memory efficiency and scalability +fn demonstrate_scalability() { + println!("SCALABILITY DEMONSTRATION"); + println!("=========================\n"); + + let test_sizes = vec![ + (50, 4), // Small: 50 sequences, batch size 4 + (200, 8), // Medium: 200 sequences, batch size 8 + (500, 16), // Large: 500 sequences, batch size 16 + ]; + + for (num_sequences, batch_size) in test_sizes { + println!("Testing with {} sequences, batch size {}...", num_sequences, batch_size); + + let train_data = generate_batch_sine_data(num_sequences, 8, 2); + let network = LSTMNetwork::new(2, 12, 1); + let mut trainer = create_adam_batch_trainer(network, 0.001); + + trainer.config.epochs = 3; + trainer.config.print_every = 1; + + let start_time = Instant::now(); + trainer.train(&train_data, None, batch_size); + let training_time = start_time.elapsed(); + + let final_loss = trainer.get_latest_metrics().unwrap().train_loss; + println!(" Completed in {:.2}s, final loss: {:.6}\n", + training_time.as_secs_f64(), final_loss); + } + + println!("All scalability tests completed successfully!"); + println!("Batch processing handles varying dataset sizes efficiently."); +} + +fn main() { + println!("RUST-LSTM BATCH PROCESSING DEMONSTRATION"); + println!("=========================================\n"); + + println!("This example demonstrates the new batch processing capabilities:"); + println!("- Simultaneous processing of multiple sequences"); + println!("- Performance improvements over single-sequence training"); + println!("- Batch prediction capabilities"); + println!("- Scalability with different batch sizes\n"); + + benchmark_training_performance(); + demonstrate_batch_prediction(); + demonstrate_scalability(); + + println!("\nBATCH PROCESSING DEMONSTRATION COMPLETED!"); + println!("=========================================="); + println!("Key Benefits Demonstrated:"); + println!("- Faster training through batch processing"); + println!("- Efficient memory utilization"); + println!("- Scalable to different dataset sizes"); + println!("- Easy-to-use batch training API"); + println!("- Backward compatibility with existing code"); + + println!("\nNext Steps:"); + println!("- Try batch processing with your own datasets"); + println!("- Experiment with different batch sizes"); + println!("- Compare performance with single-sequence training"); + println!("- Use batch processing for faster model development"); +} \ No newline at end of file diff --git a/src/layers/lstm_cell.rs b/src/layers/lstm_cell.rs index cfdebbd..b0d2601 100644 --- a/src/layers/lstm_cell.rs +++ b/src/layers/lstm_cell.rs @@ -31,6 +31,25 @@ pub struct LSTMCellCache { pub output_dropout_mask: Option>, } +/// Batch cache for multiple sequences processed simultaneously +#[derive(Clone)] +pub struct LSTMCellBatchCache { + pub input: Array2, + pub hx: Array2, + pub cx: Array2, + pub gates: Array2, + pub input_gate: Array2, + pub forget_gate: Array2, + pub cell_gate: Array2, + pub output_gate: Array2, + pub cy: Array2, + pub hy: Array2, + pub input_dropout_mask: Option>, + pub recurrent_dropout_mask: Option>, + pub output_dropout_mask: Option>, + pub batch_size: usize, +} + /// LSTM cell with trainable parameters and dropout support #[derive(Clone)] pub struct LSTMCell { @@ -199,6 +218,172 @@ impl LSTMCell { (hy_final, cy, cache) } + /// Batch forward pass for multiple sequences simultaneously + /// + /// # Arguments + /// * `input` - Input tensor of shape (input_size, batch_size) + /// * `hx` - Hidden state tensor of shape (hidden_size, batch_size) + /// * `cx` - Cell state tensor of shape (hidden_size, batch_size) + /// + /// # Returns + /// * Tuple of (new_hidden_state, new_cell_state) with same batch dimensions + pub fn forward_batch(&mut self, input: &Array2, hx: &Array2, cx: &Array2) -> (Array2, Array2) { + let batch_size = input.ncols(); + assert_eq!(hx.ncols(), batch_size, "Hidden state batch size must match input batch size"); + assert_eq!(cx.ncols(), batch_size, "Cell state batch size must match input batch size"); + assert_eq!(input.nrows(), self.w_ih.ncols(), "Input feature size must match weight matrix"); + assert_eq!(hx.nrows(), self.hidden_size, "Hidden state size must match network hidden size"); + assert_eq!(cx.nrows(), self.hidden_size, "Cell state size must match network hidden size"); + + // Apply input dropout across the entire batch + let (input_dropped, _input_mask) = if let Some(ref mut dropout) = self.input_dropout { + let dropped = dropout.forward(input); + let mask = dropout.get_last_mask().map(|m| m.clone()); + (dropped, mask) + } else { + (input.clone(), None) + }; + + // Apply recurrent dropout across the entire batch + let (hx_dropped, _recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout { + let dropped = dropout.forward(hx); + let mask = dropout.get_last_mask().map(|m| m.clone()); + (dropped, mask) + } else { + (hx.clone(), None) + }; + + // Compute all gates in parallel for the entire batch + // gates shape: (4 * hidden_size, batch_size) + let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap() + + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap(); + + // Extract and compute gate activations for the entire batch + let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x)); + let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x)); + let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh()); + let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x)); + + // Update cell state for entire batch + let mut cy = &forget_gate * cx + &input_gate * &cell_gate; + + // Apply zoneout to cell state if configured + if let Some(ref zoneout) = self.zoneout { + for col_idx in 0..batch_size { + let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col); + cy.column_mut(col_idx).assign(&cy_zoneout.column(0)); + } + } + + // Compute hidden state for entire batch + let mut hy = &output_gate * cy.map(|&x| x.tanh()); + + // Apply zoneout to hidden state if configured + if let Some(ref zoneout) = self.zoneout { + for col_idx in 0..batch_size { + let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col); + hy.column_mut(col_idx).assign(&hy_zoneout.column(0)); + } + } + + // Apply output dropout to the entire batch + let hy_final = if let Some(ref mut dropout) = self.output_dropout { + dropout.forward(&hy) + } else { + hy + }; + + (hy_final, cy) + } + + /// Batch forward pass with caching for training + /// + /// Similar to forward_batch but caches intermediate values needed for backpropagation + pub fn forward_batch_with_cache(&mut self, input: &Array2, hx: &Array2, cx: &Array2) -> (Array2, Array2, LSTMCellBatchCache) { + let batch_size = input.ncols(); + + // Apply dropout and track masks + let (input_dropped, input_mask) = if let Some(ref mut dropout) = self.input_dropout { + let dropped = dropout.forward(input); + let mask = dropout.get_last_mask().map(|m| m.clone()); + (dropped, mask) + } else { + (input.clone(), None) + }; + + let (hx_dropped, recurrent_mask) = if let Some(ref mut dropout) = self.recurrent_dropout { + let dropped = dropout.forward(hx); + let mask = dropout.get_last_mask().map(|m| m.clone()); + (dropped, mask) + } else { + (hx.clone(), None) + }; + + // Compute gates for entire batch + let gates = &self.w_ih.dot(&input_dropped) + &self.b_ih.broadcast((4 * self.hidden_size, batch_size)).unwrap() + + &self.w_hh.dot(&hx_dropped) + &self.b_hh.broadcast((4 * self.hidden_size, batch_size)).unwrap(); + + let input_gate = gates.slice(s![0..self.hidden_size, ..]).map(|&x| sigmoid(x)); + let forget_gate = gates.slice(s![self.hidden_size..2*self.hidden_size, ..]).map(|&x| sigmoid(x)); + let cell_gate = gates.slice(s![2*self.hidden_size..3*self.hidden_size, ..]).map(|&x| x.tanh()); + let output_gate = gates.slice(s![3*self.hidden_size..4*self.hidden_size, ..]).map(|&x| sigmoid(x)); + + let mut cy = &forget_gate * cx + &input_gate * &cell_gate; + + // Apply zoneout if configured + if let Some(ref zoneout) = self.zoneout { + for col_idx in 0..batch_size { + let cy_col = cy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let cx_col = cx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let cy_zoneout = zoneout.apply_cell_zoneout(&cy_col, &cx_col); + cy.column_mut(col_idx).assign(&cy_zoneout.column(0)); + } + } + + let mut hy = &output_gate * cy.map(|&x| x.tanh()); + + if let Some(ref zoneout) = self.zoneout { + for col_idx in 0..batch_size { + let hy_col = hy.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let hx_col = hx.column(col_idx).to_owned().insert_axis(ndarray::Axis(1)); + let hy_zoneout = zoneout.apply_hidden_zoneout(&hy_col, &hx_col); + hy.column_mut(col_idx).assign(&hy_zoneout.column(0)); + } + } + + let (hy_final, output_mask) = if let Some(ref mut dropout) = self.output_dropout { + let dropped = dropout.forward(&hy); + let mask = dropout.get_last_mask().map(|m| m.clone()); + (dropped, mask) + } else { + (hy, None) + }; + + // Create cache for backpropagation + let cache = LSTMCellBatchCache { + input: input.clone(), + hx: hx.clone(), + cx: cx.clone(), + gates: gates.to_owned(), + input_gate: input_gate.to_owned(), + forget_gate: forget_gate.to_owned(), + cell_gate: cell_gate.to_owned(), + output_gate: output_gate.to_owned(), + cy: cy.clone(), + hy: hy_final.clone(), + input_dropout_mask: input_mask, + recurrent_dropout_mask: recurrent_mask, + output_dropout_mask: output_mask, + batch_size, + }; + + (hy_final, cy, cache) + } + /// Backward pass implementing LSTM gradient computation with dropout /// /// Returns (parameter_gradients, input_gradient, hidden_gradient, cell_gradient) @@ -283,6 +468,91 @@ impl LSTMCell { (gradients, dx, dhx, dcx) } + /// Batch backward pass for training with multiple sequences + /// + /// Computes gradients for an entire batch simultaneously + pub fn backward_batch(&self, dhy: &Array2, dcy: &Array2, cache: &LSTMCellBatchCache) -> (LSTMCellGradients, Array2, Array2, Array2) { + let batch_size = cache.batch_size; + let hidden_size = self.hidden_size; + + // Apply output dropout backward pass using saved mask + let dhy_dropped = if let Some(ref mask) = cache.output_dropout_mask { + let keep_prob = if let Some(ref dropout) = self.output_dropout { + 1.0 - dropout.dropout_rate + } else { + 1.0 + }; + dhy * mask / keep_prob + } else { + dhy.clone() + }; + + // Output gate gradients for entire batch + let tanh_cy = cache.cy.map(|&x| x.tanh()); + let do_t = &dhy_dropped * &tanh_cy; + let do_raw = &do_t * &cache.output_gate * &cache.output_gate.map(|&x| 1.0 - x); + + // Cell state gradients from both tanh and direct paths + let dcy_from_tanh = &dhy_dropped * &cache.output_gate * cache.cy.map(|&x| 1.0 - x.tanh().powi(2)); + let dcy_total = dcy + dcy_from_tanh; + + // Gate gradients for entire batch + let df_t = &dcy_total * &cache.cx; + let df_raw = &df_t * &cache.forget_gate * cache.forget_gate.map(|&x| 1.0 - x); + + let di_t = &dcy_total * &cache.cell_gate; + let di_raw = &di_t * &cache.input_gate * cache.input_gate.map(|&x| 1.0 - x); + + let dc_t = &dcy_total * &cache.input_gate; + let dc_raw = &dc_t * cache.cell_gate.map(|&x| 1.0 - x.powi(2)); + + // Concatenate gate gradients + let mut dgates = Array2::zeros((4 * hidden_size, batch_size)); + dgates.slice_mut(s![0..hidden_size, ..]).assign(&di_raw); + dgates.slice_mut(s![hidden_size..2*hidden_size, ..]).assign(&df_raw); + dgates.slice_mut(s![2*hidden_size..3*hidden_size, ..]).assign(&dc_raw); + dgates.slice_mut(s![3*hidden_size..4*hidden_size, ..]).assign(&do_raw); + + // Parameter gradients - sum across batch dimension + let dw_ih = dgates.dot(&cache.input.t()); + let dw_hh = dgates.dot(&cache.hx.t()); + let db_ih = dgates.sum_axis(ndarray::Axis(1)).insert_axis(ndarray::Axis(1)); + let db_hh = db_ih.clone(); + + let gradients = LSTMCellGradients { + w_ih: dw_ih, + w_hh: dw_hh, + b_ih: db_ih, + b_hh: db_hh, + }; + + // Input and hidden gradients for entire batch + let mut dx = self.w_ih.t().dot(&dgates); + let mut dhx = self.w_hh.t().dot(&dgates); + let dcx = &dcy_total * &cache.forget_gate; + + // Apply dropout gradients if masks exist + if let Some(ref mask) = cache.input_dropout_mask { + let keep_prob = if let Some(ref dropout) = self.input_dropout { + 1.0 - dropout.dropout_rate + } else { + 1.0 + }; + dx = dx * mask / keep_prob; + } + + if let Some(ref mask) = cache.recurrent_dropout_mask { + let keep_prob = if let Some(ref dropout) = self.recurrent_dropout { + 1.0 - dropout.dropout_rate + } else { + 1.0 + }; + dhx = dhx * mask / keep_prob; + } + + (gradients, dx, dhx, dcx) + } + /// Initialize zero gradients for accumulation pub fn zero_gradients(&self) -> LSTMCellGradients { LSTMCellGradients { diff --git a/src/lib.rs b/src/lib.rs index 96b13d1..3795150 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,16 +43,17 @@ pub mod training; pub mod persistence; // Re-export commonly used items -pub use models::lstm_network::{LSTMNetwork, LayerDropoutConfig}; +pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig}; pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache}; -pub use layers::lstm_cell::LSTMCell; +pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients}; pub use layers::peephole_lstm_cell::PeepholeLSTMCell; pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache}; pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache}; pub use layers::dropout::{Dropout, Zoneout}; pub use training::{ - LSTMTrainer, ScheduledLSTMTrainer, TrainingConfig, - create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer + LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics, + create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer, + create_basic_batch_trainer, create_adam_batch_trainer }; pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer}; pub use schedulers::{ diff --git a/src/loss.rs b/src/loss.rs index c96efad..d3675a8 100644 --- a/src/loss.rs +++ b/src/loss.rs @@ -7,6 +7,37 @@ pub trait LossFunction { /// Compute the gradient of the loss with respect to predictions fn compute_gradient(&self, predictions: &Array2, targets: &Array2) -> Array2; + + /// Compute batch loss for multiple predictions and targets + /// Default implementation computes average loss across batch + fn compute_batch_loss(&self, predictions: &Array2, targets: &Array2) -> f64 { + let batch_size = predictions.ncols(); + let mut total_loss = 0.0; + + for i in 0..batch_size { + let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1)); + let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1)); + total_loss += self.compute_loss(&pred_col, &target_col); + } + + total_loss / batch_size as f64 + } + + /// Compute batch gradients for multiple predictions and targets + /// Default implementation computes gradients for each sample and concatenates + fn compute_batch_gradient(&self, predictions: &Array2, targets: &Array2) -> Array2 { + let batch_size = predictions.ncols(); + let mut batch_gradients = Array2::zeros(predictions.raw_dim()); + + for i in 0..batch_size { + let pred_col = predictions.column(i).to_owned().insert_axis(ndarray::Axis(1)); + let target_col = targets.column(i).to_owned().insert_axis(ndarray::Axis(1)); + let grad = self.compute_gradient(&pred_col, &target_col); + batch_gradients.column_mut(i).assign(&grad.column(0)); + } + + batch_gradients + } } /// Mean Squared Error loss function @@ -23,6 +54,17 @@ impl LossFunction for MSELoss { let diff = predictions - targets; 2.0 * diff / (predictions.len() as f64) } + + fn compute_batch_loss(&self, predictions: &Array2, targets: &Array2) -> f64 { + let diff = predictions - targets; + let squared_diff = &diff * &diff; + squared_diff.sum() / (predictions.len() as f64) + } + + fn compute_batch_gradient(&self, predictions: &Array2, targets: &Array2) -> Array2 { + let diff = predictions - targets; + 2.0 * diff / (predictions.len() as f64) + } } /// Mean Absolute Error loss function @@ -38,6 +80,16 @@ impl LossFunction for MAELoss { let diff = predictions - targets; diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64) } + + fn compute_batch_loss(&self, predictions: &Array2, targets: &Array2) -> f64 { + let diff = predictions - targets; + diff.map(|x| x.abs()).sum() / (predictions.len() as f64) + } + + fn compute_batch_gradient(&self, predictions: &Array2, targets: &Array2) -> Array2 { + let diff = predictions - targets; + diff.map(|x| if *x > 0.0 { 1.0 } else if *x < 0.0 { -1.0 } else { 0.0 }) / (predictions.len() as f64) + } } /// Cross-Entropy Loss with softmax diff --git a/src/models/lstm_network.rs b/src/models/lstm_network.rs index 97a994e..05a8ba1 100644 --- a/src/models/lstm_network.rs +++ b/src/models/lstm_network.rs @@ -1,5 +1,5 @@ use ndarray::Array2; -use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache}; +use crate::layers::lstm_cell::{LSTMCell, LSTMCellGradients, LSTMCellCache, LSTMCellBatchCache}; use crate::optimizers::Optimizer; /// Holds cached values for all layers during network forward pass @@ -8,6 +8,13 @@ pub struct LSTMNetworkCache { pub cell_caches: Vec, } +/// Holds cached values for batch processing during network forward pass +#[derive(Clone)] +pub struct LSTMNetworkBatchCache { + pub cell_caches: Vec, + pub batch_size: usize, +} + /// Multi-layer LSTM network for sequence modeling with dropout support /// /// Stacks multiple LSTM cells where the output of layer i becomes @@ -227,6 +234,152 @@ impl LSTMNetwork { (outputs, caches) } + + /// Process multiple sequences in a batch + /// + /// # Arguments + /// * `batch_sequences` - Vector of sequences, each sequence is a Vec> + /// where each Array2 has shape (input_size, 1) for single sequences + /// + /// # Returns + /// * Vector of sequence outputs, where each sequence output is Vec<(Array2, Array2)> + pub fn forward_batch_sequences(&mut self, batch_sequences: &[Vec>]) -> Vec, Array2)>> { + // Find the maximum sequence length for padding + let max_seq_len = batch_sequences.iter().map(|seq| seq.len()).max().unwrap_or(0); + let batch_size = batch_sequences.len(); + + if batch_size == 0 || max_seq_len == 0 { + return Vec::new(); + } + + let mut batch_outputs = vec![Vec::new(); batch_size]; + + // Initialize batch hidden and cell states + let mut batch_hx = Array2::zeros((self.hidden_size, batch_size)); + let mut batch_cx = Array2::zeros((self.hidden_size, batch_size)); + + // Process each time step across all sequences in the batch + for t in 0..max_seq_len { + // Prepare batch input for current time step + let mut batch_input = Array2::zeros((self.input_size, batch_size)); + let mut active_sequences = Vec::new(); + + for (batch_idx, sequence) in batch_sequences.iter().enumerate() { + if t < sequence.len() { + // Copy input for this sequence at time step t + batch_input.column_mut(batch_idx).assign(&sequence[t].column(0)); + active_sequences.push(batch_idx); + } + } + + if active_sequences.is_empty() { + break; // No more active sequences + } + + // Forward pass for this time step across the batch + let (new_batch_hx, new_batch_cx) = self.forward_batch(&batch_input, &batch_hx, &batch_cx); + + // Update states and collect outputs for active sequences + batch_hx = new_batch_hx.clone(); + batch_cx = new_batch_cx.clone(); + + // Store outputs for each active sequence + for &batch_idx in &active_sequences { + let hy = new_batch_hx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1)); + let cy = new_batch_cx.column(batch_idx).to_owned().insert_axis(ndarray::Axis(1)); + batch_outputs[batch_idx].push((hy, cy)); + } + } + + batch_outputs + } + + /// Batch forward pass for single time step across multiple sequences + /// + /// # Arguments + /// * `batch_input` - Input tensor of shape (input_size, batch_size) + /// * `batch_hx` - Hidden states tensor of shape (hidden_size, batch_size) + /// * `batch_cx` - Cell states tensor of shape (hidden_size, batch_size) + /// + /// # Returns + /// * Tuple of (new_hidden_states, new_cell_states) with same batch dimensions + pub fn forward_batch(&mut self, batch_input: &Array2, batch_hx: &Array2, batch_cx: &Array2) -> (Array2, Array2) { + let mut current_input = batch_input.clone(); + let mut current_hx = batch_hx.clone(); + let mut current_cx = batch_cx.clone(); + + // Process through each layer + for cell in &mut self.cells { + let (new_hx, new_cx) = cell.forward_batch(¤t_input, ¤t_hx, ¤t_cx); + current_input = new_hx.clone(); // Output of layer i becomes input to layer i+1 + current_hx = new_hx; + current_cx = new_cx; + } + + (current_hx, current_cx) + } + + /// Batch forward pass with caching for training + /// + /// Similar to forward_batch but caches intermediate values needed for backpropagation + pub fn forward_batch_with_cache(&mut self, batch_input: &Array2, batch_hx: &Array2, batch_cx: &Array2) -> (Array2, Array2, LSTMNetworkBatchCache) { + let mut current_input = batch_input.clone(); + let mut current_hx = batch_hx.clone(); + let mut current_cx = batch_cx.clone(); + let mut cell_caches = Vec::new(); + + // Process through each layer with caching + for cell in &mut self.cells { + let (new_hx, new_cx, cache) = cell.forward_batch_with_cache(¤t_input, ¤t_hx, ¤t_cx); + cell_caches.push(cache); + + current_input = new_hx.clone(); + current_hx = new_hx; + current_cx = new_cx; + } + + let network_cache = LSTMNetworkBatchCache { + cell_caches, + batch_size: batch_input.ncols(), + }; + + (current_hx, current_cx, network_cache) + } + + /// Batch backward pass for training + /// + /// Computes gradients for an entire batch simultaneously + pub fn backward_batch(&self, dhy: &Array2, dcy: &Array2, cache: &LSTMNetworkBatchCache) -> (Vec, Array2) { + let mut gradients = Vec::new(); + let mut current_dhy = dhy.clone(); + let mut current_dcy = dcy.clone(); + + // Backward through layers in reverse order + for (i, cell) in self.cells.iter().enumerate().rev() { + let cell_cache = &cache.cell_caches[i]; + let (cell_gradients, dx, _dhx_prev, dcx_prev) = cell.backward_batch(¤t_dhy, ¤t_dcy, cell_cache); + + gradients.push(cell_gradients); + + if i > 0 { + current_dhy = dx; + current_dcy = dcx_prev; + } + } + + gradients.reverse(); + + let dx_input = if !gradients.is_empty() { + let first_cell = &self.cells[0]; + let first_cache = &cache.cell_caches[0]; + let (_, dx_input, _, _) = first_cell.backward_batch(dhy, dcy, first_cache); + dx_input + } else { + Array2::::zeros(dhy.raw_dim()) + }; + + (gradients, dx_input) + } } /// Configuration for layer-specific dropout settings diff --git a/src/training.rs b/src/training.rs index 97e7777..a42fe9f 100644 --- a/src/training.rs +++ b/src/training.rs @@ -439,6 +439,312 @@ impl ScheduledLSTMTrain } } +/// Batch trainer for LSTM networks with configurable loss and optimizer +/// Processes multiple sequences simultaneously for improved performance +pub struct LSTMBatchTrainer { + pub network: LSTMNetwork, + pub loss_function: L, + pub optimizer: O, + pub config: TrainingConfig, + pub metrics_history: Vec, +} + +impl LSTMBatchTrainer { + pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self { + LSTMBatchTrainer { + network, + loss_function, + optimizer, + config: TrainingConfig::default(), + metrics_history: Vec::new(), + } + } + + pub fn with_config(mut self, config: TrainingConfig) -> Self { + self.config = config; + self + } + + /// Train on a batch of sequences using batch processing + /// + /// # Arguments + /// * `batch_inputs` - Vector of input sequences, each sequence is Vec> + /// * `batch_targets` - Vector of target sequences, each sequence is Vec> + /// + /// # Returns + /// * Average loss across the batch + pub fn train_batch(&mut self, batch_inputs: &[Vec>], batch_targets: &[Vec>]) -> f64 { + assert_eq!(batch_inputs.len(), batch_targets.len(), "Batch inputs and targets must have same length"); + + if batch_inputs.is_empty() { + return 0.0; + } + + self.network.train(); + + // Find maximum sequence length for padding + let max_seq_len = batch_inputs.iter().map(|seq| seq.len()).max().unwrap_or(0); + let batch_size = batch_inputs.len(); + + let mut total_loss = 0.0; + let mut total_gradients = self.network.zero_gradients(); + let mut valid_steps = 0; + + // Initialize batch states + let mut batch_hx = Array2::zeros((self.network.hidden_size, batch_size)); + let mut batch_cx = Array2::zeros((self.network.hidden_size, batch_size)); + + // Process each time step + for t in 0..max_seq_len { + // Prepare batch input and targets for current time step + let mut batch_input = Array2::zeros((self.network.input_size, batch_size)); + let mut batch_target = Array2::zeros((self.network.hidden_size, batch_size)); + let mut active_sequences = Vec::new(); + + // Collect active sequences for this time step + for (batch_idx, (input_seq, target_seq)) in batch_inputs.iter().zip(batch_targets.iter()).enumerate() { + if t < input_seq.len() && t < target_seq.len() { + batch_input.column_mut(batch_idx).assign(&input_seq[t].column(0)); + batch_target.column_mut(batch_idx).assign(&target_seq[t].column(0)); + active_sequences.push(batch_idx); + } + } + + if active_sequences.is_empty() { + break; + } + + // Forward pass with caching for active sequences + let (new_batch_hx, new_batch_cx, cache) = self.network.forward_batch_with_cache(&batch_input, &batch_hx, &batch_cx); + + // Compute loss only for active sequences + let active_predictions = if active_sequences.len() == batch_size { + new_batch_hx.clone() + } else { + let mut active_preds = Array2::zeros((self.network.hidden_size, active_sequences.len())); + for (idx, &batch_idx) in active_sequences.iter().enumerate() { + active_preds.column_mut(idx).assign(&new_batch_hx.column(batch_idx)); + } + active_preds + }; + + let active_targets = if active_sequences.len() == batch_size { + batch_target.clone() + } else { + let mut active_targs = Array2::zeros((self.network.hidden_size, active_sequences.len())); + for (idx, &batch_idx) in active_sequences.iter().enumerate() { + active_targs.column_mut(idx).assign(&batch_target.column(batch_idx)); + } + active_targs + }; + + let step_loss = self.loss_function.compute_batch_loss(&active_predictions, &active_targets); + total_loss += step_loss; + valid_steps += 1; + + // Compute gradients + let dhy = self.loss_function.compute_batch_gradient(&active_predictions, &active_targets); + let _dcy = Array2::::zeros(dhy.raw_dim()); + + // Expand gradients back to full batch size if needed + let full_dhy = if active_sequences.len() == batch_size { + dhy + } else { + let mut full_grad = Array2::zeros((self.network.hidden_size, batch_size)); + for (idx, &batch_idx) in active_sequences.iter().enumerate() { + full_grad.column_mut(batch_idx).assign(&dhy.column(idx)); + } + full_grad + }; + + let full_dcy = Array2::::zeros(full_dhy.raw_dim()); + + // Backward pass + let (step_gradients, _) = self.network.backward_batch(&full_dhy, &full_dcy, &cache); + + // Accumulate gradients + for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) { + total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih; + total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh; + total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih; + total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh; + } + + // Update states + batch_hx = new_batch_hx; + batch_cx = new_batch_cx; + } + + // Apply gradient clipping + if let Some(clip_value) = self.config.clip_gradient { + self.clip_gradients(&mut total_gradients, clip_value); + } + + // Update parameters + self.network.update_parameters(&total_gradients, &mut self.optimizer); + + if valid_steps > 0 { + total_loss / valid_steps as f64 + } else { + 0.0 + } + } + + /// Train for multiple epochs with batch processing + /// + /// # Arguments + /// * `train_data` - Vector of (input_sequences, target_sequences) tuples for training + /// * `validation_data` - Optional validation data + /// * `batch_size` - Number of sequences to process in each batch + pub fn train(&mut self, + train_data: &[(Vec>, Vec>)], + validation_data: Option<&[(Vec>, Vec>)]>, + batch_size: usize) { + + println!("Starting batch training for {} epochs with batch size {}...", + self.config.epochs, batch_size); + + for epoch in 0..self.config.epochs { + let start_time = Instant::now(); + let mut epoch_loss = 0.0; + let mut num_batches = 0; + + // Create batches + for batch_start in (0..train_data.len()).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(train_data.len()); + let batch = &train_data[batch_start..batch_end]; + + let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect(); + let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect(); + + let batch_loss = self.train_batch(&batch_inputs, &batch_targets); + epoch_loss += batch_loss; + num_batches += 1; + } + + epoch_loss /= num_batches as f64; + + // Validation + let validation_loss = if let Some(val_data) = validation_data { + self.network.eval(); + Some(self.evaluate_batch(val_data, batch_size)) + } else { + None + }; + + let time_elapsed = start_time.elapsed().as_secs_f64(); + let current_lr = self.optimizer.get_learning_rate(); + + let metrics = TrainingMetrics { + epoch, + train_loss: epoch_loss, + validation_loss, + time_elapsed, + learning_rate: current_lr, + }; + + self.metrics_history.push(metrics.clone()); + + if epoch % self.config.print_every == 0 { + if let Some(val_loss) = validation_loss { + println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}", + epoch, epoch_loss, val_loss, current_lr, time_elapsed, num_batches); + } else { + println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}", + epoch, epoch_loss, current_lr, time_elapsed, num_batches); + } + } + } + + println!("Batch training completed!"); + } + + /// Evaluate model performance using batch processing + pub fn evaluate_batch(&mut self, data: &[(Vec>, Vec>)], batch_size: usize) -> f64 { + self.network.eval(); + + let mut total_loss = 0.0; + let mut num_batches = 0; + + for batch_start in (0..data.len()).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(data.len()); + let batch = &data[batch_start..batch_end]; + + let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect(); + let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect(); + + // Process batch and compute loss (simplified evaluation) + let batch_outputs = self.network.forward_batch_sequences(&batch_inputs); + + let mut batch_loss = 0.0; + let mut valid_samples = 0; + + for (outputs, targets) in batch_outputs.iter().zip(batch_targets.iter()) { + for ((output, _), target) in outputs.iter().zip(targets.iter()) { + let loss = self.loss_function.compute_loss(output, target); + batch_loss += loss; + valid_samples += 1; + } + } + + if valid_samples > 0 { + total_loss += batch_loss / valid_samples as f64; + num_batches += 1; + } + } + + if num_batches > 0 { + total_loss / num_batches as f64 + } else { + 0.0 + } + } + + /// Generate predictions using batch processing + pub fn predict_batch(&mut self, inputs: &[Vec>]) -> Vec>> { + self.network.eval(); + + let batch_outputs = self.network.forward_batch_sequences(inputs); + batch_outputs.into_iter() + .map(|sequence_outputs| sequence_outputs.into_iter().map(|(output, _)| output).collect()) + .collect() + } + + /// Clip gradients by global norm to prevent exploding gradients + fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) { + for gradient in gradients.iter_mut() { + self.clip_gradient_matrix(&mut gradient.w_ih, max_norm); + self.clip_gradient_matrix(&mut gradient.w_hh, max_norm); + self.clip_gradient_matrix(&mut gradient.b_ih, max_norm); + self.clip_gradient_matrix(&mut gradient.b_hh, max_norm); + } + } + + fn clip_gradient_matrix(&self, matrix: &mut Array2, max_norm: f64) { + let norm = (&*matrix * &*matrix).sum().sqrt(); + if norm > max_norm { + let scale = max_norm / norm; + *matrix = matrix.map(|x| x * scale); + } + } + + pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> { + self.metrics_history.last() + } + + pub fn get_metrics_history(&self) -> &[TrainingMetrics] { + &self.metrics_history + } + + pub fn set_training_mode(&mut self, training: bool) { + if training { + self.network.train(); + } else { + self.network.eval(); + } + } +} + /// Create a basic trainer with SGD optimizer and MSE loss pub fn create_basic_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMTrainer { let loss_function = MSELoss; @@ -481,13 +787,25 @@ pub fn create_cosine_annealing_trainer( eta_min: f64 ) -> ScheduledLSTMTrainer { let loss_function = MSELoss; - let optimizer = ScheduledOptimizer::cosine_annealing( - crate::optimizers::Adam::new(learning_rate), - learning_rate, - t_max, - eta_min - ); - ScheduledLSTMTrainer::new(network, loss_function, optimizer) + let optimizer = crate::optimizers::Adam::new(learning_rate); + let scheduler = crate::schedulers::CosineAnnealingLR::new(t_max, eta_min); + let scheduled_optimizer = crate::optimizers::ScheduledOptimizer::new(optimizer, scheduler, learning_rate); + + ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) +} + +/// Create a basic batch trainer with SGD optimizer and MSE loss +pub fn create_basic_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer { + let loss_function = MSELoss; + let optimizer = SGD::new(learning_rate); + LSTMBatchTrainer::new(network, loss_function, optimizer) +} + +/// Create a batch trainer with Adam optimizer and MSE loss +pub fn create_adam_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer { + let loss_function = MSELoss; + let optimizer = crate::optimizers::Adam::new(learning_rate); + LSTMBatchTrainer::new(network, loss_function, optimizer) } #[cfg(test)]