diff --git a/eagle/modeling_eagle.py b/eagle/modeling_eagle.py index 5378cc4a..42d9ef8a 100644 --- a/eagle/modeling_eagle.py +++ b/eagle/modeling_eagle.py @@ -903,11 +903,10 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_len for i in range(len(self.tree_buffer['tree_indices'])): if logits_processor is not None: - topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) + topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k) else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, - top_k, - dim=-1).values + topk_results = torch.topk(last_headout, top_k, dim=-1) + topk_index, topk_prob = topk_results.indices, topk_results.values op = None ss_token.append(topk_index) @@ -943,8 +942,8 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_len if logits_processor is not None: topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, top_k, - dim=-1).values + topk_results = torch.topk(last_headout, top_k, dim=-1) + topk_index, topk_prob = topk_results.indices, topk_results.values op = None ss_token.append(topk_index) ss_prob.append(topk_prob) @@ -1734,3 +1733,5 @@ def generate( return out_inputids[0] return out_inputids + +