-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
110 lines (93 loc) · 3.67 KB
/
app.py
File metadata and controls
110 lines (93 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
from pathlib import Path
from langchain.agents import create_sql_agent
from langchain.sql_database import SQLDatabase
from langchain.callbacks import StreamlitCallbackHandler
from sqlalchemy import create_engine
import sqlite3
from langchain_groq import ChatGroq
# Set page configuration
st.set_page_config(page_title="Langchain: Chat with SQL DB")
st.title("Langchain: Chat with SQL DB")
# Constants for database types
LOCALDB = "USE_LOCALDB"
USE_MYSQL = "USE_MYSQL"
# Sidebar options
radio_opt = ["Use SQLite 3 Database - VALUE_MAINTENANCE.db",
"Connect to your SQL Database"]
selected_opt = st.sidebar.radio(
label="Choose the DB you want to chat with", options=radio_opt)
# Configure database connection type
if radio_opt.index(selected_opt) == 1:
db_uri = USE_MYSQL
mysql_host = st.sidebar.text_input("Provide MySQL Host", value="localhost")
mysql_user = st.sidebar.text_input("MySQL User")
mysql_password = st.sidebar.text_input("MySQL Password", type="password")
mysql_db = st.sidebar.text_input("MySQL Database")
else:
db_uri = LOCALDB
# API Key for Groq
api_key = st.sidebar.text_input(label="Groq API Key", type="password")
# Input validation
if not api_key:
st.warning("Please add the Groq API key to proceed.")
st.stop() # Stop execution if API key is not provided
# LLM Model Initialization
try:
llm = ChatGroq(groq_api_key=api_key,
model_name="Llama3-8b-8192", streaming=True)
except Exception as e:
st.error(f"Error initializing the LLM: {e}")
st.stop()
# Cached Database Configuration
@st.cache_resource(ttl=7200)
def configure_db(db_uri, mysql_host=None, mysql_user=None, mysql_password=None, mysql_db=None):
if db_uri == LOCALDB:
db_filepath = Path(__file__).parent / "VALUE_MAINTENANCE.db"
if not db_filepath.exists():
st.error(f"SQLite database file not found: {db_filepath}")
st.stop()
return SQLDatabase.from_uri(f"sqlite:///{db_filepath}")
elif db_uri == USE_MYSQL:
if not (mysql_host and mysql_user and mysql_password and mysql_db):
st.error("Please provide all MySQL connection details.")
st.stop()
# Ensure host is provided, default to 'localhost' if empty
mysql_host = mysql_host or 'localhost'
# Construct the MySQL connection URI correctly
mysql_uri = f"mysql+mysqlconnector://{mysql_user}:{mysql_password}@{mysql_host}/{mysql_db}"
return SQLDatabase.from_uri(mysql_uri)
# Initialize database connection
try:
if db_uri == USE_MYSQL:
db = configure_db(db_uri, mysql_host, mysql_user,
mysql_password, mysql_db)
else:
db = configure_db(db_uri)
except Exception as e:
st.error(f"Database connection error: {e}")
st.stop()
# Toolkits and Agents
try:
callback_container = st.container() # Streamlit container for callback output
callback_handler = StreamlitCallbackHandler(
parent_container=callback_container)
# Corrected agent type
agent = create_sql_agent(
llm=llm, db=db, agent_type="zero-shot-react-description", verbose=True)
except Exception as e:
st.error(f"Error setting up the agent: {e}")
st.stop()
# Chat interface
st.subheader("Chat with Your Database")
query = st.text_input("Enter your SQL query or ask a question:")
if st.button("Submit Query"):
if not query.strip():
st.warning("Please enter a valid query.")
else:
try:
with callback_container:
response = agent.run(query, callbacks=[callback_handler])
st.write(response)
except Exception as e:
st.error(f"Error processing query: {e}")