From c0ebfcfe67eb021b5765f20336c58388fbc26dcb Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Thu, 1 Jan 2026 08:58:18 +0000 Subject: [PATCH 1/9] implement multi-layer mlp for llama3 draft --- scripts/train_eagle3.py | 8 ++++++++ specforge/modeling/draft/llama3_eagle.py | 14 +++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 881f3f7aa..873fabea0 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -92,6 +92,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: choices=["sglang", "hf", "custom"], help="The backend of the target model", ) + model_group.add_argument( + "--num-draft-hidden-layers", + type=int, + default=3, + help="The number of MLPs in the draft model decoder" + ) # dataset arguments dataset_group = parser.add_argument_group("dataset") @@ -370,12 +376,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model = AutoEagle3DraftModel.from_pretrained( draft_model_last_checkpoint, attention_backend=args.attention_backend, + num_draft_hidden_layers=args.num_draft_hidden_layers, torch_dtype=torch.bfloat16, ).cuda() else: draft_model = AutoEagle3DraftModel.from_config( draft_model_config, attention_backend=args.attention_backend, + num_draft_hidden_layers=args.num_draft_hidden_layers, torch_dtype=torch.bfloat16, ).cuda() diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 2701a7add..54c655d50 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1234,7 +1234,7 @@ def forward(self, hidden_states): class LlamaDecoderLayer(nn.Module): - def __init__(self, config, attention_backend: str = "sdpa"): + def __init__(self, config, num_draft_hidden_layers=1, attention_backend: str = "sdpa"): super().__init__() self.hidden_size = config.hidden_size @@ -1249,7 +1249,7 @@ def __init__(self, config, attention_backend: str = "sdpa"): raise ValueError(f"Unknown attention backend {attention_backend}") self.attention_backend = attention_backend - self.mlp = LlamaMLP(config) + self.mlps = nn.Sequential(*[LlamaMLP(config) for _ in range(num_draft_hidden_layers)]) # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1307,7 +1307,7 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlps(hidden_states) hidden_states = residual + hidden_states # outputs = (hidden_states, return_hidden) @@ -1318,7 +1318,7 @@ class LlamaForCausalLMEagle3(Eagle3DraftModel): config_class = LlamaConfig - def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: + def __init__(self, config, num_draft_hidden_layers=1, quant_config=None, attention_backend="sdpa") -> None: super().__init__(config) self.config = config self.quant_config = quant_config @@ -1328,7 +1328,11 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + self.midlayer = LlamaDecoderLayer( + config, + num_draft_hidden_layers=num_draft_hidden_layers, + attention_backend=attention_backend + ) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( From fb5f3894f8a8dbf5c4b61ddf4b8e47b4d373200f Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Thu, 1 Jan 2026 08:59:46 +0000 Subject: [PATCH 2/9] update old hidden_layers comment in config utils.py --- specforge/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/specforge/utils.py b/specforge/utils.py index 57a423bbd..20d2fbb1f 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -155,7 +155,8 @@ def generate_draft_model_config( draft_config[draft_param] = value # Special handling for some parameters - # Ensure num_hidden_layers is always 1 (EAGLE3 feature) + # Ensure (target) num_hidden_layers is 1. However, the draft model can + # deviate from this configuration with the command line arg num_draft_hidden_layers draft_config["num_hidden_layers"] = 1 # Keep some fixed draft model specific parameters From ae050ee6dcc8e9f1f24fa1ffe893cfdb33567411 Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Thu, 1 Jan 2026 09:00:24 +0000 Subject: [PATCH 3/9] update unittest for draft llama3 for the num_draft_hidden_layers param --- tests/test_modeling/test_draft/test_llama3.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index b0fa86c80..242f3f051 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -5,6 +5,7 @@ from unittest.mock import patch import torch +import torch.nn as nn from transformers import LlamaConfig from specforge.modeling.draft.llama3_eagle import ( @@ -47,15 +48,19 @@ def setUp(self): } self.config = LlamaConfig(**config_dict) + self.num_draft_hidden_layers = 3 def tearDown(self): shutil.rmtree(self.temp_dir) def test_model_initialization(self): - model = LlamaForCausalLMEagle3(self.config) + model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) - self.assertIsInstance(model.midlayer.mlp, LlamaMLP) + self.assertIsInstance(model.midlayer.mlps, nn.Sequential) + self.assertEqual(len(model.midlayer.mlps), self.num_draft_hidden_layers) + for i in range(self.num_draft_hidden_layers): + self.assertIsInstance(model.midlayer.mlps[i], LlamaMLP) self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) @@ -63,7 +68,7 @@ def test_model_initialization(self): def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" - model = LlamaForCausalLMEagle3(self.config) + model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) self.config.save_pretrained(self.temp_dir) @@ -76,7 +81,7 @@ def test_save_pretrained(self): @patch("transformers.modeling_utils.PreTrainedModel.from_pretrained") def test_from_pretrained_mock(self, mock_from_pretrained): """mock""" - mock_model = LlamaForCausalLMEagle3(self.config) + mock_model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) mock_from_pretrained.return_value = mock_model loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir) @@ -85,7 +90,7 @@ def test_from_pretrained_mock(self, mock_from_pretrained): def test_model_forward_pass(self): """forward""" - model = LlamaForCausalLMEagle3(self.config) + model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) model.eval() batch_size = 2 @@ -105,8 +110,8 @@ def test_model_forward_pass(self): self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size)) def test_state_dict_compatibility(self): - model1 = LlamaForCausalLMEagle3(self.config) - model2 = LlamaForCausalLMEagle3(self.config) + model1 = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + model2 = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) state_dict = model1.state_dict() From 17e8e6f62876f96d74be827debb8ba927bdc8a6e Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Fri, 2 Jan 2026 13:13:15 +0000 Subject: [PATCH 4/9] reformat code with pre-commit hooks --- scripts/train_eagle3.py | 2 +- specforge/modeling/draft/llama3_eagle.py | 22 +++++++++++----- specforge/utils.py | 2 +- tests/test_modeling/test_draft/test_llama3.py | 26 ++++++++++++++----- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 873fabea0..9b0a85313 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -96,7 +96,7 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: "--num-draft-hidden-layers", type=int, default=3, - help="The number of MLPs in the draft model decoder" + help="The number of MLPs in the draft model decoder", ) # dataset arguments diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 54c655d50..0b7ff13ff 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1234,7 +1234,9 @@ def forward(self, hidden_states): class LlamaDecoderLayer(nn.Module): - def __init__(self, config, num_draft_hidden_layers=1, attention_backend: str = "sdpa"): + def __init__( + self, config, num_draft_hidden_layers=1, attention_backend: str = "sdpa" + ): super().__init__() self.hidden_size = config.hidden_size @@ -1249,7 +1251,9 @@ def __init__(self, config, num_draft_hidden_layers=1, attention_backend: str = " raise ValueError(f"Unknown attention backend {attention_backend}") self.attention_backend = attention_backend - self.mlps = nn.Sequential(*[LlamaMLP(config) for _ in range(num_draft_hidden_layers)]) + self.mlps = nn.Sequential( + *[LlamaMLP(config) for _ in range(num_draft_hidden_layers)] + ) # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1318,7 +1322,13 @@ class LlamaForCausalLMEagle3(Eagle3DraftModel): config_class = LlamaConfig - def __init__(self, config, num_draft_hidden_layers=1, quant_config=None, attention_backend="sdpa") -> None: + def __init__( + self, + config, + num_draft_hidden_layers=1, + quant_config=None, + attention_backend="sdpa", + ) -> None: super().__init__(config) self.config = config self.quant_config = quant_config @@ -1329,9 +1339,9 @@ def __init__(self, config, num_draft_hidden_layers=1, quant_config=None, attenti config.vocab_size, config.hidden_size, config.pad_token_id ) self.midlayer = LlamaDecoderLayer( - config, - num_draft_hidden_layers=num_draft_hidden_layers, - attention_backend=attention_backend + config, + num_draft_hidden_layers=num_draft_hidden_layers, + attention_backend=attention_backend, ) if hasattr(config, "target_hidden_size"): diff --git a/specforge/utils.py b/specforge/utils.py index 20d2fbb1f..feee734d9 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -155,7 +155,7 @@ def generate_draft_model_config( draft_config[draft_param] = value # Special handling for some parameters - # Ensure (target) num_hidden_layers is 1. However, the draft model can + # Ensure (target) num_hidden_layers is 1. However, the draft model can # deviate from this configuration with the command line arg num_draft_hidden_layers draft_config["num_hidden_layers"] = 1 diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index 242f3f051..62c002dd7 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -54,13 +54,15 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_model_initialization(self): - model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + model = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) self.assertIsInstance(model.midlayer.mlps, nn.Sequential) self.assertEqual(len(model.midlayer.mlps), self.num_draft_hidden_layers) for i in range(self.num_draft_hidden_layers): - self.assertIsInstance(model.midlayer.mlps[i], LlamaMLP) + self.assertIsInstance(model.midlayer.mlps[i], LlamaMLP) self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) @@ -68,7 +70,9 @@ def test_model_initialization(self): def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" - model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + model = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) self.config.save_pretrained(self.temp_dir) @@ -81,7 +85,9 @@ def test_save_pretrained(self): @patch("transformers.modeling_utils.PreTrainedModel.from_pretrained") def test_from_pretrained_mock(self, mock_from_pretrained): """mock""" - mock_model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + mock_model = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) mock_from_pretrained.return_value = mock_model loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir) @@ -90,7 +96,9 @@ def test_from_pretrained_mock(self, mock_from_pretrained): def test_model_forward_pass(self): """forward""" - model = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + model = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) model.eval() batch_size = 2 @@ -110,8 +118,12 @@ def test_model_forward_pass(self): self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size)) def test_state_dict_compatibility(self): - model1 = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) - model2 = LlamaForCausalLMEagle3(self.config, num_draft_hidden_layers=self.num_draft_hidden_layers) + model1 = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) + model2 = LlamaForCausalLMEagle3( + self.config, num_draft_hidden_layers=self.num_draft_hidden_layers + ) state_dict = model1.state_dict() From 8c886a1f878d1acd31bf910d2a8e62b3a4712568 Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Sun, 4 Jan 2026 10:44:22 +0000 Subject: [PATCH 5/9] revert old changes based on num_draft_hidden_layers --- scripts/train_eagle3.py | 8 ----- specforge/modeling/draft/llama3_eagle.py | 24 +++----------- specforge/utils.py | 3 +- tests/test_modeling/test_draft/test_llama3.py | 31 +++++-------------- 4 files changed, 13 insertions(+), 53 deletions(-) diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 9b0a85313..881f3f7aa 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -92,12 +92,6 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: choices=["sglang", "hf", "custom"], help="The backend of the target model", ) - model_group.add_argument( - "--num-draft-hidden-layers", - type=int, - default=3, - help="The number of MLPs in the draft model decoder", - ) # dataset arguments dataset_group = parser.add_argument_group("dataset") @@ -376,14 +370,12 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model = AutoEagle3DraftModel.from_pretrained( draft_model_last_checkpoint, attention_backend=args.attention_backend, - num_draft_hidden_layers=args.num_draft_hidden_layers, torch_dtype=torch.bfloat16, ).cuda() else: draft_model = AutoEagle3DraftModel.from_config( draft_model_config, attention_backend=args.attention_backend, - num_draft_hidden_layers=args.num_draft_hidden_layers, torch_dtype=torch.bfloat16, ).cuda() diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 0b7ff13ff..2701a7add 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1234,9 +1234,7 @@ def forward(self, hidden_states): class LlamaDecoderLayer(nn.Module): - def __init__( - self, config, num_draft_hidden_layers=1, attention_backend: str = "sdpa" - ): + def __init__(self, config, attention_backend: str = "sdpa"): super().__init__() self.hidden_size = config.hidden_size @@ -1251,9 +1249,7 @@ def __init__( raise ValueError(f"Unknown attention backend {attention_backend}") self.attention_backend = attention_backend - self.mlps = nn.Sequential( - *[LlamaMLP(config) for _ in range(num_draft_hidden_layers)] - ) + self.mlp = LlamaMLP(config) # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1311,7 +1307,7 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlps(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states # outputs = (hidden_states, return_hidden) @@ -1322,13 +1318,7 @@ class LlamaForCausalLMEagle3(Eagle3DraftModel): config_class = LlamaConfig - def __init__( - self, - config, - num_draft_hidden_layers=1, - quant_config=None, - attention_backend="sdpa", - ) -> None: + def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: super().__init__(config) self.config = config self.quant_config = quant_config @@ -1338,11 +1328,7 @@ def __init__( self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer( - config, - num_draft_hidden_layers=num_draft_hidden_layers, - attention_backend=attention_backend, - ) + self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( diff --git a/specforge/utils.py b/specforge/utils.py index feee734d9..57a423bbd 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -155,8 +155,7 @@ def generate_draft_model_config( draft_config[draft_param] = value # Special handling for some parameters - # Ensure (target) num_hidden_layers is 1. However, the draft model can - # deviate from this configuration with the command line arg num_draft_hidden_layers + # Ensure num_hidden_layers is always 1 (EAGLE3 feature) draft_config["num_hidden_layers"] = 1 # Keep some fixed draft model specific parameters diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index 62c002dd7..b0fa86c80 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -5,7 +5,6 @@ from unittest.mock import patch import torch -import torch.nn as nn from transformers import LlamaConfig from specforge.modeling.draft.llama3_eagle import ( @@ -48,21 +47,15 @@ def setUp(self): } self.config = LlamaConfig(**config_dict) - self.num_draft_hidden_layers = 3 def tearDown(self): shutil.rmtree(self.temp_dir) def test_model_initialization(self): - model = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) + model = LlamaForCausalLMEagle3(self.config) self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) - self.assertIsInstance(model.midlayer.mlps, nn.Sequential) - self.assertEqual(len(model.midlayer.mlps), self.num_draft_hidden_layers) - for i in range(self.num_draft_hidden_layers): - self.assertIsInstance(model.midlayer.mlps[i], LlamaMLP) + self.assertIsInstance(model.midlayer.mlp, LlamaMLP) self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) @@ -70,9 +63,7 @@ def test_model_initialization(self): def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" - model = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) + model = LlamaForCausalLMEagle3(self.config) self.config.save_pretrained(self.temp_dir) @@ -85,9 +76,7 @@ def test_save_pretrained(self): @patch("transformers.modeling_utils.PreTrainedModel.from_pretrained") def test_from_pretrained_mock(self, mock_from_pretrained): """mock""" - mock_model = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) + mock_model = LlamaForCausalLMEagle3(self.config) mock_from_pretrained.return_value = mock_model loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir) @@ -96,9 +85,7 @@ def test_from_pretrained_mock(self, mock_from_pretrained): def test_model_forward_pass(self): """forward""" - model = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) + model = LlamaForCausalLMEagle3(self.config) model.eval() batch_size = 2 @@ -118,12 +105,8 @@ def test_model_forward_pass(self): self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size)) def test_state_dict_compatibility(self): - model1 = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) - model2 = LlamaForCausalLMEagle3( - self.config, num_draft_hidden_layers=self.num_draft_hidden_layers - ) + model1 = LlamaForCausalLMEagle3(self.config) + model2 = LlamaForCausalLMEagle3(self.config) state_dict = model1.state_dict() From 88620b49cf82a300c777b1998b8f73023635d39e Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Sun, 11 Jan 2026 07:27:13 +0000 Subject: [PATCH 6/9] llama3 draft model with multi-layer decoders implementation and tests --- specforge/modeling/draft/base.py | 6 +- specforge/modeling/draft/llama3_eagle.py | 61 ++++++++++--------- tests/test_modeling/test_draft/test_llama3.py | 19 +++--- 3 files changed, 47 insertions(+), 39 deletions(-) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..87542bea5 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -24,7 +24,7 @@ import json import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, List import torch from huggingface_hub import snapshot_download @@ -98,10 +98,10 @@ def backbone( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, - cache_hidden: torch.Tensor, + caches_hidden: List[List[List[torch.Tensor]]], attention_mask: torch.Tensor, position_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, + past_key_values: List[Optional[Cache]] = None, use_cache: bool = True, ) -> torch.Tensor: """ diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 2701a7add..3981a04be 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1328,7 +1328,10 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, config.pad_token_id ) - self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + self.num_hidden_layers = config.num_hidden_layers + self.midlayers = nn.ModuleList( + [LlamaDecoderLayer(config, attention_backend=attention_backend) for _ in range(self.num_hidden_layers)] + ) if hasattr(config, "target_hidden_size"): self.fc = torch.nn.Linear( @@ -1366,11 +1369,11 @@ def forward( position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)` """ if ttt_length == 1: - print_with_rank("using ttt_length 1, no need to cache hidden states") - cache_hidden = None + print_with_rank("using ttt_length 1, no need to cache hidden states for decoder layer(s)") + caches_hidden = None else: - print_with_rank(f"using ttt_length {ttt_length}, caching hidden states") - cache_hidden = [[], []] + print_with_rank(f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)") + caches_hidden = [[[], []] for _ in range(self.num_hidden_layers)] batch_size, seq_length, _ = hidden_states.size() @@ -1390,15 +1393,15 @@ def forward( # fc hidden_states = self.fc(hidden_states) - hidden_states = self.midlayer( - input_emb=inputs_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - output_attentions=False, - use_cache=False, + + hidden_states = self.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + caches_hidden=caches_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + use_cache=False, ) # norm @@ -1422,19 +1425,21 @@ def backbone( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, - cache_hidden: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, + caches_hidden: Optional[List[List[List[torch.Tensor]]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[Cache]] = None, use_cache: bool = True, ) -> torch.Tensor: - return self.midlayer( - input_emb=input_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=False, - use_cache=False, - ) + for i, layer in enumerate(self.midlayers): + hidden_states = layer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=caches_hidden[i] if caches_hidden is not None else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values[i] if past_key_values is not None else None, + output_attentions=False, + use_cache=False, + ) + return hidden_states \ No newline at end of file diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index b0fa86c80..ad15b7a39 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -5,6 +5,7 @@ from unittest.mock import patch import torch +import torch.nn as nn from transformers import LlamaConfig from specforge.modeling.draft.llama3_eagle import ( @@ -35,7 +36,7 @@ def setUp(self): "model_type": "llama", "num_attention_heads": 32, "num_key_value_heads": 8, - "num_hidden_layers": 1, + "num_hidden_layers": 3, "pad_token_id": 0, "rms_norm_eps": 1e-05, "tie_word_embeddings": False, @@ -53,13 +54,15 @@ def tearDown(self): def test_model_initialization(self): model = LlamaForCausalLMEagle3(self.config) - - self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) - self.assertIsInstance(model.midlayer.mlp, LlamaMLP) - self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) - self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) - self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) - self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size) + self.assertEqual(model.num_hidden_layers, self.config.num_hidden_layers) + self.assertIsInstance(model.midlayers, nn.ModuleList) + for layer in model.midlayers: + self.assertIsInstance(layer.self_attn, LlamaAttention) + self.assertIsInstance(layer.mlp, LlamaMLP) + self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm) + self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm) + self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm) + self.assertEqual(layer.hidden_size, self.config.hidden_size) def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" From c02bc6d9c395f429f00233015fb525f1d4a09efa Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Sun, 11 Jan 2026 07:56:19 +0000 Subject: [PATCH 7/9] precommit format --- specforge/modeling/draft/base.py | 2 +- specforge/modeling/draft/llama3_eagle.py | 51 +++++++++++-------- tests/test_modeling/test_draft/test_llama3.py | 12 ++--- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index 87542bea5..3ad9c6564 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -24,7 +24,7 @@ import json import os from abc import ABC, abstractmethod -from typing import Optional, List +from typing import List, Optional import torch from huggingface_hub import snapshot_download diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 3981a04be..12e877c8f 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1330,7 +1330,10 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.num_hidden_layers = config.num_hidden_layers self.midlayers = nn.ModuleList( - [LlamaDecoderLayer(config, attention_backend=attention_backend) for _ in range(self.num_hidden_layers)] + [ + LlamaDecoderLayer(config, attention_backend=attention_backend) + for _ in range(self.num_hidden_layers) + ] ) if hasattr(config, "target_hidden_size"): @@ -1369,10 +1372,14 @@ def forward( position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)` """ if ttt_length == 1: - print_with_rank("using ttt_length 1, no need to cache hidden states for decoder layer(s)") + print_with_rank( + "using ttt_length 1, no need to cache hidden states for decoder layer(s)" + ) caches_hidden = None else: - print_with_rank(f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)") + print_with_rank( + f"using ttt_length {ttt_length}, caching hidden states for decoder layer(s)" + ) caches_hidden = [[[], []] for _ in range(self.num_hidden_layers)] batch_size, seq_length, _ = hidden_states.size() @@ -1395,13 +1402,13 @@ def forward( hidden_states = self.fc(hidden_states) hidden_states = self.backbone( - input_embeds=inputs_embeds, - hidden_states=hidden_states, - caches_hidden=caches_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=None, - use_cache=False, + input_embeds=inputs_embeds, + hidden_states=hidden_states, + caches_hidden=caches_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + use_cache=False, ) # norm @@ -1432,14 +1439,16 @@ def backbone( use_cache: bool = True, ) -> torch.Tensor: for i, layer in enumerate(self.midlayers): - hidden_states = layer( - input_emb=input_embeds, - hidden_states=hidden_states, - cache_hidden=caches_hidden[i] if caches_hidden is not None else None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values[i] if past_key_values is not None else None, - output_attentions=False, - use_cache=False, - ) - return hidden_states \ No newline at end of file + hidden_states = layer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=caches_hidden[i] if caches_hidden is not None else None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=( + past_key_values[i] if past_key_values is not None else None + ), + output_attentions=False, + use_cache=False, + ) + return hidden_states diff --git a/tests/test_modeling/test_draft/test_llama3.py b/tests/test_modeling/test_draft/test_llama3.py index ad15b7a39..505b1abd0 100644 --- a/tests/test_modeling/test_draft/test_llama3.py +++ b/tests/test_modeling/test_draft/test_llama3.py @@ -57,12 +57,12 @@ def test_model_initialization(self): self.assertEqual(model.num_hidden_layers, self.config.num_hidden_layers) self.assertIsInstance(model.midlayers, nn.ModuleList) for layer in model.midlayers: - self.assertIsInstance(layer.self_attn, LlamaAttention) - self.assertIsInstance(layer.mlp, LlamaMLP) - self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm) - self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm) - self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm) - self.assertEqual(layer.hidden_size, self.config.hidden_size) + self.assertIsInstance(layer.self_attn, LlamaAttention) + self.assertIsInstance(layer.mlp, LlamaMLP) + self.assertIsInstance(layer.hidden_norm, LlamaRMSNorm) + self.assertIsInstance(layer.input_layernorm, LlamaRMSNorm) + self.assertIsInstance(layer.post_attention_layernorm, LlamaRMSNorm) + self.assertEqual(layer.hidden_size, self.config.hidden_size) def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" From 71e21cb6c9ad4d50a65fa801f57af6aa9d5b5498 Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Sun, 11 Jan 2026 08:14:07 +0000 Subject: [PATCH 8/9] fix signature of past_key_values --- specforge/modeling/draft/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index 3ad9c6564..86e9035a4 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -101,7 +101,7 @@ def backbone( caches_hidden: List[List[List[torch.Tensor]]], attention_mask: torch.Tensor, position_ids: torch.Tensor, - past_key_values: List[Optional[Cache]] = None, + past_key_values: Optional[List[Cache]] = None, use_cache: bool = True, ) -> torch.Tensor: """ From e1f7fe6fc2df0150dad778711dd0b84dd2e2bfc0 Mon Sep 17 00:00:00 2001 From: xiaoxi Date: Sun, 11 Jan 2026 08:54:39 +0000 Subject: [PATCH 9/9] fix specforge core online eagle3 model and QWen model --- specforge/core/eagle3.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index aa740bb04..65cc04163 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -96,7 +96,7 @@ def forward( target: torch.Tensor, loss_mask: torch.Tensor, hidden_states: torch.Tensor, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: @@ -123,6 +123,7 @@ def forward( batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 + draft_num_hidden_layers = self.draft_model.num_hidden_layers # Step 2: project the concatenated hidden states to the target hidden size hidden_states = self.draft_model.project_hidden_states(hidden_states) @@ -166,13 +167,13 @@ def forward( # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift global_input_ids = input_ids if self.attention_backend == "sdpa": - cache_hidden = [[], []] + caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": - cache_hidden = None - past_key_values = DynamicCache() + caches_hidden = None + past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)] elif self.attention_backend == "usp": - cache_hidden = [[], []] + caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)] past_key_values = None hidden_states = self.prepare_usp_input(hidden_states) @@ -193,7 +194,7 @@ def forward( hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + caches_hidden=caches_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -374,7 +375,7 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, loss_mask: torch.Tensor, - past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, position_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.Tensor] = None, @@ -409,13 +410,14 @@ def forward( batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 + draft_num_hidden_layers = self.draft_model.num_hidden_layers # Step 2: project the concatenated hidden states to the target hidden size hidden_states = self.draft_model.project_hidden_states(hidden_states) # Step 3: process kv cache, position ids and position ids if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values[0][0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -465,11 +467,11 @@ def forward( vlosses = [] acces = [] if self.attention_backend == "sdpa": - cache_hidden = [[], []] + caches_hidden = [[[], []] for _ in range(draft_num_hidden_layers)] past_key_values = None elif self.attention_backend == "flex_attention": - cache_hidden = None - past_key_values = DynamicCache() + caches_hidden = None + past_key_values = [DynamicCache() for _ in range(draft_num_hidden_layers)] for idx in range(self.length): target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() @@ -484,7 +486,7 @@ def forward( hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, hidden_states=hidden_states, - cache_hidden=cache_hidden, + caches_hidden=caches_hidden, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values,