From e10c6d10e6c42d1f8d87e985cc30a4bd648e2a02 Mon Sep 17 00:00:00 2001 From: liutuo Date: Tue, 8 Jul 2025 17:35:38 +0800 Subject: [PATCH] chore: Clean eagenerate and naivegenerate --- README.md | 4 +- eagle/application/webui.py | 6 +- .../gen_baseline_answer_llama3chat.py | 4 +- eagle/evaluation/gen_ea_answer_ds.py | 4 +- eagle/evaluation/gen_ea_answer_llama2chat.py | 4 +- eagle/evaluation/gen_ea_answer_llama3chat.py | 4 +- eagle/evaluation/gen_ea_answer_mix.py | 4 +- eagle/evaluation/gen_ea_answer_vicuna.py | 4 +- eagle/model/ea_model.py | 204 ++---------------- eagle/testbug/model/ea_model.py | 12 +- eagle/testbug/model/ea_modelbs.py | 4 +- eagle/testbug/testbbug.py | 2 +- 12 files changed, 47 insertions(+), 209 deletions(-) diff --git a/README.md b/README.md index 2ada70b0..12143d5a 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ python -m eagle.application.webui --ea-model-path [path of EAGLE weight]\ The *total-token* is the number of draft tokens. For smaller models and advanced GPUs, this value can be set larger. Adjusting according to the specific device and model can achieve better results. If set to -1, EAGLE-2 will automatically configure this parameter. ### With Code -You can use our provided "eagenerate" for speedup generation just like using 'generate' from Hugging Face. Here is an example. +You can use our provided "ea_generate" for speedup generation just like using 'generate' from Hugging Face. Here is an example. ```python from eagle.model.ea_model import EaModel from fastchat.model import get_conversation_template @@ -197,7 +197,7 @@ conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids=model.tokenizer([prompt]).input_ids input_ids = torch.as_tensor(input_ids).cuda() -output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512) +output_ids=model.ea_generate(input_ids,temperature=0.5,max_new_tokens=512) output=model.tokenizer.decode(output_ids[0]) ``` diff --git a/eagle/application/webui.py b/eagle/application/webui.py index a1b96dff..593468a7 100644 --- a/eagle/application/webui.py +++ b/eagle/application/webui.py @@ -91,7 +91,7 @@ def warmup(model): prompt += " " input_ids = model.tokenizer([prompt]).input_ids input_ids = torch.as_tensor(input_ids).cuda() - for output_ids in model.ea_generate(input_ids): + for output_ids in model.ea_generate(input_ids, streaming=True): ol=output_ids.shape[1] def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,): @@ -154,7 +154,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat if use_EaInfer: for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct",streaming=True): totaltime+=(time.time()-start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() @@ -183,7 +183,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat else: for output_ids in model.naive_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct",streaming=True): totaltime += (time.time() - start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() diff --git a/eagle/evaluation/gen_baseline_answer_llama3chat.py b/eagle/evaluation/gen_baseline_answer_llama3chat.py index 340a2927..a1facc2b 100644 --- a/eagle/evaluation/gen_baseline_answer_llama3chat.py +++ b/eagle/evaluation/gen_baseline_answer_llama3chat.py @@ -158,7 +158,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.naivegenerate( + output_ids, new_token, idx = model.naive_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, @@ -240,7 +240,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.naivegenerate( + output_ids, new_token, idx = model.naive_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, diff --git a/eagle/evaluation/gen_ea_answer_ds.py b/eagle/evaluation/gen_ea_answer_ds.py index 24aad091..4bda0caa 100644 --- a/eagle/evaluation/gen_ea_answer_ds.py +++ b/eagle/evaluation/gen_ea_answer_ds.py @@ -155,7 +155,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, @@ -234,7 +234,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, diff --git a/eagle/evaluation/gen_ea_answer_llama2chat.py b/eagle/evaluation/gen_ea_answer_llama2chat.py index 7df2b2af..d8576726 100644 --- a/eagle/evaluation/gen_ea_answer_llama2chat.py +++ b/eagle/evaluation/gen_ea_answer_llama2chat.py @@ -144,7 +144,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True @@ -210,7 +210,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True diff --git a/eagle/evaluation/gen_ea_answer_llama3chat.py b/eagle/evaluation/gen_ea_answer_llama3chat.py index 21e71d14..b04ec2cf 100644 --- a/eagle/evaluation/gen_ea_answer_llama3chat.py +++ b/eagle/evaluation/gen_ea_answer_llama3chat.py @@ -159,7 +159,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, @@ -241,7 +241,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True, diff --git a/eagle/evaluation/gen_ea_answer_mix.py b/eagle/evaluation/gen_ea_answer_mix.py index 36def8f7..a803a8fe 100644 --- a/eagle/evaluation/gen_ea_answer_mix.py +++ b/eagle/evaluation/gen_ea_answer_mix.py @@ -145,7 +145,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True @@ -211,7 +211,7 @@ def get_model_answers( try: torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True diff --git a/eagle/evaluation/gen_ea_answer_vicuna.py b/eagle/evaluation/gen_ea_answer_vicuna.py index 7bc615f2..8f267882 100644 --- a/eagle/evaluation/gen_ea_answer_vicuna.py +++ b/eagle/evaluation/gen_ea_answer_vicuna.py @@ -148,7 +148,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True @@ -212,7 +212,7 @@ def get_model_answers( torch.cuda.synchronize() start_time = time.time() - output_ids, new_token, idx = model.eagenerate( + output_ids, new_token, idx = model.ea_generate( torch.as_tensor(input_ids).cuda(), temperature=temperature, log=True diff --git a/eagle/model/ea_model.py b/eagle/model/ea_model.py index f0d2e516..1869d953 100644 --- a/eagle/model/ea_model.py +++ b/eagle/model/ea_model.py @@ -191,7 +191,7 @@ def forward( return outputs, hidden_states @torch.no_grad() - def eagenerate( + def ea_generate( self, input_ids, temperature=0.0, @@ -201,6 +201,7 @@ def eagenerate( max_length=2048, log=False, is_llama3=False, + streaming=False, ): if is_llama3: @@ -282,189 +283,8 @@ def eagenerate( sample_p ) - if is_llama3: - if stop_token_id in input_ids[0, input_len:].tolist(): - break - - if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): - break - if new_token > max_new_tokens: - break - if input_ids.shape[1] > max_length: - break - if not log: - return input_ids - else: - return input_ids, new_token, idx - - @torch.no_grad() - def naivegenerate( - self, - input_ids, - temperature=0.0, - top_p=0.0, - top_k=0.0, - max_new_tokens=512, - max_length=2048, - log=False, - is_llama3=False, - - ): - if is_llama3: - stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - - - if temperature > 1e-5: - logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) - else: - logits_processor = None - # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" - # Avoid modifying the input_ids in-place - - padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) - input_ids = input_ids.clone() - self.ea_layer.reset_kv() - - # Initialize the past key and value states - if hasattr(self, "past_key_values"): - past_key_values = self.past_key_values - past_key_values_data = self.past_key_values_data - current_length_data = self.current_length_data - # Reset the past key and value states - current_length_data.zero_() - else: - ( - past_key_values, - past_key_values_data, - current_length_data, - ) = initialize_past_key_values(self.base_model,max_length=max_length) - self.past_key_values = past_key_values - self.past_key_values_data = past_key_values_data - self.current_length_data = current_length_data - - input_len = input_ids.shape[1] - reset_tree_mode(self) - outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) - new_token = 0 - max_length = max_length - self.ea_layer.total_tokens - 10 - for idx in range(max_length): - if logits_processor is not None: - logits = outputs.logits[:, -1] - logits = logits_processor(None, logits) - probabilities = torch.nn.functional.softmax(logits, dim=-1) - input_id = torch.multinomial(probabilities, 1) - else: - input_id = outputs.logits[:, -1:].argmax(dim=-1) - outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) - input_ids = torch.cat([input_ids, input_id], dim=-1) - new_token += 1 - - if is_llama3: - if stop_token_id in input_ids[0, input_len:].tolist(): - break - - if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): - break - if new_token > max_new_tokens: - break - if input_ids.shape[1] > max_length: - break - if not log: - return input_ids - else: - return input_ids, new_token, idx - - @torch.no_grad() - def ea_generate( - self, - input_ids, - temperature=0.0, - top_p=0.0, - top_k=0.0, - max_new_tokens=512, - max_length=2048, - log=False, - is_llama3=False, - - ): - if is_llama3: - stop_token_id = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - - - if temperature > 1e-5: - logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) - else: - logits_processor = None - # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" - # Avoid modifying the input_ids in-place - - padding = (torch.zeros(1, 1, dtype=torch.long) - 1).to(input_ids.device) - input_ids = input_ids.clone() - self.ea_layer.reset_kv() - - # Initialize the past key and value states - if hasattr(self, "past_key_values"): - past_key_values = self.past_key_values - past_key_values_data = self.past_key_values_data - current_length_data = self.current_length_data - # Reset the past key and value states - current_length_data.zero_() - else: - ( - past_key_values, - past_key_values_data, - current_length_data, - ) = initialize_past_key_values(self.base_model,max_length=max_length) - self.past_key_values = past_key_values - self.past_key_values_data = past_key_values_data - self.current_length_data = current_length_data - - input_len = input_ids.shape[1] - reset_tree_mode(self) - draft_tokens, retrieve_indices, tree_mask, tree_position_ids, logits, hidden_state, sample_token = initialize_tree( - input_ids, self, past_key_values, logits_processor - ) - new_token = 0 - max_length = max_length - self.ea_layer.total_tokens - 10 - for idx in range(max_length): - # with Timer("all"): - self.base_model.model.tree_mask = tree_mask - - draft_tokens = draft_tokens.to(input_ids.device) - # with Timer("tree_decoding"): - logits, hidden_state_new, outputs = tree_decoding( - self, - draft_tokens, - past_key_values, - tree_position_ids, - input_ids, - retrieve_indices, - ) - # retrieve_indices=tree_buffers["retrieve_indices"] - # logits = logits[0, retrieve_indices] - draft_tokens = torch.cat((draft_tokens, padding), dim=1) - candidates = draft_tokens[0, retrieve_indices] - best_candidate, accept_length, sample_p = evaluate_posterior( - logits, candidates, logits_processor - ) - # print(accept_length) - # with Timer("update_inference_inputs"): - input_ids, draft_tokens, retrieve_indices, tree_mask, tree_position_ids, new_token, hidden_state, sample_token = update_inference_inputs( - input_ids, - candidates, - best_candidate, - accept_length, - retrieve_indices, - logits_processor, - new_token, - past_key_values_data, - current_length_data, - self, - hidden_state_new, - sample_p - ) - - yield input_ids + if streaming: + yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): @@ -476,6 +296,11 @@ def ea_generate( break if input_ids.shape[1] > max_length: break + if not streaming: + if not log: + return input_ids + else: + return input_ids, new_token, idx @torch.no_grad() def naive_generate( @@ -488,6 +313,7 @@ def naive_generate( max_length=2048, log=False, is_llama3=False, + streaming=False, ): if is_llama3: @@ -535,12 +361,12 @@ def naive_generate( input_id = torch.multinomial(probabilities, 1) else: input_id = outputs.logits[:, -1:].argmax(dim=-1) - outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) input_ids = torch.cat([input_ids, input_id], dim=-1) new_token += 1 - yield input_ids + if streaming: + yield input_ids if is_llama3: if stop_token_id in input_ids[0, input_len:].tolist(): @@ -552,3 +378,9 @@ def naive_generate( break if input_ids.shape[1] > max_length: break + + if not streaming: + if not log: + return input_ids + else: + return input_ids, new_token, idx diff --git a/eagle/testbug/model/ea_model.py b/eagle/testbug/model/ea_model.py index 528c47a2..96bde13e 100644 --- a/eagle/testbug/model/ea_model.py +++ b/eagle/testbug/model/ea_model.py @@ -101,7 +101,7 @@ def forward( return outputs, orig, hidden_states @torch.no_grad() - def eagenerate( + def ea_generate( self, input_ids, temperature=0.0, @@ -110,6 +110,7 @@ def eagenerate( max_new_tokens=512, max_length=2048, tree_choices=mc_sim_7b_63, + streaming=False, ): if temperature > 1e-5: @@ -192,10 +193,15 @@ def eagenerate( sample_p ) + if streaming: + yield input_ids + if new_token > max_new_tokens: - return input_ids + break if input_ids.shape[1] > max_length: - return input_ids + break + if not streaming: + return input_ids @torch.no_grad() diff --git a/eagle/testbug/model/ea_modelbs.py b/eagle/testbug/model/ea_modelbs.py index 262f0cda..b24843f0 100644 --- a/eagle/testbug/model/ea_modelbs.py +++ b/eagle/testbug/model/ea_modelbs.py @@ -149,7 +149,7 @@ def forward( return outputs, orig, hidden_states @torch.no_grad() - def eagenerate( + def ea_generate( self, input_ids, attention_mask=None, @@ -302,7 +302,7 @@ def eagenerate( @torch.no_grad() - def naivegenerate( + def naive_generate( self, input_ids, attention_mask=None, diff --git a/eagle/testbug/testbbug.py b/eagle/testbug/testbbug.py index 97b6beda..20d8093b 100644 --- a/eagle/testbug/testbbug.py +++ b/eagle/testbug/testbbug.py @@ -44,7 +44,7 @@ input_ids = torch.as_tensor([[1, 2]]) s = time.time() for i in range(500000): - output_ids = model.eagenerate(input_ids, temperature=1.0, max_new_tokens=15) + output_ids = model.ea_generate(input_ids, temperature=1.0, max_new_tokens=15) outs.append(output_ids[:, input_ids.shape[1]:input_ids.shape[1] + 15]) if i>0 and i%1000==0: outstensor = torch.cat(outs, dim=0)