-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparse_query.py
More file actions
91 lines (76 loc) · 3.55 KB
/
parse_query.py
File metadata and controls
91 lines (76 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import google.generativeai as genai
from dotenv import load_dotenv
import os
from gemini_context_manager import GeminiContextManager
import re
load_dotenv()
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
model = genai.GenerativeModel("gemini-pro",safety_settings=safety_settings)
cols = {
"usertable": ["user_id"],
"userinfo": ["fullname", "date_of_birth"],
"notes": ["note", "note_date"],
"medicine": ["med_name", "med_dosage", "med_frequency", "med_date"],
"vaccine": ["vac_name", "vac_date"],
"lab_result": ["lab_result", "lab_date"],
"surgeries": ["surgery", "surgery_date"],
"emergencies": ["emergency_name", "emergency_date"],
"vitals": ["vital_name", "vital_value", "vital_date"],
"diagnosis": ["diagnosis", "diag_date"],
"symptoms": ["symptom", "symptom_date"]
}
context_manager = GeminiContextManager()
context_manager.add_context("user",f"""
This is a system prompt.
You are a natural language processor. Your job is to generate valid SQL queries based on the user's input. The user will provide a dictionary of tables and columns. If the query wants to get all patients with certain condition, the query returned should be something like this: SELECT userinfo.fullname FROM userinfo JOIN (necessary table name) ON userinfo.user_id = (necessary table name).user_id WHERE (necessary table name).(something) = 'something required';
Also remember to include date with queries asking for notes so you include note_date in the query.
When asked for symptoms for a patient, the query should be something like this: SELECT symptoms.symptom, diagnosis.diagnosis FROM symptoms JOIN userinfo ON symptoms.user_id = userinfo.user_id JOIN diagnosis ON symptoms.diag_id = diagnosis.diag_id WHERE userinfo.fullname = 'name provided';
Below are the didctionary where key is the name of the table in the database, and the value is the list of columns in the table. Only use values in this dictionary to generate the SQL query.
Dictionary:
{str(cols)}
return format:
["~SQL query here~"]
example:
["SELECT * FROM usertable"]
["SELECT surgery from surgeries"]
IMPORTANT: FOLLOW THE PROVIDED FORMAT EXACTLY.
IMPORTANT: ONLY RETURN THE QUERY STRING. DO NOT RETURN ANYTHING ELSE.
IMPORTANT: DO NOT INCLUDE BACK QUOTES OR STRING UNRELATED TO THE QUERY.
IMPORTANT: DO NOT INCLUDE COLUMNS THAT ARE NOT IN THE DICTIONARY ABOVE.
""")
context_manager.add_context("model","Certainly, I can help with creating the SQL quey.")
def sanitize(text):
text = text.replace("\n"," ")
text = text.replace("`","")
text = text.replace("sql","")
#text = text.replace("*"," * ")
#text = re.sub("\W+", " ", text)
return text
def parse_query(input_text):
chat = model.start_chat(history=context_manager.get_context())
response = chat.send_message(input_text)
response = sanitize(response.text)
return response
def main():
response = parse_query("I want all the medicine prescription")
if __name__ == "__main__":
main()