-
Notifications
You must be signed in to change notification settings - Fork 254
added multisectionhead #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| impl Default for LLM { | ||
| fn default() -> Self { | ||
| let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); | ||
| let num_heads = 8; // Default to 8 attention heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this should be with with the rest of constants ex; MAX_SEQ_LEN and such
| let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); | ||
| let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); | ||
| // Using 8 attention heads (EMBEDDING_DIM=128 / 8 = 16 dim per head) | ||
| let num_heads = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should share this with a universal const
| impl Default for SelfAttention { | ||
| fn default() -> Self { | ||
| SelfAttention::new(EMBEDDING_DIM) | ||
| SelfAttention::new(EMBEDDING_DIM, 8) // 8 attention heads by default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here!
| w_o: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), | ||
| cached_input: None, | ||
| cached_q: None, | ||
| cached_k: None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And caching!! Very cool
| fn test_self_attention_forward() { | ||
| // Create self-attention module | ||
| let mut self_attention = SelfAttention::new(EMBEDDING_DIM); | ||
| // #[test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's get rid of this commented section. either uncomment or delete
|
Hey! Thanks for the PR! Gave it a quick glance. Will take another closer pass once I get some energy. Also got some merge conflicts |
|
Alright,I will work on the modifications
…On Tue, 21 Oct 2025, 11:49 pm Thomas Karatzas, ***@***.***> wrote:
*tekaratzas* left a comment (tekaratzas/RustGPT#24)
<#24 (comment)>
Hey! Thanks for the PR!
Gave it a quick glance. Will take another closer pass once I get some
energy.
Also got some merge conflicts
—
Reply to this email directly, view it on GitHub
<#24 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AE4COXCGNBIMULRNYFY2ROT3Y2Z73AVCNFSM6AAAAACJFIK2PWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTIMRZHAZTSNJWGA>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Do test well before you approve, reject PR if any error.
In src/self_attention.rs, i completely refactored it to implement multi-head attention as you said earlier
Also i Added num_heads and head_dim fields to track attention heads,
Added output projection matrix w_o for combining head outputs,
Implemented head splitting and concatenation logic,
Updated forward and backward passes to handle multiple attention heads,
So in src/transformer.rs, i updated the constructor to accept num_heads parameter
so that TransformerBlock::new() now takes (embedding_dim, hidden_dim, num_heads).
In src/main.rs: i updated it to use multi-head attention with 8 heads as default
The Default configuration is 8 attention heads
Furthermpre, in src/llm.rs: i Updated default implementation to use 8 heads
The Tests too were updated, so in all test files , i use files to use the new constructor signature for tests/self_attention_test.rs, tests/transformer_test.rs and tests/llm_test.rs.