diff --git a/eagle/model/cnets.py b/eagle/model/cnets.py index a8e13ca1..b5c52881 100644 --- a/eagle/model/cnets.py +++ b/eagle/model/cnets.py @@ -667,7 +667,7 @@ def reset_kv(self): self.stable_kv = None @torch.no_grad() - def topK_genrate(self, hidden_states, input_ids, head, logits_processor): + def topK_generate(self, hidden_states, input_ids, head, logits_processor): input_ids = input_ids.to(hidden_states.device) total_tokens = self.total_tokens diff --git a/eagle/model/cnets1.py b/eagle/model/cnets1.py index 25cad25a..89e8fcea 100644 --- a/eagle/model/cnets1.py +++ b/eagle/model/cnets1.py @@ -670,7 +670,7 @@ def reset_kv(self): self.stable_kv = None @torch.no_grad() - def topK_genrate(self, hidden_states, input_ids, head, logits_processor): + def topK_generate(self, hidden_states, input_ids, head, logits_processor): input_ids = input_ids.to(hidden_states.device) total_tokens = self.total_tokens diff --git a/eagle/model/utils.py b/eagle/model/utils.py index dfc64bae..063fd59b 100644 --- a/eagle/model/utils.py +++ b/eagle/model/utils.py @@ -223,7 +223,7 @@ def initialize_tree0(input_ids, model, past_key_values, logits_processor): # input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) # # Clone the output hidden states # - # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head) + # draft_tokens, retrieve_indices,tree_mask,tree_position_ids = self.ea_layer.topK_generate(hidden_states, input_ids, self.base_model.lm_head) # if output_orig: # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, outputs, orig, hidden_states, token # return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, hidden_states, token @@ -250,7 +250,7 @@ def initialize_tree(input_ids, model, past_key_values, logits_processor): if outputs["hidden_states"][0].device != ea_device: outputs["hidden_states"] = [x.to(ea_device) for x in outputs["hidden_states"]] hidden_states=torch.cat(outputs["hidden_states"],dim=-1) - draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_generate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token @@ -463,7 +463,7 @@ def update_inference_inputs( token = torch.argmax(prob) token = token[None, None] # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1) - draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new, + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_generate(accept_hidden_state_new, input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1), head=model.base_model.lm_head,logits_processor=logits_processor) diff --git a/eagle/modeling_eagle.py b/eagle/modeling_eagle.py index 5378cc4a..d32b6d33 100644 --- a/eagle/modeling_eagle.py +++ b/eagle/modeling_eagle.py @@ -860,7 +860,7 @@ def sample(self, logits, logits_processor, k=1): return sampled_indices, sampled_probs, probabilities @torch.no_grad() - def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True, + def topK_generate(self, hidden_states, input_ids, head, logits_processor, max_length=4, use_cache=True, attention_mask=None, len_posi=None, ): top_k = 5 bs = input_ids.shape[0] @@ -1272,7 +1272,7 @@ def initialize_tree(input_ids, model, logits_processor, attention_mask=None): token = token[:, None] input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) - tree_logits = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head, logits_processor, + tree_logits = model.ea_layer.topK_generate(hidden_states, input_ids, model.base_model.lm_head, logits_processor, attention_mask=attention_mask) @@ -1543,7 +1543,7 @@ def update_inference_inputs( - tree_logits = model.ea_layer.topK_genrate(draft_hidden, + tree_logits = model.ea_layer.topK_generate(draft_hidden, input_ids=draft_input_ids, head=model.base_model.lm_head, logits_processor=logits_processor,attention_mask=attention_mask,len_posi=input_ids.shape[1]) diff --git a/eagle/testbug/model/cnets.py b/eagle/testbug/model/cnets.py index ac09075a..af0ea03a 100644 --- a/eagle/testbug/model/cnets.py +++ b/eagle/testbug/model/cnets.py @@ -569,7 +569,7 @@ def sample(self,logits, logits_processor,k=1, replacement=False): @torch.no_grad() - def topK_genrate(self, hidden_states, input_ids, head, logits_processor,max_length=4, use_cache=True): + def topK_generate(self, hidden_states, input_ids, head, logits_processor,max_length=4, use_cache=True): input_ids = input_ids[:, 1:] ss_token,ss_prob,ss_op = [],[],[] len_posi=input_ids.shape[1] diff --git a/eagle/testbug/model/ea_model.py b/eagle/testbug/model/ea_model.py index 528c47a2..042cf24a 100644 --- a/eagle/testbug/model/ea_model.py +++ b/eagle/testbug/model/ea_model.py @@ -92,7 +92,7 @@ def forward( input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) # Clone the output hidden states - ea_logits = self.ea_layer.topK_genrate(hidden_states, input_ids, None, logits_processor) + ea_logits = self.ea_layer.topK_generate(hidden_states, input_ids, None, logits_processor) if output_orig: return ea_logits, outputs, orig, hidden_states, token return ea_logits, hidden_states, token diff --git a/eagle/testbug/model/ea_modelbs.py b/eagle/testbug/model/ea_modelbs.py index 262f0cda..a16c4f97 100644 --- a/eagle/testbug/model/ea_modelbs.py +++ b/eagle/testbug/model/ea_modelbs.py @@ -140,7 +140,7 @@ def forward( input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) # Clone the output hidden states - ea_logits = self.ea_layer.topK_genrate(hidden_states, input_ids, self.base_model.lm_head, logits_processor,attention_mask=attention_mask) + ea_logits = self.ea_layer.topK_generate(hidden_states, input_ids, self.base_model.lm_head, logits_processor,attention_mask=attention_mask) if output_orig: return ea_logits, outputs, orig, hidden_states, token return ea_logits, hidden_states, token diff --git a/eagle/testbug/model/utils.py b/eagle/testbug/model/utils.py index 5279039f..c978ead0 100644 --- a/eagle/testbug/model/utils.py +++ b/eagle/testbug/model/utils.py @@ -441,7 +441,7 @@ def update_inference_inputs( token = torch.argmax(prob) token = token[None, None] # hidden_state = torch.cat((hidden_state, accept_hidden_state_new), dim=1) - tree_logits = model.ea_layer.topK_genrate(None, + tree_logits = model.ea_layer.topK_generate(None, input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1), head=None, logits_processor=logits_processor)