-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot_class.py
More file actions
59 lines (49 loc) · 2.25 KB
/
chatbot_class.py
File metadata and controls
59 lines (49 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# symptom_predictor.py
import os
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv, find_dotenv
# Load environment variables
load_dotenv(find_dotenv())
class SymptomPredictor:
def __init__(self):
# Load Hugging Face model
self.HF_TOKEN = os.environ.get("HF_TOKEN")
self.HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
self.llm = HuggingFaceEndpoint(
repo_id=self.HUGGINGFACE_REPO_ID,
temperature=1.5,
model_kwargs={"token": self.HF_TOKEN, "max_length": "512"}
)
# Load FAISS database
self.DB_FAISS_PATH = "vectorstore/db_faiss"
self.embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.db = FAISS.load_local(self.DB_FAISS_PATH, self.embedding_model, allow_dangerous_deserialization=True)
# Set custom prompt
self.CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Don't provide anything out of the given context.
Context: {context}
Question: {question}
Start the answer directly. No small talk please.
"""
prompt = PromptTemplate(template=self.CUSTOM_PROMPT_TEMPLATE, input_variables=["context", "question"])
# Create QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.db.as_retriever(search_kwargs={'k': 3}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
def predict(self, user_query):
"""Predicts the disease based on user symptoms"""
response = self.qa_chain.invoke({'query': user_query})
return {
"result": response["result"],
# "source_documents": response["source_documents"]
}