-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
84 lines (62 loc) · 2.58 KB
/
app.py
File metadata and controls
84 lines (62 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from models import MODEL_IDS
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import torch
import asyncio
# Initialize FastAPI app
app = FastAPI()
# Load Hugging Face model and tokenizer
MODEL_NAME = MODEL_IDS["llama_3.2_1b_instruct"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# Initialize LangChain memory for conversation history
memory = ConversationBufferMemory()
# Pydantic model for request body
class PromptRequest(BaseModel):
prompt: str
# Function to generate streaming responses
async def generate_response(prompt: str):
# Add the prompt to memory
memory.chat_memory.add_user_message(prompt)
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt",
max_length=512, truncation=True)
# Generate response tokens incrementally
with torch.no_grad():
for _ in range(50): # Adjust max tokens as needed
outputs = model.generate(
inputs["input_ids"],
max_length=inputs["input_ids"].shape[1] + 1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
)
new_token = outputs[0, -1].item()
if new_token == tokenizer.eos_token_id:
break
# Decode the new token and yield it
decoded_token = tokenizer.decode(
[new_token], skip_special_tokens=True)
yield decoded_token
# Update inputs for the next iteration
inputs["input_ids"] = torch.cat(
[inputs["input_ids"], torch.tensor([[new_token]])], dim=1)
# Add the model's response to memory
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
memory.chat_memory.add_ai_message(full_response)
# FastAPI endpoint for streaming responses
@app.post("/stream-chat/")
async def stream_chat(request: PromptRequest):
if not request.prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
return StreamingResponse(generate_response(request.prompt), media_type="text/plain")
# FastAPI endpoint to get conversation history
@app.get("/conversation-history/")
async def get_conversation_history():
return {"history": memory.load_memory_variables({})}
# Run the FastAPI app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)