Skip to content

Conversation

@zhenyu-02
Copy link

🚀 Add Multi-Head Attention Implementation

Overview

This PR implements a complete Multi-Head Self-Attention mechanism and makes it the default architecture for the RustGPT project, replacing the previous single-head attention implementation.

🎯 Motivation

As mentioned in README.md#L177, multi-head attention was listed as a desired feature under "Areas for Improvement". Multi-head attention is a core component of modern Transformer architectures (GPT, BERT, etc.) and provides:

  • Better representation learning: Multiple heads can attend to different aspects of the input simultaneously
  • Improved model capacity: More parameters dedicated to attention computation
  • Standard architecture: Aligns with mainstream Transformer implementations

📋 Changes

New Files

  • src/multi_head_attention.rs (405 lines) - Complete multi-head attention implementation with forward/backward passes
  • tests/multi_head_attention_test.rs (278 lines) - Comprehensive test suite with 13 test cases
  • 📚 MULTI_HEAD_ATTENTION.md - Technical documentation and usage guide
  • 📚 MIGRATION_TO_MULTI_HEAD.md - Migration guide and architecture comparison

Modified Files

  • src/lib.rs - Added NUM_HEADS constant (default: 8)
  • src/llm.rs - Updated to use MultiHeadTransformerBlock by default
  • src/main.rs - All 3 transformer layers now use multi-head attention
  • src/transformer.rs - Added MultiHeadTransformerBlock (legacy TransformerBlock kept for backward compatibility)
  • tests/llm_test.rs - Updated tests to match new architecture
  • tests/transformer_test.rs - Added 5 new tests for MultiHeadTransformerBlock
  • README.md - Updated documentation to reflect multi-head attention implementation

🔧 Technical Details

Architecture

Input [seq_len, embedding_dim=128]
  ↓
Multi-Head Attention (8 heads, head_dim=16)
  ├─ Head 1: Q₁, K₁, V₁ → Attention₁
  ├─ Head 2: Q₂, K₂, V₂ → Attention₂
  ├─ ...
  └─ Head 8: Q₈, K₈, V₈ → Attention₈
  ↓
Concatenate → Output Projection (W_o)
  ↓
Residual Connection + Layer Norm
  ↓
Feed-Forward Network
  ↓
Residual Connection + Layer Norm

Key Features

  • Configurable number of heads via NUM_HEADS constant
  • Causal masking for autoregressive generation
  • Residual connections for training stability
  • Full gradient computation for backpropagation
  • Adam optimizer integration for all weight matrices
  • Numerical stability (tested with extreme values)

Parameter Changes

Before (Single-Head):

  • 3 weight matrices (W_q, W_k, W_v): 49,152 parameters per transformer block

After (Multi-Head):

  • 4 weight matrices (W_q, W_k, W_v, W_o): 65,536 parameters per transformer block
  • +33% parameters dedicated to attention mechanism

✅ Testing

All tests pass (49/49):

$ cargo test
...
running 13 tests  # multi_head_attention_test
running 6 tests   # transformer_test (including 5 new tests)
running 5 tests   # llm_test
...
test result: ok. 49 passed; 0 failed; 0 ignored

Test Coverage

  • ✅ Forward/backward propagation
  • ✅ Different sequence lengths (1-10)
  • ✅ Different head counts (1, 2, 4, 8, 16, 32, 64, 128)
  • ✅ Residual connections
  • ✅ Causal masking
  • ✅ Parameter counting
  • ✅ Numerical stability
  • ✅ Head splitting/concatenation
  • ✅ Training step simulation
  • ✅ Edge cases (should panic when heads don't divide embedding_dim)

🔄 Backward Compatibility

The legacy SelfAttention and TransformerBlock implementations are preserved but marked with #[allow(dead_code)] for reference. Users can still access them if needed, but the default behavior now uses multi-head attention.

📊 Performance Impact

  • Increased parameters: +26% total model size (due to additional W_o matrix)
  • Improved expressiveness: 8 heads can learn different attention patterns
  • Training time: Slightly increased due to more computations per forward pass

🎨 Code Quality

  • ✅ No compilation errors
  • ✅ All clippy warnings addressed
  • ✅ Code formatted with cargo fmt
  • ✅ Comprehensive documentation with examples
  • ✅ Clear comments explaining complex logic

📚 Documentation

Extensive documentation provided:

  • MULTI_HEAD_ATTENTION.md: Technical implementation details, API documentation, usage examples
  • MIGRATION_TO_MULTI_HEAD.md: Migration guide, architecture comparison, rollback instructions
  • Inline code comments: Detailed explanations of the attention mechanism
  • README updates: Reflected new architecture and marked feature as complete ✅

🚀 Usage Example

use llm::{LLM, EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS};
use llm::multi_head_attention::MultiHeadAttention;

// Default LLM now uses multi-head attention (8 heads)
let llm = LLM::default();

// Or create custom multi-head attention layer
let mut mha = MultiHeadAttention::new(EMBEDDING_DIM, NUM_HEADS);
let output = mha.forward(&input);

🔍 Related Issues

Addresses the "multi-head attention" item mentioned in README.md under "Areas for Improvement" (line 177).


Model Output Example:

=== MODEL INFORMATION ===
Network architecture: Embeddings -> MultiHeadTransformerBlock -> MultiHeadTransformerBlock -> MultiHeadTransformerBlock -> OutputProjection
Model configuration -> max_seq_len: 80, embedding_dim: 128, hidden_dim: 256, num_heads: 8
Total parameters: ~234,000

@zhenyu-02
Copy link
Author

@tekaratzas please check, Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants