diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 154dfd17e..28495d29a 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -98,7 +98,7 @@ def forward( target: torch.Tensor, loss_mask: torch.Tensor, hidden_states: torch.Tensor, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, is_vlm: bool = False, @@ -127,6 +127,7 @@ def forward( batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 + draft_num_hidden_layers = self.draft_model.num_hidden_layers # Step 2: project the concatenated hidden states to the target hidden size if self.attention_backend == "usp": @@ -182,11 +183,11 @@ def forward( # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift global_input_ids = input_ids if self.attention_backend in ["sdpa", "fa", "usp"]: - cache_hidden = [[], []] + caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": - cache_hidden = None - past_key_values = DynamicCache() + caches_hidden = None + past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)] else: raise ValueError(f"Unknown attention backend: {self.attention_backend}") @@ -207,7 +208,7 @@ def forward( hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + caches_hidden=caches_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -388,7 +389,7 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, @@ -423,13 +424,14 @@ def forward( batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 + draft_num_hidden_layers = self.draft_model.num_hidden_layers # Step 2: project the concatenated hidden states to the target hidden size hidden_states = self.draft_model.project_hidden_states(hidden_states) # Step 3: process kv cache, position ids and position ids if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values[0][0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -479,11 +481,11 @@ def forward( vlosses = [] acces = [] if self.attention_backend in ["sdpa", "fa"]: - cache_hidden = [[], []] + caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": cache_hidden = None - past_key_values = DynamicCache() + past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)] else: raise ValueError(f"Unknown attention backend: {self.attention_backend}") @@ -500,7 +502,7 @@ def forward( hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + caches_hidden=caches_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..86e9035a4 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -24,7 +24,7 @@ import json import os from abc import ABC, abstractmethod -from typing import Optional +from typing import List, Optional import torch from huggingface_hub import snapshot_download @@ -98,10 +98,10 @@ def backbone( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, - cache_hidden: torch.Tensor, + caches_hidden: List[List[List[torch.Tensor]]], attention_mask: torch.Tensor, position_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[List[Cache]] = None, use_cache: bool = True, ) -> torch.Tensor: """ diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 552a3cf86..776a5df8d 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1336,7 +1336,13 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + self.num_hidden_layers = config.num_hidden_layers + self.midlayers = nn.ModuleList( + [ + LlamaDecoderLayer(config, attention_backend=attention_backend) + for _ in range(self.num_hidden_layers) + ] + ) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( @@ -1374,11 +1380,15 @@ def forward( position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)` """ if ttt_length == 1: - print_with_rank("using ttt_length 1, no need to cache hidden states") - cache_hidden = None + print_with_rank( + "using ttt_length 1, no need to cache hidden states for decoder layer(s)" + ) + caches_hidden = None else: - print_with_rank(f"using ttt_length {ttt_length}, caching hidden states") - cache_hidden = [[], []] + print_with_rank( + f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)" + ) + caches_hidden = [[[], []] for _ in range(self.num_hidden_layers)] batch_size, seq_length, _ = hidden_states.size() @@ -1398,14 +1408,14 @@ def forward( # fc hidden_states = self.fc(hidden_states) - hidden_states = self.midlayer( - input_emb=inputs_embeds, + + hidden_states = self.backbone( + input_embeds=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + caches_hidden=caches_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_values=None, - output_attentions=False, use_cache=False, ) @@ -1430,19 +1440,23 @@ def backbone( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, - cache_hidden: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, + caches_hidden: Optional[List[List[List[torch.Tensor]]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[Cache]] = None, use_cache: bool = True, ) -> torch.Tensor: - return self.midlayer( - input_emb=input_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=False, - use_cache=False, - ) + for i, layer in enumerate(self.midlayers): + hidden_states = layer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=caches_hidden[i] if caches_hidden is not None else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=( + past_key_values[i] if past_key_values is not None else None + ), + output_attentions=False, + use_cache=False, + ) + return hidden_states diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index b0fa86c80..505b1abd0 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -5,6 +5,7 @@ from unittest.mock import patch import torch +import torch.nn as nn from transformers import LlamaConfig from specforge.modeling.draft.llama3_eagle import ( @@ -35,7 +36,7 @@ def setUp(self): "model_type": "llama", "num_attention_heads": 32, "num_key_value_heads": 8, - "num_hidden_layers": 1, + "num_hidden_layers": 3, "pad_token_id": 0, "rms_norm_eps": 1e-05, "tie_word_embeddings": False, @@ -53,13 +54,15 @@ def tearDown(self): def test_model_initialization(self): model = LlamaForCausalLMEagle3(self.config) - - self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) - self.assertIsInstance(model.midlayer.mlp, LlamaMLP) - self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) - self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) - self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) - self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size) + self.assertEqual(model.num_hidden_layers, self.config.num_hidden_layers) + self.assertIsInstance(model.midlayers, nn.ModuleList) + for layer in model.midlayers: + self.assertIsInstance(layer.self_attn, LlamaAttention) + self.assertIsInstance(layer.mlp, LlamaMLP) + self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm) + self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm) + self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm) + self.assertEqual(layer.hidden_size, self.config.hidden_size) def test_save_pretrained(self): """Test the model's save_pretrained functionality."""