Skip to content

Conversation

@Adesoji1
Copy link

Do test well before you approve, reject PR if any error.

In src/self_attention.rs, i completely refactored it to implement multi-head attention as you said earlier

Also i Added num_heads and head_dim fields to track attention heads,
Added output projection matrix w_o for combining head outputs,
Implemented head splitting and concatenation logic,
Updated forward and backward passes to handle multiple attention heads,
So in src/transformer.rs, i updated the constructor to accept num_heads parameter

so that TransformerBlock::new() now takes (embedding_dim, hidden_dim, num_heads).
In src/main.rs: i updated it to use multi-head attention with 8 heads as default

The Default configuration is 8 attention heads
Furthermpre, in src/llm.rs: i Updated default implementation to use 8 heads

The Tests too were updated, so in all test files , i use files to use the new constructor signature for tests/self_attention_test.rs, tests/transformer_test.rs and tests/llm_test.rs.

adesoji@adesoji-Lenovo-Legion-7-15IMH05:~/Documents/RustGPT$ cargo test 2>&1 | tail -40
     Running tests/output_projection_test.rs (target/debug/deps/output_projection_test-4e9d08b337cbd89d)

running 5 tests
test test_output_projection_creation ... ok
test test_output_projection_forward ... ok
test test_output_projection_with_different_sequence_lengths ... ok
test test_output_projection_backward ... ok
test test_output_projection_training ... ok

test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/self_attention_test.rs (target/debug/deps/self_attention_test-8f28365a0525528f)

running 2 tests
test test_multi_head_attention_forward ... ok
test test_multi_head_attention_with_different_sequence_lengths ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.03s

     Running tests/transformer_test.rs (target/debug/deps/transformer_test-36edac1f2ee51240)

running 1 test
test test_transformer_block ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.06s

     Running tests/vocab_test.rs (target/debug/deps/vocab_test-791d2688aaccfe09)

running 2 tests
test test_vocab_default ... ok
test test_vocab_encode_decode ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Doc-tests llm

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

adesoji@adesoji-Lenovo-Legion-7-15IMH05:~/Documents/RustGPT$ cargo test self_attention 2>&1
   Compiling llm v0.1.0 (/home/adesoji/Documents/RustGPT)
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.25s
     Running unittests src/lib.rs (target/debug/deps/llm-132257558c9799fd)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running unittests src/main.rs (target/debug/deps/llm-0d36e5d4e517be12)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

     Running tests/adam_test.rs (target/debug/deps/adam_test-f057b38856edfa4f)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 5 filtered out; finished in 0.00s

     Running tests/dataset_loader_test.rs (target/debug/deps/dataset_loader_test-7322943525e74782)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 2 filtered out; finished in 0.00s

     Running tests/embeddings_test.rs (target/debug/deps/embeddings_test-fedc1afad78da023)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 5 filtered out; finished in 0.00s

     Running tests/feed_forward_test.rs (target/debug/deps/feed_forward_test-84b429fcfbfb2096)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 3 filtered out; finished in 0.00s

     Running tests/llm_test.rs (target/debug/deps/llm_test-12f94139712e0f73)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 5 filtered out; finished in 0.00s

     Running tests/output_projection_test.rs (target/debug/deps/output_projection_test-4e9d08b337cbd89d)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 5 filtered out; finished in 0.00s

     Running tests/self_attention_test.rs (target/debug/deps/self_attention_test-8f28365a0525528f)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 4 filtered out; finished in 0.00s

     Running tests/transformer_test.rs (target/debug/deps/transformer_test-36edac1f2ee51240)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 1 filtered out; finished in 0.00s

     Running tests/vocab_test.rs (target/debug/deps/vocab_test-791d2688aaccfe09)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 2 filtered out; finished in 0.00s

adesoji@adesoji-Lenovo-Legion-7-15IMH05:~/Documents/RustGPT$ cargo test --test self_attention_test 2>&1
    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.03s
     Running tests/self_attention_test.rs (target/debug/deps/self_attention_test-8f28365a0525528f)

running 4 tests
test test_multi_head_attention_forward ... ok
test test_multi_head_attention_with_different_sequence_lengths ... ok
test test_multi_head_attention_backward ... ok
test test_multi_head_attention_different_head_counts ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.23s

adesoji@adesoji-Lenovo-Legion-7-15IMH05:~/Documents/RustGPT$ cargo test 2>&1 | grep -E "(test result:|running)"
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 5 tests
test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 2 tests
test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 5 tests
test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.01s
running 3 tests
test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.05s
running 5 tests
test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.08s
running 5 tests
test result: ok. 5 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 4 tests
test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.27s
running 1 test
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.06s
running 2 tests
test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
running 0 tests
test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s
adesoji@adesoji-Lenovo-Legion-7-15IMH05:~/Documents/RustGPT$ cargo build --release 2>&1 | tail -5
   Compiling rand v0.9.2
   Compiling rand_distr v0.5.1
   Compiling csv v1.3.1
   Compiling llm v0.1.0 (/home/adesoji/Documents/RustGPT)
    Finished `release` profile [optimized] target(s) in 11.35s

impl Default for LLM {
fn default() -> Self {
let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let num_heads = 8; // Default to 8 attention heads
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this should be with with the rest of constants ex; MAX_SEQ_LEN and such

let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
// Using 8 attention heads (EMBEDDING_DIM=128 / 8 = 16 dim per head)
let num_heads = 8;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should share this with a universal const

impl Default for SelfAttention {
fn default() -> Self {
SelfAttention::new(EMBEDDING_DIM)
SelfAttention::new(EMBEDDING_DIM, 8) // 8 attention heads by default
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here!

w_o: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)),
cached_input: None,
cached_q: None,
cached_k: None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And caching!! Very cool

fn test_self_attention_forward() {
// Create self-attention module
let mut self_attention = SelfAttention::new(EMBEDDING_DIM);
// #[test]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of this commented section. either uncomment or delete

@tekaratzas
Copy link
Owner

Hey! Thanks for the PR!

Gave it a quick glance. Will take another closer pass once I get some energy.

Also got some merge conflicts

@Adesoji1
Copy link
Author

Adesoji1 commented Oct 22, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants