From 2699f0cf6de50b4e95b575ba1057c395cee04e35 Mon Sep 17 00:00:00 2001 From: Wang Zhenyu Date: Sun, 12 Oct 2025 21:34:03 +0800 Subject: [PATCH] feat: add multi head attention --- README.md | 32 +-- src/lib.rs | 2 + src/llm.rs | 7 +- src/main.rs | 15 +- src/multi_head_attention.rs | 404 +++++++++++++++++++++++++++++ src/self_attention.rs | 9 + src/transformer.rs | 65 ++++- tests/llm_test.rs | 12 +- tests/multi_head_attention_test.rs | 281 ++++++++++++++++++++ tests/transformer_test.rs | 65 ++++- 10 files changed, 861 insertions(+), 31 deletions(-) create mode 100644 src/multi_head_attention.rs create mode 100644 tests/multi_head_attention_test.rs diff --git a/README.md b/README.md index 73f0b34..5fdd7bc 100644 --- a/README.md +++ b/README.md @@ -41,17 +41,18 @@ Input Text → Tokenization → Embeddings → Transformer Blocks → Output Pro ``` src/ -├── main.rs # 🎯 Training pipeline and interactive mode -├── llm.rs # 🧠 Core LLM implementation and training logic -├── lib.rs # 📚 Library exports and constants -├── transformer.rs # 🔄 Transformer block (attention + feed-forward) -├── self_attention.rs # 👀 Multi-head self-attention mechanism -├── feed_forward.rs # ⚡ Position-wise feed-forward networks -├── embeddings.rs # 📊 Token embedding layer -├── output_projection.rs # 🎰 Final linear layer for vocabulary predictions -├── vocab.rs # 📝 Vocabulary management and tokenization -├── layer_norm.rs # 🧮 Layer normalization -└── adam.rs # 🏃 Adam optimizer implementation +├── main.rs # 🎯 Training pipeline and interactive mode +├── llm.rs # 🧠 Core LLM implementation and training logic +├── lib.rs # 📚 Library exports and constants +├── transformer.rs # 🔄 Transformer block (multi-head attention + feed-forward) +├── multi_head_attention.rs # 👀 Multi-head self-attention mechanism (default) +├── self_attention.rs # 👁️ Single-head attention (legacy) +├── feed_forward.rs # ⚡ Position-wise feed-forward networks +├── embeddings.rs # 📊 Token embedding layer +├── output_projection.rs # 🎰 Final linear layer for vocabulary predictions +├── vocab.rs # 📝 Vocabulary management and tokenization +├── layer_norm.rs # 🧮 Layer normalization +└── adam.rs # 🏃 Adam optimizer implementation tests/ ├── llm_test.rs # Tests for core LLM functionality @@ -110,8 +111,9 @@ Model output: Rain is caused by water vapor in clouds condensing into droplets t - **Vocabulary Size**: Dynamic (built from training data) - **Embedding Dimension**: 128 (defined by `EMBEDDING_DIM` in `src/lib.rs`) - **Hidden Dimension**: 256 (defined by `HIDDEN_DIM` in `src/lib.rs`) +- **Number of Attention Heads**: 8 (defined by `NUM_HEADS` in `src/lib.rs`) - **Max Sequence Length**: 80 tokens (defined by `MAX_SEQ_LEN` in `src/lib.rs`) -- **Architecture**: 3 Transformer blocks + embeddings + output projection +- **Architecture**: 3 Multi-Head Transformer blocks + embeddings + output projection ### Training Details - **Optimizer**: Adam with gradient clipping @@ -174,7 +176,7 @@ Contributions are welcome! This project is perfect for learning and experimentat - **📊 Evaluation metrics** - Perplexity, benchmarks, training visualizations ### Areas for Improvement -- **Advanced architectures** (multi-head attention, positional encoding, RoPE) +- **Advanced architectures** (~~multi-head attention~~✅, positional encoding, RoPE) - **Training improvements** (different optimizers, learning rate schedules, regularization) - **Data handling** (larger datasets, tokenizer improvements, streaming) - **Model analysis** (attention visualization, gradient analysis, interpretability) @@ -194,8 +196,8 @@ Contributions are welcome! This project is perfect for learning and experimentat ### Ideas for Contributions - 🚀 **Beginner**: Model save/load, more training data, config files -- 🔥 **Intermediate**: Beam search, positional encodings, training checkpoints -- ⚡ **Advanced**: Multi-head attention, layer parallelization, custom optimizations +- 🔥 **Intermediate**: Beam search, positional encodings (RoPE), training checkpoints +- ⚡ **Advanced**: Layer parallelization, Flash Attention, custom optimizations Questions? Open an issue or start a discussion! diff --git a/src/lib.rs b/src/lib.rs index a13d7dc..04ea656 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod embeddings; pub mod feed_forward; pub mod layer_norm; pub mod llm; +pub mod multi_head_attention; pub mod output_projection; pub mod self_attention; pub mod transformer; @@ -18,3 +19,4 @@ pub use vocab::Vocab; pub const MAX_SEQ_LEN: usize = 80; pub const EMBEDDING_DIM: usize = 128; pub const HIDDEN_DIM: usize = 256; +pub const NUM_HEADS: usize = 8; // Number of attention heads for multi-head attention diff --git a/src/llm.rs b/src/llm.rs index d0d6688..06cec09 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -3,8 +3,8 @@ use std::cmp::Ordering; use ndarray::{Array1, Array2, Axis}; use crate::{ - EMBEDDING_DIM, Embeddings, HIDDEN_DIM, MAX_SEQ_LEN, Vocab, output_projection::OutputProjection, - transformer::TransformerBlock, + EMBEDDING_DIM, Embeddings, HIDDEN_DIM, MAX_SEQ_LEN, NUM_HEADS, Vocab, + output_projection::OutputProjection, transformer::MultiHeadTransformerBlock, }; pub trait Layer { fn layer_type(&self) -> &str; @@ -24,7 +24,8 @@ pub struct LLM { impl Default for LLM { fn default() -> Self { - let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + let transformer_block = + MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS); let output_projection = OutputProjection::new(EMBEDDING_DIM, Vocab::default_words().len()); Self { vocab: Vocab::default(), diff --git a/src/main.rs b/src/main.rs index 5babf3c..5c7fc83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ use std::io::Write; -use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN}; +use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN, NUM_HEADS}; use dataset_loader::{Dataset, DatasetType}; use crate::{ embeddings::Embeddings, llm::LLM, output_projection::OutputProjection, - transformer::TransformerBlock, vocab::Vocab, + transformer::MultiHeadTransformerBlock, vocab::Vocab, }; mod adam; @@ -14,6 +14,7 @@ mod embeddings; mod feed_forward; mod layer_norm; mod llm; +mod multi_head_attention; mod output_projection; mod self_attention; mod transformer; @@ -44,9 +45,9 @@ fn main() { let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s: &String| s.as_str()).collect(); let vocab = Vocab::new(vocab_words_refs); - let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); - let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); - let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + let transformer_block_1 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS); + let transformer_block_2 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS); + let transformer_block_3 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS); let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len()); let embeddings = Embeddings::new(vocab.clone()); let mut llm = LLM::new( @@ -63,8 +64,8 @@ fn main() { println!("\n=== MODEL INFORMATION ==="); println!("Network architecture: {}", llm.network_description()); println!( - "Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}", - MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM + "Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}, num_heads: {}", + MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS ); println!("Total parameters: {}", llm.total_parameters()); diff --git a/src/multi_head_attention.rs b/src/multi_head_attention.rs new file mode 100644 index 0000000..d41ef71 --- /dev/null +++ b/src/multi_head_attention.rs @@ -0,0 +1,404 @@ +use std::f32; + +use ndarray::Array2; +use rand_distr::{Distribution, Normal}; + +use crate::{EMBEDDING_DIM, adam::Adam, llm::Layer}; + +/// Multi-Head Self-Attention implementation +/// +/// This layer splits the embedding dimension into multiple attention heads, +/// allowing the model to attend to information from different representation +/// subspaces at different positions. +/// +/// Architecture: +/// - Input: [seq_len, embedding_dim] +/// - Split into num_heads with head_dim = embedding_dim / num_heads +/// - Each head computes its own Q, K, V and attention +/// - Outputs are concatenated and projected through W_o +pub struct MultiHeadAttention { + pub embedding_dim: usize, + pub num_heads: usize, + pub head_dim: usize, + + // Weight matrices for Q, K, V projections + w_q: Array2, + w_k: Array2, + w_v: Array2, + // Output projection matrix + w_o: Array2, + + // Cache for backward pass + cached_input: Option>, + cached_q: Option>, + cached_k: Option>, + cached_v: Option>, + cached_attn_weights: Option>>, + + // Optimizers for each weight matrix + optimizer_w_q: Adam, + optimizer_w_k: Adam, + optimizer_w_v: Adam, + optimizer_w_o: Adam, +} + +impl Default for MultiHeadAttention { + fn default() -> Self { + MultiHeadAttention::new(EMBEDDING_DIM, 8) + } +} + +impl MultiHeadAttention { + /// Creates a new MultiHeadAttention layer + /// + /// # Arguments + /// * `embedding_dim` - The dimension of input embeddings + /// * `num_heads` - Number of attention heads (must divide embedding_dim evenly) + pub fn new(embedding_dim: usize, num_heads: usize) -> Self { + assert_eq!( + embedding_dim % num_heads, + 0, + "embedding_dim must be divisible by num_heads" + ); + + let head_dim = embedding_dim / num_heads; + let mut rng = rand::rng(); + // Xavier/He initialization: std = sqrt(2 / fan_in) + let std = (2.0 / embedding_dim as f32).sqrt(); + let normal = Normal::new(0.0, std).unwrap(); + + MultiHeadAttention { + embedding_dim, + num_heads, + head_dim, + w_q: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), + w_k: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), + w_v: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), + w_o: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), + cached_input: None, + cached_q: None, + cached_k: None, + cached_v: None, + cached_attn_weights: None, + optimizer_w_q: Adam::new((embedding_dim, embedding_dim)), + optimizer_w_k: Adam::new((embedding_dim, embedding_dim)), + optimizer_w_v: Adam::new((embedding_dim, embedding_dim)), + optimizer_w_o: Adam::new((embedding_dim, embedding_dim)), + } + } + + /// Computes Q, K, V projections from input + fn compute_qkv(&self, input: &Array2) -> (Array2, Array2, Array2) { + let q = input.dot(&self.w_q); // Q = X * W_Q + let k = input.dot(&self.w_k); // K = X * W_K + let v = input.dot(&self.w_v); // V = X * W_V + (q, k, v) + } + + /// Splits the input into multiple heads + /// + /// # Arguments + /// * `x` - Input of shape [seq_len, embedding_dim] + /// + /// # Returns + /// Vector of arrays, one per head, each of shape [seq_len, head_dim] + pub fn split_heads(&self, x: &Array2) -> Vec> { + let mut heads = Vec::new(); + + for h in 0..self.num_heads { + let start_idx = h * self.head_dim; + let end_idx = start_idx + self.head_dim; + let head = x.slice(ndarray::s![.., start_idx..end_idx]).to_owned(); + heads.push(head); + } + + heads + } + + /// Concatenates multiple heads back into a single array + /// + /// # Arguments + /// * `heads` - Vector of arrays, one per head, each of shape [seq_len, head_dim] + /// + /// # Returns + /// Array of shape [seq_len, embedding_dim] + pub fn concat_heads(&self, heads: &[Array2]) -> Array2 { + let seq_len = heads[0].shape()[0]; + let mut result = Array2::zeros((seq_len, self.embedding_dim)); + + for (h, head) in heads.iter().enumerate() { + let start_idx = h * self.head_dim; + let end_idx = start_idx + self.head_dim; + result + .slice_mut(ndarray::s![.., start_idx..end_idx]) + .assign(head); + } + + result + } + + /// Applies softmax function row-wise + fn softmax(&self, scores: &Array2) -> Array2 { + let mut result = scores.clone(); + + // Apply softmax row-wise + for mut row in result.rows_mut() { + let max_val = row.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); + // Calculate exp for each element + let exp_values: Vec = row.iter().map(|&x| (x - max_val).exp()).collect(); + let sum_exp: f32 = exp_values.iter().sum(); + + // Normalize by sum + for (i, &exp_val) in exp_values.iter().enumerate() { + row[i] = exp_val / sum_exp; + } + } + + result + } + + /// Computes gradient of softmax function + fn softmax_backward( + softmax_output: &Array2, // shape: [seq_len, seq_len] + grad_output: &Array2, // shape: [seq_len, seq_len] + ) -> Array2 { + let mut grad_input = softmax_output.clone(); + + for ((mut grad_row, softmax_row), grad_out_row) in grad_input + .outer_iter_mut() + .zip(softmax_output.outer_iter()) + .zip(grad_output.outer_iter()) + { + // dot product: y ⊙ dL/dy + let dot = softmax_row + .iter() + .zip(grad_out_row.iter()) + .map(|(&y_i, &dy_i)| y_i * dy_i) + .sum::(); + + for ((g, &y_i), &dy_i) in grad_row + .iter_mut() + .zip(softmax_row.iter()) + .zip(grad_out_row.iter()) + { + *g = y_i * (dy_i - dot); + } + } + + grad_input + } + + /// Performs attention for a single head + fn attention_head( + &self, + q_head: &Array2, + k_head: &Array2, + v_head: &Array2, + ) -> (Array2, Array2) { + let dk = (self.head_dim as f32).sqrt(); + let k_t = k_head.t(); + let mut scores = q_head.dot(&k_t) / dk; + + // Apply causal masking + let seq_len = scores.shape()[0]; + for i in 0..seq_len { + for j in (i + 1)..seq_len { + scores[[i, j]] = f32::NEG_INFINITY; + } + } + + let attn_weights = self.softmax(&scores); + let output = attn_weights.dot(v_head); + + (output, attn_weights) + } +} + +impl Layer for MultiHeadAttention { + fn layer_type(&self) -> &str { + "MultiHeadAttention" + } + + fn forward(&mut self, input: &Array2) -> Array2 { + // Cache input for backward pass + self.cached_input = Some(input.clone()); + + // Compute Q, K, V projections + let (q, k, v) = self.compute_qkv(input); + self.cached_q = Some(q.clone()); + self.cached_k = Some(k.clone()); + self.cached_v = Some(v.clone()); + + // Split into heads + let q_heads = self.split_heads(&q); + let k_heads = self.split_heads(&k); + let v_heads = self.split_heads(&v); + + // Apply attention for each head + let mut head_outputs = Vec::new(); + let mut attn_weights = Vec::new(); + for i in 0..self.num_heads { + let (head_output, head_weights) = + self.attention_head(&q_heads[i], &k_heads[i], &v_heads[i]); + head_outputs.push(head_output); + attn_weights.push(head_weights); + } + self.cached_attn_weights = Some(attn_weights); + + // Concatenate heads + let concat = self.concat_heads(&head_outputs); + + // Apply output projection + let output = concat.dot(&self.w_o); + + // Add residual connection + output + input + } + + fn backward(&mut self, grads: &Array2, lr: f32) -> Array2 { + let input = self.cached_input.as_ref().unwrap(); + let q = self.cached_q.as_ref().unwrap(); + let k = self.cached_k.as_ref().unwrap(); + let v = self.cached_v.as_ref().unwrap(); + let attn_weights = self.cached_attn_weights.as_ref().unwrap(); + + // Gradient through residual connection + let grad_output = grads.clone(); + + // Gradient through output projection: dL/dW_o = concat^T @ grad_output + let q_heads = self.split_heads(q); + let k_heads = self.split_heads(k); + let v_heads = self.split_heads(v); + + // Recompute head outputs for gradient calculation + let mut head_outputs = Vec::new(); + for i in 0..self.num_heads { + let head_output = attn_weights[i].dot(&v_heads[i]); + head_outputs.push(head_output); + } + let concat = self.concat_heads(&head_outputs); + + // Gradient of W_o + let grad_w_o = concat.t().dot(&grad_output); + + // Gradient w.r.t. concatenated heads + let grad_concat = grad_output.dot(&self.w_o.t()); + + // Split gradient back into heads + let grad_heads = self.split_heads(&grad_concat); + + // Backpropagate through each attention head + let mut grad_q_heads = Vec::new(); + let mut grad_k_heads = Vec::new(); + let mut grad_v_heads = Vec::new(); + + for i in 0..self.num_heads { + let dk = (self.head_dim as f32).sqrt(); + + // Gradient w.r.t. V: dL/dV = attn_weights^T @ grad_head + let grad_v_head = attn_weights[i].t().dot(&grad_heads[i]); + + // Gradient w.r.t. attention weights + let grad_attn_weights = grad_heads[i].dot(&v_heads[i].t()); + + // Gradient w.r.t. scores (through softmax) + let grad_scores = Self::softmax_backward(&attn_weights[i], &grad_attn_weights); + + // Gradient w.r.t. Q and K + let grad_q_head = grad_scores.dot(&k_heads[i]) / dk; + let grad_k_head = grad_scores.t().dot(&q_heads[i]) / dk; + + grad_q_heads.push(grad_q_head); + grad_k_heads.push(grad_k_head); + grad_v_heads.push(grad_v_head); + } + + // Concatenate head gradients + let grad_q = self.concat_heads(&grad_q_heads); + let grad_k = self.concat_heads(&grad_k_heads); + let grad_v = self.concat_heads(&grad_v_heads); + + // Gradient w.r.t. weight matrices + let grad_w_q = input.t().dot(&grad_q); + let grad_w_k = input.t().dot(&grad_k); + let grad_w_v = input.t().dot(&grad_v); + + // Gradient w.r.t. input (through Q, K, V projections) + let grad_input_attention = + grad_q.dot(&self.w_q.t()) + grad_k.dot(&self.w_k.t()) + grad_v.dot(&self.w_v.t()); + + // Add gradient from residual connection + let grad_input = grad_input_attention + grads; + + // Update weights using Adam optimizer + self.optimizer_w_q.step(&mut self.w_q, &grad_w_q, lr); + self.optimizer_w_k.step(&mut self.w_k, &grad_w_k, lr); + self.optimizer_w_v.step(&mut self.w_v, &grad_w_v, lr); + self.optimizer_w_o.step(&mut self.w_o, &grad_w_o, lr); + + grad_input + } + + fn parameters(&self) -> usize { + self.w_q.len() + self.w_k.len() + self.w_v.len() + self.w_o.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_and_concat_heads() { + let embedding_dim = 128; + let num_heads = 8; + let seq_len = 10; + + let mha = MultiHeadAttention::new(embedding_dim, num_heads); + let input = Array2::ones((seq_len, embedding_dim)); + + // Test split + let heads = mha.split_heads(&input); + assert_eq!(heads.len(), num_heads); + for head in &heads { + assert_eq!(head.shape(), [seq_len, mha.head_dim]); + } + + // Test concat + let concat = mha.concat_heads(&heads); + assert_eq!(concat.shape(), [seq_len, embedding_dim]); + + // Verify split and concat are inverses + for i in 0..seq_len { + for j in 0..embedding_dim { + assert!((concat[[i, j]] - input[[i, j]]).abs() < 1e-5); + } + } + } + + #[test] + fn test_multihead_attention_shapes() { + let embedding_dim = 128; + let num_heads = 8; + let seq_len = 10; + + let mut mha = MultiHeadAttention::new(embedding_dim, num_heads); + let input = Array2::ones((seq_len, embedding_dim)); + + let output = mha.forward(&input); + assert_eq!(output.shape(), [seq_len, embedding_dim]); + } + + #[test] + fn test_multihead_attention_parameter_count() { + let embedding_dim = 128; + let num_heads = 8; + + let mha = MultiHeadAttention::new(embedding_dim, num_heads); + + // Should have 4 weight matrices: W_q, W_k, W_v, W_o + // Each is embedding_dim x embedding_dim + let expected = 4 * embedding_dim * embedding_dim; + assert_eq!(mha.parameters(), expected); + } +} diff --git a/src/self_attention.rs b/src/self_attention.rs index 2e31324..9e17eb3 100644 --- a/src/self_attention.rs +++ b/src/self_attention.rs @@ -5,6 +5,10 @@ use rand_distr::{Distribution, Normal}; use crate::{EMBEDDING_DIM, adam::Adam, llm::Layer}; +/// Legacy single-head self-attention implementation. +/// This has been superseded by MultiHeadAttention. +/// Kept for reference and backward compatibility. +#[allow(dead_code)] pub struct SelfAttention { pub embedding_dim: usize, w_q: Array2, // Weight matrices for Q, K, V @@ -26,6 +30,7 @@ impl Default for SelfAttention { impl SelfAttention { /// Initializes a Transformer with random Q, K, V weights + #[allow(dead_code)] pub fn new(embedding_dim: usize) -> Self { let mut rng = rand::rng(); // Xavier/He initialization: std = sqrt(2 / fan_in) @@ -44,6 +49,7 @@ impl SelfAttention { } } + #[allow(dead_code)] fn compute_qkv(&self, input: &Array2) -> (Array2, Array2, Array2) { let q = input.dot(&self.w_q); // Q = X * W_Q let k = input.dot(&self.w_k); // K = X * W_K @@ -51,6 +57,7 @@ impl SelfAttention { (q, k, v) } + #[allow(dead_code)] fn attention(&self, q: &Array2, k: &Array2, v: &Array2) -> Array2 { let dk = (self.embedding_dim as f32).sqrt(); @@ -69,6 +76,7 @@ impl SelfAttention { weights.dot(v) } + #[allow(dead_code)] fn softmax(&self, scores: &Array2) -> Array2 { let mut result = scores.clone(); @@ -88,6 +96,7 @@ impl SelfAttention { result } + #[allow(dead_code)] fn softmax_backward( softmax_output: &Array2, // shape: [seq_len, vocab_size] grad_output: &Array2, // shape: [seq_len, vocab_size] diff --git a/src/transformer.rs b/src/transformer.rs index e700c8c..e083ecd 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,8 +1,14 @@ use ndarray::Array2; use crate::{ - feed_forward::FeedForward, layer_norm::LayerNorm, llm::Layer, self_attention::SelfAttention, + feed_forward::FeedForward, layer_norm::LayerNorm, llm::Layer, + multi_head_attention::MultiHeadAttention, self_attention::SelfAttention, }; + +/// Legacy Transformer Block with single-head self-attention. +/// This has been superseded by MultiHeadTransformerBlock. +/// Kept for reference and backward compatibility. +#[allow(dead_code)] pub struct TransformerBlock { attention: SelfAttention, feed_forward: FeedForward, @@ -11,6 +17,7 @@ pub struct TransformerBlock { } impl TransformerBlock { + #[allow(dead_code)] pub fn new(embedding_dim: usize, hidden_dim: usize) -> Self { TransformerBlock { attention: SelfAttention::new(embedding_dim), @@ -21,6 +28,25 @@ impl TransformerBlock { } } +/// Transformer Block with multi-head self-attention +pub struct MultiHeadTransformerBlock { + attention: MultiHeadAttention, + feed_forward: FeedForward, + norm1: LayerNorm, // After attention + norm2: LayerNorm, // After feed forward +} + +impl MultiHeadTransformerBlock { + pub fn new(embedding_dim: usize, hidden_dim: usize, num_heads: usize) -> Self { + MultiHeadTransformerBlock { + attention: MultiHeadAttention::new(embedding_dim, num_heads), + feed_forward: FeedForward::new(embedding_dim, hidden_dim), + norm1: LayerNorm::new(embedding_dim), + norm2: LayerNorm::new(embedding_dim), + } + } +} + impl Layer for TransformerBlock { fn layer_type(&self) -> &str { "TransformerBlock" @@ -58,3 +84,40 @@ impl Layer for TransformerBlock { + self.norm2.parameters() } } + +impl Layer for MultiHeadTransformerBlock { + fn layer_type(&self) -> &str { + "MultiHeadTransformerBlock" + } + + fn forward(&mut self, input: &Array2) -> Array2 { + // Standard Transformer architecture: attention + norm -> feedforward + norm + let attention_out = self.attention.forward(input); // includes residual + let norm1_out = self.norm1.normalize(&attention_out); + + let feed_forward_out = self.feed_forward.forward(&norm1_out); // includes residual + + self.norm2.normalize(&feed_forward_out) + } + + fn backward(&mut self, grads: &Array2, lr: f32) -> Array2 { + // Backward through second LayerNorm + let grad_norm2 = self.norm2.backward(grads, lr); + + // Backward through feed-forward (includes residual connection) + let grad_ffn = self.feed_forward.backward(&grad_norm2, lr); + + // Backward through first LayerNorm + let grad_norm1 = self.norm1.backward(&grad_ffn, lr); + + // Backward through attention (includes residual connection) + self.attention.backward(&grad_norm1, lr) + } + + fn parameters(&self) -> usize { + self.attention.parameters() + + self.feed_forward.parameters() + + self.norm1.parameters() + + self.norm2.parameters() + } +} diff --git a/tests/llm_test.rs b/tests/llm_test.rs index 1e2fec4..1375d77 100644 --- a/tests/llm_test.rs +++ b/tests/llm_test.rs @@ -1,6 +1,6 @@ use llm::{ - EMBEDDING_DIM, Embeddings, HIDDEN_DIM, LLM, Layer, MAX_SEQ_LEN, Vocab, - output_projection::OutputProjection, transformer::TransformerBlock, + EMBEDDING_DIM, Embeddings, HIDDEN_DIM, LLM, Layer, MAX_SEQ_LEN, NUM_HEADS, Vocab, + output_projection::OutputProjection, transformer::MultiHeadTransformerBlock, }; use ndarray::Array2; @@ -143,7 +143,11 @@ fn test_llm_total_parameters() { // Create an LLM with actual layers to get a meaningful parameter count let embeddings = Box::new(Embeddings::new(vocab.clone())); - let transformer_block = Box::new(TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM)); + let transformer_block = Box::new(MultiHeadTransformerBlock::new( + EMBEDDING_DIM, + HIDDEN_DIM, + NUM_HEADS, + )); let output_projection = Box::new(OutputProjection::new(EMBEDDING_DIM, vocab_size)); let llm = LLM::new( @@ -159,7 +163,7 @@ fn test_llm_total_parameters() { // source) let expected_embeddings_parameters = vocab_size * EMBEDDING_DIM + MAX_SEQ_LEN * EMBEDDING_DIM; let expected_transformer_block_parameters = (2 * EMBEDDING_DIM) + // LayerNorm - (3 * EMBEDDING_DIM * EMBEDDING_DIM) + // SelfAttention + (4 * EMBEDDING_DIM * EMBEDDING_DIM) + // MultiHeadAttention (W_q, W_k, W_v, W_o) (2 * EMBEDDING_DIM) + // LayerNorm (EMBEDDING_DIM * HIDDEN_DIM + HIDDEN_DIM + HIDDEN_DIM * EMBEDDING_DIM + EMBEDDING_DIM); // FeedForward let expected_output_projection_parameters = EMBEDDING_DIM * vocab_size + vocab_size; diff --git a/tests/multi_head_attention_test.rs b/tests/multi_head_attention_test.rs new file mode 100644 index 0000000..f41e6c1 --- /dev/null +++ b/tests/multi_head_attention_test.rs @@ -0,0 +1,281 @@ +use llm::{EMBEDDING_DIM, Layer, multi_head_attention::MultiHeadAttention}; +use ndarray::Array2; + +#[test] +fn test_multi_head_attention_forward() { + // Create multi-head attention module with 8 heads + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + + // Create input tensor (seq_len=3, embedding_dim=EMBEDDING_DIM) + let input = Array2::ones((3, EMBEDDING_DIM)); + + // Test forward pass + let output = mha.forward(&input); + + // Check output shape - should be same as input + assert_eq!(output.shape(), input.shape()); + + // Verify output is not all zeros + let output_sum: f32 = output.iter().sum(); + assert!(output_sum.abs() > 0.0, "Output should not be all zeros"); +} + +#[test] +fn test_multi_head_attention_with_different_sequence_lengths() { + // Create multi-head attention module with 4 heads + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 4); + + // Test with different sequence lengths + for seq_len in 1..10 { + // Create input tensor + let input = Array2::ones((seq_len, EMBEDDING_DIM)); + + // Test forward pass + let output = mha.forward(&input); + + // Check output shape + assert_eq!(output.shape(), [seq_len, EMBEDDING_DIM]); + } +} + +#[test] +fn test_multi_head_attention_different_head_counts() { + // Test with different numbers of heads (must divide EMBEDDING_DIM evenly) + let seq_len = 5; + let valid_head_counts = vec![1, 2, 4, 8, 16, 32, 64, 128]; + + for num_heads in valid_head_counts { + if EMBEDDING_DIM.is_multiple_of(num_heads) { + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, num_heads); + let input = Array2::ones((seq_len, EMBEDDING_DIM)); + + let _output = mha.forward(&input); + + assert_eq!(mha.num_heads, num_heads); + assert_eq!(mha.head_dim, EMBEDDING_DIM / num_heads); + } + } +} + +#[test] +fn test_multi_head_attention_residual_connection() { + // Test that residual connection is working + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + let seq_len = 3; + let input = Array2::from_shape_fn((seq_len, EMBEDDING_DIM), |(i, j)| { + (i * EMBEDDING_DIM + j) as f32 + }); + + let output = mha.forward(&input); + + // Output should not be zero due to residual connection + let output_norm: f32 = output.iter().map(|&x| x * x).sum::().sqrt(); + assert!(output_norm > 0.0, "Output should not be zero"); + + // Output should be different from input (attention transforms it) + let mut has_difference = false; + for i in 0..seq_len { + for j in 0..EMBEDDING_DIM { + if (output[[i, j]] - input[[i, j]]).abs() > 1e-3 { + has_difference = true; + break; + } + } + } + assert!( + has_difference, + "Output should differ from input due to attention computation" + ); +} + +#[test] +fn test_multi_head_attention_backward_pass() { + // Test backward pass + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + let seq_len = 3; + let input = Array2::ones((seq_len, EMBEDDING_DIM)); + + // Forward pass + let _output = mha.forward(&input); + + // Backward pass with mock gradients + let grad_output = Array2::ones((seq_len, EMBEDDING_DIM)); + let grad_input = mha.backward(&grad_output, 0.001); + + // Check gradient shape + assert_eq!(grad_input.shape(), input.shape()); + + // Gradients should not be all zeros + let grad_norm: f32 = grad_input.iter().map(|&x| x * x).sum::().sqrt(); + assert!(grad_norm > 0.0, "Gradients should not be zero"); +} + +#[test] +fn test_multi_head_attention_parameter_count() { + let num_heads = 8; + let mha = MultiHeadAttention::new(EMBEDDING_DIM, num_heads); + + // Should have 4 weight matrices: W_q, W_k, W_v, W_o + // Each is EMBEDDING_DIM x EMBEDDING_DIM + let expected = 4 * EMBEDDING_DIM * EMBEDDING_DIM; + assert_eq!(mha.parameters(), expected); +} + +#[test] +fn test_multi_head_attention_causal_masking() { + // Test that causal masking prevents attention to future tokens + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 4); + let seq_len = 5; + + // Create input with distinct patterns for each position + let mut input = Array2::zeros((seq_len, EMBEDDING_DIM)); + for i in 0..seq_len { + for j in 0..EMBEDDING_DIM { + input[[i, j]] = (i as f32 + 1.0) * 10.0 + (j as f32); + } + } + + let output = mha.forward(&input); + + // Check that output shape is correct + assert_eq!(output.shape(), [seq_len, EMBEDDING_DIM]); + + // The first token should only attend to itself (plus residual) + // We can't verify exact values due to random initialization, + // but we can verify the output is not zero + let first_token_norm: f32 = output.row(0).iter().map(|&x| x * x).sum::().sqrt(); + assert!(first_token_norm > 0.0); +} + +#[test] +fn test_multi_head_attention_layer_type() { + let mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + assert_eq!(mha.layer_type(), "MultiHeadAttention"); +} + +#[test] +fn test_split_and_concat_heads_consistency() { + // Test that split and concat are inverse operations + let num_heads = 8; + let seq_len = 5; + let mha = MultiHeadAttention::new(EMBEDDING_DIM, num_heads); + + // Create random input + let input = Array2::from_shape_fn((seq_len, EMBEDDING_DIM), |(i, j)| { + ((i * EMBEDDING_DIM + j) as f32 * 0.1).sin() + }); + + // Split and concat + let heads = mha.split_heads(&input); + let reconstructed = mha.concat_heads(&heads); + + // Should be identical + for i in 0..seq_len { + for j in 0..EMBEDDING_DIM { + let diff = (reconstructed[[i, j]] - input[[i, j]]).abs(); + assert!( + diff < 1e-5, + "Split-concat should preserve values. Diff at ({}, {}): {}", + i, + j, + diff + ); + } + } +} + +#[test] +fn test_multi_head_attention_training_step() { + // Simulate a mini training step + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + let seq_len = 3; + let lr = 0.001; + + // Initial forward pass + let input = Array2::from_shape_fn((seq_len, EMBEDDING_DIM), |(i, j)| { + ((i + j) as f32 * 0.1).sin() + }); + let _output1 = mha.forward(&input); + + // Backward pass with gradients + let grad_output = Array2::ones((seq_len, EMBEDDING_DIM)) * 0.1; + let _grad_input = mha.backward(&grad_output, lr); + + // Second forward pass - output should be different due to weight updates + let output2 = mha.forward(&input); + + // Verify output2 is valid (weights were updated) + let output2_norm: f32 = output2.iter().map(|&x| x * x).sum::().sqrt(); + assert!( + output2_norm > 0.0, + "Output should be valid after weight update" + ); +} + +#[test] +fn test_multi_head_attention_numerical_stability() { + // Test with extreme values to check numerical stability + let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, 8); + let seq_len = 3; + + // Test with large values + let input_large = Array2::ones((seq_len, EMBEDDING_DIM)) * 100.0; + let output_large = mha.forward(&input_large); + assert!( + output_large.iter().all(|&x| x.is_finite()), + "Output should be finite with large inputs" + ); + + // Test with small values + let input_small = Array2::ones((seq_len, EMBEDDING_DIM)) * 0.001; + let output_small = mha.forward(&input_small); + assert!( + output_small.iter().all(|&x| x.is_finite()), + "Output should be finite with small inputs" + ); +} + +#[test] +#[should_panic(expected = "embedding_dim must be divisible by num_heads")] +fn test_multi_head_attention_invalid_head_count() { + // Should panic if num_heads doesn't divide embedding_dim + MultiHeadAttention::new(EMBEDDING_DIM, 7); // 128 is not divisible by 7 +} + +#[test] +fn test_multi_head_vs_single_head() { + // Compare single-head MHA with multi-head MHA + let seq_len = 3; + let input = Array2::from_shape_fn((seq_len, EMBEDDING_DIM), |(i, j)| { + ((i + j) as f32 * 0.1).sin() + }); + + // Single head + let mut mha_single = MultiHeadAttention::new(EMBEDDING_DIM, 1); + let output_single = mha_single.forward(&input); + + // Multiple heads + let mut mha_multi = MultiHeadAttention::new(EMBEDDING_DIM, 8); + let output_multi = mha_multi.forward(&input); + + // Both should have same shape + assert_eq!(output_single.shape(), output_multi.shape()); + + // Both should have non-zero outputs + let norm_single: f32 = output_single.iter().map(|&x| x * x).sum::().sqrt(); + let norm_multi: f32 = output_multi.iter().map(|&x| x * x).sum::().sqrt(); + assert!(norm_single > 0.0); + assert!(norm_multi > 0.0); + + // Outputs should differ (different initializations and computations) + let mut differs = false; + for i in 0..seq_len { + for j in 0..EMBEDDING_DIM { + if (output_single[[i, j]] - output_multi[[i, j]]).abs() > 1e-3 { + differs = true; + break; + } + } + } + assert!(differs, "Single-head and multi-head outputs should differ"); +} diff --git a/tests/transformer_test.rs b/tests/transformer_test.rs index 0fa49d1..ed82363 100644 --- a/tests/transformer_test.rs +++ b/tests/transformer_test.rs @@ -1,4 +1,7 @@ -use llm::{EMBEDDING_DIM, HIDDEN_DIM, Layer, transformer::TransformerBlock}; +use llm::{ + EMBEDDING_DIM, HIDDEN_DIM, Layer, + transformer::{MultiHeadTransformerBlock, TransformerBlock}, +}; use ndarray::Array2; #[test] @@ -14,3 +17,63 @@ fn test_transformer_block() { // Check output shape assert_eq!(output.shape(), [1, EMBEDDING_DIM]); } + +#[test] +fn test_multi_head_transformer_block_forward() { + let num_heads = 8; + let mut transformer = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + + // Create some input + let input = Array2::ones((5, EMBEDDING_DIM)); + + let output = transformer.forward(&input); + + // Check output shape + assert_eq!(output.shape(), [5, EMBEDDING_DIM]); +} + +#[test] +fn test_multi_head_transformer_block_backward() { + let num_heads = 4; + let mut transformer = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + + let input = Array2::ones((3, EMBEDDING_DIM)); + let _output = transformer.forward(&input); + + // Backward pass + let grad_output = Array2::ones((3, EMBEDDING_DIM)); + let grad_input = transformer.backward(&grad_output, 0.001); + + // Check gradient shape + assert_eq!(grad_input.shape(), input.shape()); +} + +#[test] +fn test_multi_head_transformer_block_parameter_count() { + let num_heads = 8; + let transformer = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + + // Should have more parameters than the single-head version due to W_o matrix + let params = transformer.parameters(); + assert!(params > 0, "Should have non-zero parameters"); +} + +#[test] +fn test_multi_head_transformer_different_sequence_lengths() { + let num_heads = 8; + let mut transformer = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + + // Test with different sequence lengths + for seq_len in 1..10 { + let input = Array2::ones((seq_len, EMBEDDING_DIM)); + let output = transformer.forward(&input); + assert_eq!(output.shape(), [seq_len, EMBEDDING_DIM]); + } +} + +#[test] +fn test_multi_head_transformer_layer_type() { + let num_heads = 8; + let transformer = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + assert_eq!(transformer.layer_type(), "MultiHeadTransformerBlock"); +}