From 24fe8b0c263c090fbc8e47e77eee9cd2e30cf855 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:52:20 -0400 Subject: [PATCH 1/6] Update README.md --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 146e3aa9..a7cf187a 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,7 @@ The *total-token* is the number of draft tokens. For smaller models and advanced ### With Code You can use our provided "eagenerate" for speedup generation just like using 'generate' from Hugging Face. Here is an example. ```python +import torch from eagle.model.ea_model import EaModel from fastchat.model import get_conversation_template model = EaModel.from_pretrained( @@ -205,7 +206,10 @@ model = EaModel.from_pretrained( torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", - total_token=-1 + total_token=-1, + # If using EAGLE, please uncomment the following line + use_eagle3=False + ) model.eval() your_message="Hello" From d611bbeced8c196510d79f49478fd3df4ca0fd2a Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:57:18 -0400 Subject: [PATCH 2/6] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a7cf187a..d6b9ea75 100644 --- a/README.md +++ b/README.md @@ -207,8 +207,8 @@ model = EaModel.from_pretrained( low_cpu_mem_usage=True, device_map="auto", total_token=-1, - # If using EAGLE, please uncomment the following line - use_eagle3=False + # If using EAGLE (not EAGLE 3), please uncomment the following line + # use_eagle3=False ) model.eval() From 93c00829a2385560628ecea5bcf8c4f7496fec27 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:57:48 -0400 Subject: [PATCH 3/6] Update cnets1.py --- eagle/model/cnets1.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/eagle/model/cnets1.py b/eagle/model/cnets1.py index 25cad25a..58c3ffcb 100644 --- a/eagle/model/cnets1.py +++ b/eagle/model/cnets1.py @@ -777,8 +777,30 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor): tree_position_ids = torch.sum(tree_mask, dim=1) - 1 tree_mask = tree_mask.float()[None, None] + ######################################## + # Change + ######################################## + # top_scores_index = torch.sort(top_scores.indices).values + + # get the cumulative log-probs (log-densities) for the selected drafts + draft_log_scores = scores_list[top_scores_index] # shape [total_tokens] + # convert to probabilities if you prefer: + # draft_probs = torch.exp(draft_log_scores) # shape [total_tokens] + + # then build draft_tokens as before + draft_tokens = ss_token_list[top_scores_index] + draft_tokens = torch.cat((sample_token, draft_tokens), dim=0) + + # make sure device is consistent + draft_log_scores = draft_log_scores.to(draft_tokens.device)[None] # shape [1, total_tokens] + # draft_probs = draft_probs.to(draft_tokens.device)[None] + ######################################## + # Change + ######################################## + draft_tokens = draft_tokens[None] + del parents_list, scores_list, ss_token, ss_token_list, draft_parents # with Timer("retrieve"): @@ -819,7 +841,7 @@ def custom_sort(lst): del mask_index, mask_index_list, noleaf_index, noleaf_num, leaf_num, max_depth, rid tree_position_ids = tree_position_ids.to(hidden_states.device) - return draft_tokens, retrieve_indices, tree_mask, tree_position_ids + return draft_tokens, retrieve_indices, tree_mask, tree_position_ids, draft_log_scores @@ -832,4 +854,5 @@ def count_parameters(model): if __name__ == "__main__": config = EConfig.from_pretrained('config.json') model = Model(config, load_emb=False) - print(model) \ No newline at end of file + + print(model) From 23f4abb6e13216462acd06eff9f499e442082916 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:58:31 -0400 Subject: [PATCH 4/6] Update utils.py --- eagle/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eagle/model/utils.py b/eagle/model/utils.py index dfc64bae..d92a48fb 100644 --- a/eagle/model/utils.py +++ b/eagle/model/utils.py @@ -250,8 +250,8 @@ def initialize_tree(input_ids, model, past_key_values, logits_processor): if outputs["hidden_states"][0].device != ea_device: outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]] hidden_states=torch.cat(outputs["hidden_states"],dim=-1) - draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) - return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token + draft_tokens, retrieve_indices,tree_mask,tree_position_ids, draft_log_scores = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) + return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token, draft_log_scores def reset_tree_mode( From 654f8f063a4c70de77b791f51491506a45efc618 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 5 Sep 2025 19:37:01 -0400 Subject: [PATCH 5/6] Update ea_model.py --- eagle/model/ea_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eagle/model/ea_model.py b/eagle/model/ea_model.py index f0d2e516..9a4c5e7f 100644 --- a/eagle/model/ea_model.py +++ b/eagle/model/ea_model.py @@ -238,7 +238,7 @@ def eagenerate( input_len = input_ids.shape[1] reset_tree_mode(self) # prefill - draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( + draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token, draft_log_scores = initialize_tree( input_ids, self, past_key_values, logits_processor ) new_token = 0 From 47de7b0e9ef609e93256b95c8226c963e09ae524 Mon Sep 17 00:00:00 2001 From: Jiankun Wei <72998341+david-wei-01001@users.noreply.github.com> Date: Fri, 5 Sep 2025 20:34:11 -0400 Subject: [PATCH 6/6] Update ea_model.py --- eagle/model/ea_model.py | 334 +++++++++++++++++++++++++++++++--------- 1 file changed, 262 insertions(+), 72 deletions(-) diff --git a/eagle/model/ea_model.py b/eagle/model/ea_model.py index 9a4c5e7f..1016f5a2 100644 --- a/eagle/model/ea_model.py +++ b/eagle/model/ea_model.py @@ -190,10 +190,118 @@ def forward( else: return outputs, hidden_states - @torch.no_grad() + # @torch.no_grad() + # def eagenerate( + # self, + # input_ids, + # temperature=0.0, + # top_p=0.0, + # top_k=0.0, + # max_new_tokens=512, + # max_length=2048, + # log=False, + # is_llama3=False, + + # ): + # if is_llama3: + # stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + + + # if temperature > 1e-5: + # logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + # else: + # logits_processor = None + # # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # # Avoid modifying the input_ids in-place + + # padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) + # input_ids = input_ids.clone() + # self.ea_layer.reset_kv() + + # # Initialize the past key and value states + # if hasattr(self, "past_key_values"): + # past_key_values = self.past_key_values + # past_key_values_data = self.past_key_values_data + # current_length_data = self.current_length_data + # # Reset the past key and value states + # current_length_data.zero_() + # else: + # ( + # past_key_values, + # past_key_values_data, + # current_length_data, + # ) = initialize_past_key_values(self.base_model,max_length=max_length) + # self.past_key_values = past_key_values + # self.past_key_values_data = past_key_values_data + # self.current_length_data = current_length_data + + # input_len = input_ids.shape[1] + # reset_tree_mode(self) + # # prefill + # draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token, draft_log_scores = initialize_tree( + # input_ids, self, past_key_values, logits_processor + # ) + # new_token = 0 + # max_length = max_length - self.ea_layer.total_tokens - 10 + # for idx in range(max_length): + # # with Timer("all"): + # self.base_model.model.tree_mask = tree_mask + + # draft_tokens = draft_tokens.to(input_ids.device) + # # Target model forward, get logits + # logits, hidden_state_new, outputs = tree_decoding( + # self, + # draft_tokens, + # past_key_values, + # tree_position_ids, + # input_ids, + # retrieve_indices, + # ) + # # retrieve_indices=tree_buffers["retrieve_indices"] + # # logits = logits[0, retrieve_indices] + # draft_tokens = torch.cat((draft_tokens, padding), dim=1) + # candidates = draft_tokens[0, retrieve_indices] + # # verification + # best_candidate, accept_length, sample_p = evaluate_posterior( + # logits, candidates, logits_processor + # ) + # # print(accept_length) + # # Adjusting the input sequence, draft model forward + # input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( + # input_ids, + # candidates, + # best_candidate, + # accept_length, + # retrieve_indices, + # logits_processor, + # new_token, + # past_key_values_data, + # current_length_data, + # self, + # hidden_state_new, + # sample_p + # ) + + # if is_llama3: + # if stop_token_id in input_ids[0, input_len:].tolist(): + # break + + # if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + # break + # if new_token > max_new_tokens: + # break + # if input_ids.shape[1] > max_length: + # break + # if not log: + # return input_ids + # else: + # return input_ids, new_token, idx + + # NOTE: remove any @torch.no_grad() decorator above this function def eagenerate( self, - input_ids, + input_ids=None, + inputs_embeds=None, # optional: continuous embeddings for prompt (for external optimization) temperature=0.0, top_p=0.0, top_k=0.0, @@ -201,101 +309,183 @@ def eagenerate( max_length=2048, log=False, is_llama3=False, - + enable_grad=False, # if True keep grad-enabled (for external backprop) + return_step_scores=False, # if True return per-step (draft_logp_first, verify_logp_first) lists ): if is_llama3: stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - - + if temperature > 1e-5: logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) else: logits_processor = None - # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" - # Avoid modifying the input_ids in-place - - padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) - input_ids = input_ids.clone() + + # Input handling: require exactly one of input_ids or inputs_embeds + assert (input_ids is None) ^ (inputs_embeds is None), "Provide exactly one of input_ids or inputs_embeds" + + device = None + if input_ids is not None: + device = input_ids.device + else: + device = inputs_embeds.device + + padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(device) + + # copy input_ids if provided (do not detach or clone inputs_embeds) + if input_ids is not None: + input_ids = input_ids.clone() + self.ea_layer.reset_kv() - - # Initialize the past key and value states + + # Initialize past key-values if hasattr(self, "past_key_values"): past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data - # Reset the past key and value states current_length_data.zero_() else: - ( - past_key_values, - past_key_values_data, - current_length_data, - ) = initialize_past_key_values(self.base_model,max_length=max_length) + (past_key_values, past_key_values_data, current_length_data) = initialize_past_key_values(self.base_model, max_length=max_length) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data - - input_len = input_ids.shape[1] + + if input_ids is not None: + input_len = input_ids.shape[1] + else: + input_len = inputs_embeds.shape[1] + reset_tree_mode(self) - # prefill + + # PREFILL: initialize_tree must be updated to accept inputs_embeds (if provided) + # and to return draft_log_scores (and any other draft score arrays) + # We assume initialize_tree returns: + # (draft_tokens, retrieve_indices, tree_mask, tree_position_ids, + # logits, hidden_state, sample_token, draft_log_scores) + # where draft_log_scores is shape [1, total_tokens] aligned with draft_tokens indices. draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token, draft_log_scores = initialize_tree( - input_ids, self, past_key_values, logits_processor + input_ids=input_ids, + inputs_embeds=inputs_embeds, + model=self, + past_key_values=past_key_values, + logits_processor=logits_processor ) + new_token = 0 max_length = max_length - self.ea_layer.total_tokens - 10 - for idx in range(max_length): - # with Timer("all"): - self.base_model.model.tree_mask = tree_mask - - draft_tokens = draft_tokens.to(input_ids.device) - # Target model forward, get logits - logits, hidden_state_new, outputs = tree_decoding( - self, - draft_tokens, - past_key_values, - tree_position_ids, - input_ids, - retrieve_indices, - ) - # retrieve_indices=tree_buffers["retrieve_indices"] - # logits = logits[0, retrieve_indices] - draft_tokens = torch.cat((draft_tokens, padding), dim=1) - candidates = draft_tokens[0, retrieve_indices] - # verification - best_candidate, accept_length, sample_p = evaluate_posterior( - logits, candidates, logits_processor - ) - # print(accept_length) - # Adjusting the input sequence, draft model forward - input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( - input_ids, - candidates, - best_candidate, - accept_length, - retrieve_indices, - logits_processor, - new_token, - past_key_values_data, - current_length_data, - self, - hidden_state_new, - sample_p - ) - - if is_llama3: - if stop_token_id in input_ids[0, input_len:].tolist(): + + # Prepare containers for external scoring + step_draft_logps = [] # list of tensors (scalars) or floats (kept as tensors if enable_grad True) + step_verify_logps = [] + + # Use grad enabled depending on enable_grad + torch.set_grad_enabled(enable_grad) + try: + for idx in range(max_length): + self.base_model.model.tree_mask = tree_mask + + draft_tokens = draft_tokens.to(device) + + # Target model forward -> tree_decoding MUST accept inputs_embeds if provided + logits, hidden_state_new, outputs = tree_decoding( + self, + draft_tokens, + past_key_values, + tree_position_ids, + input_ids=input_ids, + inputs_embeds=inputs_embeds, + retrieve_indices=retrieve_indices, + ) + # logits shape: [num_candidates, seq_len, vocab] + + # prepare candidates matrix + draft_tokens = torch.cat((draft_tokens, padding), dim=1) + candidates = draft_tokens[0, retrieve_indices] # [num_candidates, seq_len] + + # --- compute verifier log-probs (log-space) --- + log_probs = torch.log_softmax(logits, dim=-1) # shape [num_candidates, seq_len, vocab] + + # first EA-generated token for candidate j is candidates[j, 1] + first_ea_token_ids = candidates[:, 1].to(log_probs.device) # shape [num_candidates] + + # verifier log-prob for first EA token per candidate: logits[:,0] predicts candidates[:,1] + logp_verify_first_per_candidate = log_probs[torch.arange(candidates.shape[0], device=log_probs.device), 0, first_ea_token_ids] + # shape [num_candidates] + + # --- compute draft first-token log-prob per candidate from draft_log_scores --- + # we assume draft_log_scores is aligned with draft_tokens indices; + # The first EA token index for candidate j is retrieve_indices[j, 1]. + # So: + # draft_first_logp_per_candidate = draft_log_scores[0, retrieve_indices[:,1]] + # (make sure dtype/device align) + draft_log_scores = draft_log_scores.to(device) + first_indices = retrieve_indices[:, 1].to(device) # shape [num_candidates] + draft_first_logp_per_candidate = draft_log_scores[0, first_indices] # shape [num_candidates] + + # Now call evaluate_posterior to pick best candidate and accept_length and sample_p + best_candidate, accept_length, sample_p = evaluate_posterior( + logits, candidates, logits_processor + ) + # make best index python int for indexing + best_idx = int(best_candidate) if isinstance(best_candidate, torch.Tensor) else best_candidate + + # pick the two scalar log-probs for the selected candidate + logp_verify_first = logp_verify_first_per_candidate[best_idx] # tensor (requires_grad if upstream) + logp_draft_first = draft_first_logp_per_candidate[best_idx] # tensor (requires_grad if upstream) + + # store them (keep tensors — they will carry grad if enable_grad=True) + if return_step_scores: + step_draft_logps.append(logp_draft_first) + step_verify_logps.append(logp_verify_first) + + # Update inference state as usual. update_inference_inputs may need inputs_embeds handling + input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( + input_ids=input_ids, + candidates=candidates, + best_candidate=best_candidate, + accept_length=accept_length, + retrieve_indices=retrieve_indices, + logits_processor=logits_processor, + new_token=new_token, + past_key_values_data_list=past_key_values_data, + current_length_data=current_length_data, + model=self, + hidden_state_new=hidden_state_new, + sample_p=sample_p + ) + + # stopping criteria + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): break - - if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): - break - if new_token > max_new_tokens: - break - if input_ids.shape[1] > max_length: - break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break + finally: + # restore grad mode + torch.set_grad_enabled(True) + + # Convert step lists to tensors if desired (stack) + if return_step_scores: + # stack into tensors (shape [num_steps]) + # if enable_grad=True these tensors will require grad if their elements do + draft_scores_tensor = torch.stack(step_draft_logps) if len(step_draft_logps) > 0 else torch.empty(0, device=device) + verify_scores_tensor = torch.stack(step_verify_logps) if len(step_verify_logps) > 0 else torch.empty(0, device=device) + + # Return values if not log: - return input_ids + if return_step_scores: + return input_ids, draft_scores_tensor, verify_scores_tensor + else: + return input_ids else: - return input_ids, new_token, idx + if return_step_scores: + return input_ids, new_token, idx, draft_scores_tensor, verify_scores_tensor + else: + return input_ids, new_token, idx + @torch.no_grad() def naivegenerate(