Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,18 @@ Input Text → Tokenization → Embeddings → Transformer Blocks → Output Pro

```
src/
├── main.rs # 🎯 Training pipeline and interactive mode
├── llm.rs # 🧠 Core LLM implementation and training logic
├── lib.rs # 📚 Library exports and constants
├── transformer.rs # 🔄 Transformer block (attention + feed-forward)
├── self_attention.rs # 👀 Multi-head self-attention mechanism
├── feed_forward.rs # ⚡ Position-wise feed-forward networks
├── embeddings.rs # 📊 Token embedding layer
├── output_projection.rs # 🎰 Final linear layer for vocabulary predictions
├── vocab.rs # 📝 Vocabulary management and tokenization
├── layer_norm.rs # 🧮 Layer normalization
└── adam.rs # 🏃 Adam optimizer implementation
├── main.rs # 🎯 Training pipeline and interactive mode
├── llm.rs # 🧠 Core LLM implementation and training logic
├── lib.rs # 📚 Library exports and constants
├── transformer.rs # 🔄 Transformer block (multi-head attention + feed-forward)
├── multi_head_attention.rs # 👀 Multi-head self-attention mechanism (default)
├── self_attention.rs # 👁️ Single-head attention (legacy)
├── feed_forward.rs # ⚡ Position-wise feed-forward networks
├── embeddings.rs # 📊 Token embedding layer
├── output_projection.rs # 🎰 Final linear layer for vocabulary predictions
├── vocab.rs # 📝 Vocabulary management and tokenization
├── layer_norm.rs # 🧮 Layer normalization
└── adam.rs # 🏃 Adam optimizer implementation

tests/
├── llm_test.rs # Tests for core LLM functionality
Expand Down Expand Up @@ -110,8 +111,9 @@ Model output: Rain is caused by water vapor in clouds condensing into droplets t
- **Vocabulary Size**: Dynamic (built from training data)
- **Embedding Dimension**: 128 (defined by `EMBEDDING_DIM` in `src/lib.rs`)
- **Hidden Dimension**: 256 (defined by `HIDDEN_DIM` in `src/lib.rs`)
- **Number of Attention Heads**: 8 (defined by `NUM_HEADS` in `src/lib.rs`)
- **Max Sequence Length**: 80 tokens (defined by `MAX_SEQ_LEN` in `src/lib.rs`)
- **Architecture**: 3 Transformer blocks + embeddings + output projection
- **Architecture**: 3 Multi-Head Transformer blocks + embeddings + output projection

### Training Details
- **Optimizer**: Adam with gradient clipping
Expand Down Expand Up @@ -174,7 +176,7 @@ Contributions are welcome! This project is perfect for learning and experimentat
- **📊 Evaluation metrics** - Perplexity, benchmarks, training visualizations

### Areas for Improvement
- **Advanced architectures** (multi-head attention, positional encoding, RoPE)
- **Advanced architectures** (~~multi-head attention~~✅, positional encoding, RoPE)
- **Training improvements** (different optimizers, learning rate schedules, regularization)
- **Data handling** (larger datasets, tokenizer improvements, streaming)
- **Model analysis** (attention visualization, gradient analysis, interpretability)
Expand All @@ -194,8 +196,8 @@ Contributions are welcome! This project is perfect for learning and experimentat

### Ideas for Contributions
- 🚀 **Beginner**: Model save/load, more training data, config files
- 🔥 **Intermediate**: Beam search, positional encodings, training checkpoints
- ⚡ **Advanced**: Multi-head attention, layer parallelization, custom optimizations
- 🔥 **Intermediate**: Beam search, positional encodings (RoPE), training checkpoints
- ⚡ **Advanced**: Layer parallelization, Flash Attention, custom optimizations

Questions? Open an issue or start a discussion!

Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod embeddings;
pub mod feed_forward;
pub mod layer_norm;
pub mod llm;
pub mod multi_head_attention;
pub mod output_projection;
pub mod self_attention;
pub mod transformer;
Expand All @@ -18,3 +19,4 @@ pub use vocab::Vocab;
pub const MAX_SEQ_LEN: usize = 80;
pub const EMBEDDING_DIM: usize = 128;
pub const HIDDEN_DIM: usize = 256;
pub const NUM_HEADS: usize = 8; // Number of attention heads for multi-head attention
7 changes: 4 additions & 3 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::cmp::Ordering;
use ndarray::{Array1, Array2, Axis};

use crate::{
EMBEDDING_DIM, Embeddings, HIDDEN_DIM, MAX_SEQ_LEN, Vocab, output_projection::OutputProjection,
transformer::TransformerBlock,
EMBEDDING_DIM, Embeddings, HIDDEN_DIM, MAX_SEQ_LEN, NUM_HEADS, Vocab,
output_projection::OutputProjection, transformer::MultiHeadTransformerBlock,
};
pub trait Layer {
fn layer_type(&self) -> &str;
Expand All @@ -24,7 +24,8 @@ pub struct LLM {

impl Default for LLM {
fn default() -> Self {
let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block =
MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS);
let output_projection = OutputProjection::new(EMBEDDING_DIM, Vocab::default_words().len());
Self {
vocab: Vocab::default(),
Expand Down
15 changes: 8 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::io::Write;

use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN};
use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN, NUM_HEADS};
use dataset_loader::{Dataset, DatasetType};

use crate::{
embeddings::Embeddings, llm::LLM, output_projection::OutputProjection,
transformer::TransformerBlock, vocab::Vocab,
transformer::MultiHeadTransformerBlock, vocab::Vocab,
};

mod adam;
Expand All @@ -14,6 +14,7 @@ mod embeddings;
mod feed_forward;
mod layer_norm;
mod llm;
mod multi_head_attention;
mod output_projection;
mod self_attention;
mod transformer;
Expand Down Expand Up @@ -44,9 +45,9 @@ fn main() {
let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s: &String| s.as_str()).collect();
let vocab = Vocab::new(vocab_words_refs);

let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_1 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS);
let transformer_block_2 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS);
let transformer_block_3 = MultiHeadTransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS);
let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len());
let embeddings = Embeddings::new(vocab.clone());
let mut llm = LLM::new(
Expand All @@ -63,8 +64,8 @@ fn main() {
println!("\n=== MODEL INFORMATION ===");
println!("Network architecture: {}", llm.network_description());
println!(
"Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}",
MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM
"Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}, num_heads: {}",
MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM, NUM_HEADS
);

println!("Total parameters: {}", llm.total_parameters());
Expand Down
Loading