diff --git a/eagle/application/webui.py b/eagle/application/webui.py index a1b96dff..367b3e36 100644 --- a/eagle/application/webui.py +++ b/eagle/application/webui.py @@ -1,5 +1,5 @@ import os -os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" +#os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" import time import gradio as gr @@ -111,11 +111,13 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat elif args.model_type == "llama-3-instruct": messages = [ {"role": "system", - "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."}, + "content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."}, ] + elif args.model_type == "ds-llama-3": + messages = [] # ds-llama-3 no system prompt for query, response in pure_history: - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: messages.append({ "role": "user", "content": query @@ -131,7 +133,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat response = " " + response conv.append_message(conv.roles[1], response) - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: prompt = model.tokenizer.apply_chat_template( messages, tokenize=False, @@ -154,12 +156,12 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat if use_EaInfer: for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])): totaltime+=(time.time()-start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id) - if args.model_type == "llama-3-instruct": + if args.model_type in ["llama-3-instruct", "ds-llama-3"]: decode_ids = truncate_list(decode_ids, model.tokenizer.convert_tokens_to_ids("<|eot_id|>")) text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) @@ -183,7 +185,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat else: for output_ids in model.naive_generate(input_ids, temperature=temperature, top_p=top_p, - max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"): + max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])): totaltime += (time.time() - start_time) total_ids+=1 decode_ids = output_ids[0, input_len:].tolist() @@ -256,7 +258,7 @@ def clear(history,session_state): parser.add_argument( "--no-eagle3", action="store_true", help=" Not use EAGLE-3" ) -parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct"]) +parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct","ds-llama-3"]) parser.add_argument( "--total-token", type=int, @@ -280,20 +282,20 @@ def clear(history,session_state): load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, device_map="auto", - use_eagle3=args.no_eagle3, + use_eagle3=(not args.no_eagle3), ) model.eval() warmup(model) custom_css = """ #speed textarea { - color: red; - font-size: 30px; + color: red; + font-size: 30px; }""" with gr.Blocks(css=custom_css) as demo: gs = gr.State({"pure_history": []}) - gr.Markdown('''## EAGLE-2 Chatbot''') + gr.Markdown('''## EAGLE-3 Chatbot''') with gr.Row(): speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s") compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00") @@ -303,7 +305,7 @@ def clear(history,session_state): highlight_EaInfer = gr.Checkbox(label="Highlight the tokens generated by EAGLE-3", value=True) temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5) top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="top_p", value=0.9) - note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2 + note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2 will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''') @@ -329,4 +331,4 @@ def clear(history,session_state): ) stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event]) demo.queue() -demo.launch(share=True) +demo.launch(share=True, server_name="0.0.0.0")