Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def forward(
target: torch.Tensor,
loss_mask: torch.Tensor,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_ids: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
is_vlm: bool = False,
Expand Down Expand Up @@ -127,6 +127,7 @@ def forward(
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
draft_num_hidden_layers = self.draft_model.num_hidden_layers

# Step 2: project the concatenated hidden states to the target hidden size
if self.attention_backend == "usp":
Expand Down Expand Up @@ -182,11 +183,11 @@ def forward(
# for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift
global_input_ids = input_ids
if self.attention_backend in ["sdpa", "fa", "usp"]:
cache_hidden = [[], []]
caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
caches_hidden = None
past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)]
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

Expand All @@ -207,7 +208,7 @@ def forward(
hidden_states_out = self.draft_model.backbone(
input_embeds=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down Expand Up @@ -388,7 +389,7 @@ def forward(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
position_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -423,13 +424,14 @@ def forward(
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
draft_num_hidden_layers = self.draft_model.num_hidden_layers

# Step 2: project the concatenated hidden states to the target hidden size
hidden_states = self.draft_model.project_hidden_states(hidden_states)

# Step 3: process kv cache, position ids and position ids
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
past_key_values_length = past_key_values[0][0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
Expand Down Expand Up @@ -479,11 +481,11 @@ def forward(
vlosses = []
acces = []
if self.attention_backend in ["sdpa", "fa"]:
cache_hidden = [[], []]
caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)]
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

Expand All @@ -500,7 +502,7 @@ def forward(
hidden_states_out = self.draft_model.backbone(
input_embeds=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
6 changes: 3 additions & 3 deletions specforge/modeling/draft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional

import torch
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -98,10 +98,10 @@ def backbone(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
cache_hidden: torch.Tensor,
caches_hidden: List[List[List[torch.Tensor]]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
past_key_values: Optional[List[Cache]] = None,
use_cache: bool = True,
) -> torch.Tensor:
"""
Expand Down
60 changes: 37 additions & 23 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,13 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
)
self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend)
self.num_hidden_layers = config.num_hidden_layers
self.midlayers = nn.ModuleList(
[
LlamaDecoderLayer(config, attention_backend=attention_backend)
for _ in range(self.num_hidden_layers)
]
)

if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(
Expand Down Expand Up @@ -1374,11 +1380,15 @@ def forward(
position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)`
"""
if ttt_length == 1:
print_with_rank("using ttt_length 1, no need to cache hidden states")
cache_hidden = None
print_with_rank(
"using ttt_length 1, no need to cache hidden states for decoder layer(s)"
)
caches_hidden = None
else:
print_with_rank(f"using ttt_length {ttt_length}, caching hidden states")
cache_hidden = [[], []]
print_with_rank(
f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)"
)
caches_hidden = [[[], []] for _ in range(self.num_hidden_layers)]

batch_size, seq_length, _ = hidden_states.size()

Expand All @@ -1398,14 +1408,14 @@ def forward(

# fc
hidden_states = self.fc(hidden_states)
hidden_states = self.midlayer(
input_emb=inputs_embeds,

hidden_states = self.backbone(
input_embeds=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
caches_hidden=caches_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
output_attentions=False,
use_cache=False,
)

Expand All @@ -1430,19 +1440,23 @@ def backbone(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
cache_hidden: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
caches_hidden: Optional[List[List[List[torch.Tensor]]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Cache]] = None,
use_cache: bool = True,
) -> torch.Tensor:
return self.midlayer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=False,
)
for i, layer in enumerate(self.midlayers):
hidden_states = layer(
input_emb=input_embeds,
hidden_states=hidden_states,
cache_hidden=caches_hidden[i] if caches_hidden is not None else None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=(
past_key_values[i] if past_key_values is not None else None
),
output_attentions=False,
use_cache=False,
)
return hidden_states
19 changes: 11 additions & 8 deletions tests/test_modeling/test_draft/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import patch

import torch
import torch.nn as nn
from transformers import LlamaConfig

from specforge.modeling.draft.llama3_eagle import (
Expand Down Expand Up @@ -35,7 +36,7 @@ def setUp(self):
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"num_hidden_layers": 3,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
Expand All @@ -53,13 +54,15 @@ def tearDown(self):

def test_model_initialization(self):
model = LlamaForCausalLMEagle3(self.config)

self.assertIsInstance(model.midlayer.self_attn, LlamaAttention)
self.assertIsInstance(model.midlayer.mlp, LlamaMLP)
self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size)
self.assertEqual(model.num_hidden_layers, self.config.num_hidden_layers)
self.assertIsInstance(model.midlayers, nn.ModuleList)
for layer in model.midlayers:
self.assertIsInstance(layer.self_attn, LlamaAttention)
self.assertIsInstance(layer.mlp, LlamaMLP)
self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(layer.hidden_size, self.config.hidden_size)

def test_save_pretrained(self):
"""Test the model's save_pretrained functionality."""
Expand Down
Loading