Add multi head attention #21
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
🚀 Add Multi-Head Attention Implementation
Overview
This PR implements a complete Multi-Head Self-Attention mechanism and makes it the default architecture for the RustGPT project, replacing the previous single-head attention implementation.
🎯 Motivation
As mentioned in README.md#L177, multi-head attention was listed as a desired feature under "Areas for Improvement". Multi-head attention is a core component of modern Transformer architectures (GPT, BERT, etc.) and provides:
📋 Changes
New Files
src/multi_head_attention.rs(405 lines) - Complete multi-head attention implementation with forward/backward passestests/multi_head_attention_test.rs(278 lines) - Comprehensive test suite with 13 test casesMULTI_HEAD_ATTENTION.md- Technical documentation and usage guideMIGRATION_TO_MULTI_HEAD.md- Migration guide and architecture comparisonModified Files
src/lib.rs- AddedNUM_HEADSconstant (default: 8)src/llm.rs- Updated to useMultiHeadTransformerBlockby defaultsrc/main.rs- All 3 transformer layers now use multi-head attentionsrc/transformer.rs- AddedMultiHeadTransformerBlock(legacyTransformerBlockkept for backward compatibility)tests/llm_test.rs- Updated tests to match new architecturetests/transformer_test.rs- Added 5 new tests forMultiHeadTransformerBlockREADME.md- Updated documentation to reflect multi-head attention implementation🔧 Technical Details
Architecture
Key Features
NUM_HEADSconstantParameter Changes
Before (Single-Head):
After (Multi-Head):
✅ Testing
All tests pass (49/49):
Test Coverage
🔄 Backward Compatibility
The legacy
SelfAttentionandTransformerBlockimplementations are preserved but marked with#[allow(dead_code)]for reference. Users can still access them if needed, but the default behavior now uses multi-head attention.📊 Performance Impact
🎨 Code Quality
cargo fmt📚 Documentation
Extensive documentation provided:
MULTI_HEAD_ATTENTION.md: Technical implementation details, API documentation, usage examplesMIGRATION_TO_MULTI_HEAD.md: Migration guide, architecture comparison, rollback instructions🚀 Usage Example
🔍 Related Issues
Addresses the "multi-head attention" item mentioned in README.md under "Areas for Improvement" (line 177).
Model Output Example: