Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rust-lstm"
version = "0.4.0"
version = "0.5.0"
authors = ["Alex Kholodniak <alexandrkholodniak@gmail.com>"]
edition = "2021"
rust-version = "1.70"
Expand Down Expand Up @@ -84,3 +84,7 @@ path = "examples/time_series_prediction.rs"
[[example]]
name = "multi_layer_lstm"
path = "examples/multi_layer_lstm.rs"

[[example]]
name = "batch_processing_example"
path = "examples/batch_processing_example.rs"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Add to your `Cargo.toml`:

```toml
[dependencies]
rust-lstm = "0.4.0"
rust-lstm = "0.5.0"
```

### Basic Usage
Expand Down
230 changes: 230 additions & 0 deletions examples/batch_processing_example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use ndarray::{Array2, arr2};
use rust_lstm::{LSTMNetwork, create_adam_batch_trainer, create_basic_trainer};
use std::time::Instant;

/// Generate synthetic sine wave sequences for batch processing demonstration
fn generate_batch_sine_data(num_sequences: usize, sequence_length: usize, input_size: usize) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
let mut data = Vec::new();

for i in 0..num_sequences {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let start = (i as f64) * 0.05; // Different starting points for variety
let frequency = 1.0 + (i as f64) * 0.1; // Different frequencies

for j in 0..sequence_length {
let t = start + (j as f64) * 0.1;

// Create multi-dimensional input
let mut input_vec = vec![0.0; input_size];
input_vec[0] = (t * frequency * 2.0 * std::f64::consts::PI).sin();
if input_size > 1 {
input_vec[1] = (t * frequency * 2.0 * std::f64::consts::PI).cos();
}
if input_size > 2 {
input_vec[2] = t.sin() * t.cos(); // Some nonlinear combination
}

// Target is the next value in the sine sequence
let target = ((t + 0.1) * frequency * 2.0 * std::f64::consts::PI).sin();

inputs.push(Array2::from_shape_vec((input_size, 1), input_vec).unwrap());
targets.push(arr2(&[[target]]));
}

data.push((inputs, targets));
}

data
}

/// Benchmark training time comparison between single and batch processing
fn benchmark_training_performance() {
println!("BATCH PROCESSING PERFORMANCE BENCHMARK");
println!("======================================\n");

let input_size = 3;
let hidden_size = 16;
let num_layers = 2;
let learning_rate = 0.001;

// Generate training data
let train_data = generate_batch_sine_data(100, 10, input_size);
let val_data = generate_batch_sine_data(20, 10, input_size);

println!("Dataset: {} training sequences, {} validation sequences", train_data.len(), val_data.len());
println!("Network: {} -> {} hidden ({} layers)\n", input_size, hidden_size, num_layers);

// Test 1: Single sequence processing (traditional)
println!("Testing Traditional Single-Sequence Processing...");
let network1 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer1 = create_basic_trainer(network1, learning_rate);

// Configure for quick demo
trainer1.config.epochs = 5;
trainer1.config.print_every = 1;

let start_time = Instant::now();
trainer1.train(&train_data, Some(&val_data));
let single_time = start_time.elapsed();

let final_metrics1 = trainer1.get_latest_metrics().unwrap();
println!("Single-sequence - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics1.train_loss, single_time.as_secs_f64());

// Test 2: Batch processing with small batches
println!("Testing Batch Processing (batch size 8)...");
let network2 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer2 = create_adam_batch_trainer(network2, learning_rate);

trainer2.config.epochs = 5;
trainer2.config.print_every = 1;

let start_time = Instant::now();
trainer2.train(&train_data, Some(&val_data), 8); // Batch size 8
let batch_time = start_time.elapsed();

let final_metrics2 = trainer2.get_latest_metrics().unwrap();
println!("Batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics2.train_loss, batch_time.as_secs_f64());

// Test 3: Larger batch size
println!("Testing Larger Batch Processing (batch size 16)...");
let network3 = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer3 = create_adam_batch_trainer(network3, learning_rate);

trainer3.config.epochs = 5;
trainer3.config.print_every = 1;

let start_time = Instant::now();
trainer3.train(&train_data, Some(&val_data), 16); // Batch size 16
let large_batch_time = start_time.elapsed();

let final_metrics3 = trainer3.get_latest_metrics().unwrap();
println!("Large batch processing - Final loss: {:.6}, Time: {:.2}s\n",
final_metrics3.train_loss, large_batch_time.as_secs_f64());

// Performance summary
println!("PERFORMANCE SUMMARY:");
println!("======================");
println!("Single-sequence: {:.2}s (baseline)", single_time.as_secs_f64());
println!("Batch-8: {:.2}s ({:.1}x speedup)",
batch_time.as_secs_f64(),
single_time.as_secs_f64() / batch_time.as_secs_f64());
println!("Batch-16: {:.2}s ({:.1}x speedup)",
large_batch_time.as_secs_f64(),
single_time.as_secs_f64() / large_batch_time.as_secs_f64());

if batch_time < single_time {
println!("Batch processing achieved {:.1}x speedup!",
single_time.as_secs_f64() / batch_time.as_secs_f64());
} else {
println!("Note: For small datasets, overhead may dominate. Try larger datasets for better speedup.");
}
}

/// Demonstrate batch prediction capabilities
fn demonstrate_batch_prediction() {
println!("\nBATCH PREDICTION DEMONSTRATION");
println!("==============================\n");

let input_size = 2;
let hidden_size = 8;
let num_layers = 1;

// Create and train a simple model
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer = create_adam_batch_trainer(network, 0.01);

// Generate small training dataset
let train_data = generate_batch_sine_data(20, 5, input_size);

trainer.config.epochs = 10;
trainer.config.print_every = 5;

println!("Training a small model for prediction demo...");
trainer.train(&train_data, None, 4);

// Create test sequences for batch prediction
let test_sequences = generate_batch_sine_data(3, 3, input_size);
let test_inputs: Vec<_> = test_sequences.iter().map(|(inputs, _)| inputs.clone()).collect();
let _test_targets: Vec<_> = test_sequences.iter().map(|(_, targets)| targets.clone()).collect();

println!("\nPerforming batch predictions...");
let predictions = trainer.predict_batch(&test_inputs);

println!("Input sequences vs Predictions:");
for (i, (inputs, preds)) in test_inputs.iter().zip(predictions.iter()).enumerate() {
println!("Sequence {}:", i + 1);
for (j, (input, pred)) in inputs.iter().zip(preds.iter()).enumerate() {
println!(" Step {}: Input=[{:.3}, {:.3}] -> Pred={:.3}",
j + 1, input[[0, 0]], input[[1, 0]], pred[[0, 0]]);
}
println!();
}
}

/// Demonstrate memory efficiency and scalability
fn demonstrate_scalability() {
println!("SCALABILITY DEMONSTRATION");
println!("=========================\n");

let test_sizes = vec![
(50, 4), // Small: 50 sequences, batch size 4
(200, 8), // Medium: 200 sequences, batch size 8
(500, 16), // Large: 500 sequences, batch size 16
];

for (num_sequences, batch_size) in test_sizes {
println!("Testing with {} sequences, batch size {}...", num_sequences, batch_size);

let train_data = generate_batch_sine_data(num_sequences, 8, 2);
let network = LSTMNetwork::new(2, 12, 1);
let mut trainer = create_adam_batch_trainer(network, 0.001);

trainer.config.epochs = 3;
trainer.config.print_every = 1;

let start_time = Instant::now();
trainer.train(&train_data, None, batch_size);
let training_time = start_time.elapsed();

let final_loss = trainer.get_latest_metrics().unwrap().train_loss;
println!(" Completed in {:.2}s, final loss: {:.6}\n",
training_time.as_secs_f64(), final_loss);
}

println!("All scalability tests completed successfully!");
println!("Batch processing handles varying dataset sizes efficiently.");
}

fn main() {
println!("RUST-LSTM BATCH PROCESSING DEMONSTRATION");
println!("=========================================\n");

println!("This example demonstrates the new batch processing capabilities:");
println!("- Simultaneous processing of multiple sequences");
println!("- Performance improvements over single-sequence training");
println!("- Batch prediction capabilities");
println!("- Scalability with different batch sizes\n");

benchmark_training_performance();
demonstrate_batch_prediction();
demonstrate_scalability();

println!("\nBATCH PROCESSING DEMONSTRATION COMPLETED!");
println!("==========================================");
println!("Key Benefits Demonstrated:");
println!("- Faster training through batch processing");
println!("- Efficient memory utilization");
println!("- Scalable to different dataset sizes");
println!("- Easy-to-use batch training API");
println!("- Backward compatibility with existing code");

println!("\nNext Steps:");
println!("- Try batch processing with your own datasets");
println!("- Experiment with different batch sizes");
println!("- Compare performance with single-sequence training");
println!("- Use batch processing for faster model development");
}
Loading