Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions eagle/application/webui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
#os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
import time

import gradio as gr
Expand Down Expand Up @@ -111,11 +111,13 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat
elif args.model_type == "llama-3-instruct":
messages = [
{"role": "system",
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
]
elif args.model_type == "ds-llama-3":
messages = [] # ds-llama-3 no system prompt

for query, response in pure_history:
if args.model_type == "llama-3-instruct":
if args.model_type in ["llama-3-instruct", "ds-llama-3"]:
messages.append({
"role": "user",
"content": query
Expand All @@ -131,7 +133,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat
response = " " + response
conv.append_message(conv.roles[1], response)

if args.model_type == "llama-3-instruct":
if args.model_type in ["llama-3-instruct", "ds-llama-3"]:
prompt = model.tokenizer.apply_chat_template(
messages,
tokenize=False,
Expand All @@ -154,12 +156,12 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat
if use_EaInfer:

for output_ids in model.ea_generate(input_ids, temperature=temperature, top_p=top_p,
max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"):
max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])):
totaltime+=(time.time()-start_time)
total_ids+=1
decode_ids = output_ids[0, input_len:].tolist()
decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id)
if args.model_type == "llama-3-instruct":
if args.model_type in ["llama-3-instruct", "ds-llama-3"]:
decode_ids = truncate_list(decode_ids, model.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False,
clean_up_tokenization_spaces=True, )
Expand All @@ -183,7 +185,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat

else:
for output_ids in model.naive_generate(input_ids, temperature=temperature, top_p=top_p,
max_new_tokens=args.max_new_token,is_llama3=args.model_type=="llama-3-instruct"):
max_new_tokens=args.max_new_token,is_llama3=(args.model_type in ["llama-3-instruct", "ds-llama-3"])):
totaltime += (time.time() - start_time)
total_ids+=1
decode_ids = output_ids[0, input_len:].tolist()
Expand Down Expand Up @@ -256,7 +258,7 @@ def clear(history,session_state):
parser.add_argument(
"--no-eagle3", action="store_true", help=" Not use EAGLE-3"
)
parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct"])
parser.add_argument("--model-type", type=str, default="vicuna",choices=["llama-2-chat","vicuna","mixtral","llama-3-instruct","ds-llama-3"])
parser.add_argument(
"--total-token",
type=int,
Expand All @@ -280,20 +282,20 @@ def clear(history,session_state):
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
device_map="auto",
use_eagle3=args.no_eagle3,
use_eagle3=(not args.no_eagle3),
)
model.eval()
warmup(model)

custom_css = """
#speed textarea {
color: red;
font-size: 30px;
color: red;
font-size: 30px;
}"""

with gr.Blocks(css=custom_css) as demo:
gs = gr.State({"pure_history": []})
gr.Markdown('''## EAGLE-2 Chatbot''')
gr.Markdown('''## EAGLE-3 Chatbot''')
with gr.Row():
speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s")
compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00")
Expand All @@ -303,7 +305,7 @@ def clear(history,session_state):
highlight_EaInfer = gr.Checkbox(label="Highlight the tokens generated by EAGLE-3", value=True)
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5)
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="top_p", value=0.9)
note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2
note=gr.Markdown(show_label=False,interactive=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by EAGLE-2" is checked, the tokens correctly guessed by EAGLE-2
will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''')


Expand All @@ -329,4 +331,4 @@ def clear(history,session_state):
)
stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
demo.queue()
demo.launch(share=True)
demo.launch(share=True, server_name="0.0.0.0")