From d93def06d54c5fb463178410127f7866e4f43358 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sun, 26 Oct 2025 15:50:47 +0200 Subject: [PATCH 1/2] compact topk index and prob extraction without recomputing topk twice. --- eagle/modeling_eagle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eagle/modeling_eagle.py b/eagle/modeling_eagle.py index 5378cc4a..c68b4dea 100644 --- a/eagle/modeling_eagle.py +++ b/eagle/modeling_eagle.py @@ -903,11 +903,10 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_len for i in range(len(self.tree_buffer['tree_indices'])): if logits_processor is not None: - topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) + topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k) else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, - top_k, - dim=-1).values + topk_results = torch.topk(last_headout, top_k, dim=-1) + topk_index, topk_prob = topk_results.indices, topk_results.values op = None ss_token.append(topk_index) @@ -1734,3 +1733,4 @@ def generate( return out_inputids[0] return out_inputids + From f13f3972e3a4a36ef7c35e4e9191daf9bf32d8c6 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sun, 26 Oct 2025 16:05:49 +0200 Subject: [PATCH 2/2] updating the .topk in addiitional places --- eagle/modeling_eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/eagle/modeling_eagle.py b/eagle/modeling_eagle.py index c68b4dea..42d9ef8a 100644 --- a/eagle/modeling_eagle.py +++ b/eagle/modeling_eagle.py @@ -942,8 +942,8 @@ def topK_genrate(self, hidden_states, input_ids, head, logits_processor, max_len if logits_processor is not None: topk_index, topk_prob, op = self.sample(last_headout, logits_processor, k=top_k, ) else: - topk_index, topk_prob = torch.topk(last_headout, top_k, dim=-1).indices, torch.topk(last_headout, top_k, - dim=-1).values + topk_results = torch.topk(last_headout, top_k, dim=-1) + topk_index, topk_prob = topk_results.indices, topk_results.values op = None ss_token.append(topk_index) ss_prob.append(topk_prob) @@ -1734,3 +1734,4 @@ def generate( return out_inputids +