From 1fd30217c410a62341b00ba5f5a4d7f44206ca39 Mon Sep 17 00:00:00 2001 From: yanxing Date: Fri, 13 Jun 2025 11:00:19 +0800 Subject: [PATCH 1/2] bugfix: bugfix of ui arg no_eagle3. --- eagle/application/webui.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/eagle/application/webui.py b/eagle/application/webui.py index a1b96dff..1b68e61b 100644 --- a/eagle/application/webui.py +++ b/eagle/application/webui.py @@ -1,5 +1,5 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" +# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" import time import gradio as gr @@ -280,15 +280,15 @@ def clear(history,session_state): load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, device_map="auto", - use_eagle3=args.no_eagle3, + use_eagle3=(not args.no_eagle3), ) model.eval() warmup(model) custom_css = """ #speed textarea { - color: red; - font-size: 30px; + color: red; + font-size: 30px; }""" with gr.Blocks(css=custom_css) as demo: @@ -303,7 +303,7 @@ def clear(history,session_state): highlight_EaInfer = gr.Checkbox(label="Highlight the tokens generated by EAGLE-3", value=True) temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5) top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="top_p", value=0.9) - note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2 + note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2 will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''') From 3ef424ce549c7a1c02e9f779531f37885f8a8d36 Mon Sep 17 00:00:00 2001 From: yanxing Date: Fri, 13 Jun 2025 16:26:12 +0800 Subject: [PATCH 2/2] bugfix: add ds-llama-3 type, and del its system_prompt. --- eagle/application/webui.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/eagle/application/webui.py b/eagle/application/webui.py index 1b68e61b..367b3e36 100644 --- a/eagle/application/webui.py +++ b/eagle/application/webui.py @@ -1,5 +1,5 @@ import os -# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" +#os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" import time import gradio as gr @@ -111,11 +111,13 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat elif args.model_type == "llama-3-instruct": messages = [ {"role": "system", - "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."}, + "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."}, ] + elif args.model_type == "ds-llama-3": + messages = [] # ds-llama-3 no system prompt for query, response in pure_history: - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: messages.append({ "role": "user", "content": query @@ -131,7 +133,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat response = " " + response conv.append_message(conv.roles[1], response) - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: prompt = model.tokenizer.apply_chat_template( messages, tokenize=False, @@ -154,12 +156,12 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat if use_EaInfer: for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])): totaltime+=(time.time()-start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id) - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: decode_ids = truncate_list(decode_ids, model.tokenizer.convert_tokens_to_ids("<|eot_id|>")) text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) @@ -183,7 +185,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat else: for output_ids in model.naive_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])): totaltime += (time.time() - start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() @@ -256,7 +258,7 @@ def clear(history,session_state): parser.add_argument( "--no-eagle3", action="store_true", help=" Not use EAGLE-3" ) -parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct"]) +parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct","ds-llama-3"]) parser.add_argument( "--total-token", type=int, @@ -293,7 +295,7 @@ def clear(history,session_state): with gr.Blocks(css=custom_css) as demo: gs = gr.State({"pure_history": []}) - gr.Markdown('''## EAGLE-2 Chatbot''') + gr.Markdown('''## EAGLE-3 Chatbot''') with gr.Row(): speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s") compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00") @@ -329,4 +331,4 @@ def clear(history,session_state): ) stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event]) demo.queue() -demo.launch(share=True) +demo.launch(share=True, server_name="0.0.0.0")