-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmvp_chat.py
More file actions
100 lines (77 loc) · 2.9 KB
/
mvp_chat.py
File metadata and controls
100 lines (77 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from models import MODEL_IDS
# 1. Load Model and Tokenizer
model_id = MODEL_IDS["tinyLlama_1.1b"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id)
# If MPS is not available, this will use CPU by default.
device = torch.device(
"mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)
model.eval()
# 2. Define the Streaming Generator
def stream_generate(prompt, max_new_tokens=50, temperature=0.7):
"""
Yields the generated text token-by-token (or small chunks).
"""
# Encode prompt as input IDs
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Use the model.generate in 'greedy' or 'sampling' mode
# with 'streaming' approach by manually stepping through tokens.
# We'll do this in small steps rather than the built-in generate loop
# to simulate streaming.
generated_ids = input_ids
for _ in range(max_new_tokens):
outputs = model(
generated_ids,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature scaling
next_token_logits = next_token_logits / temperature
# Sample from the distribution
next_token = torch.distributions.Categorical(
logits=next_token_logits).sample()
# Append next token to the generated sequence
generated_ids = torch.cat(
[generated_ids, next_token.unsqueeze(0)], dim=1)
# Decode the latest token and yield it
decoded_text = tokenizer.decode(next_token)
yield decoded_text
def chat_with_model(prompt):
"""
Function called by Gradio to handle a single prompt.
It returns a generator that yields partial responses.
"""
return stream_generate(prompt)
# 3. Gradio Interface
def chat_interface():
"""
We define a Gradio interface with a text input and a streamed output.
"""
with gr.Blocks() as demo:
gr.Markdown("## Token-Streaming LLM Chat Demo")
# Chat input
prompt_box = gr.Textbox(
label="Enter your prompt",
placeholder="Type your message here..."
)
# Chat output
output_box = gr.Textbox(label="Model Response")
# The 'submit' event triggers the function chat_with_model
# We'll capture the generator output and update the textbox progressively
def submit(prompt):
partial_text = ""
for token in chat_with_model(prompt):
partial_text += token
yield partial_text
# Gradio event for streaming
prompt_box.submit(fn=submit, inputs=prompt_box, outputs=output_box)
return demo
if __name__ == "__main__":
# 4. Launch the Gradio App
demo_app = chat_interface()
demo_app.launch()