Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions experimental/knowledge_graph_rag/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import os
import streamlit as st
from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex
from utils.preprocessor import extract_triples

from utils.preprocessor import get_list_of_directories, has_pdf_files
from llama_index.core import ServiceContext
import multiprocessing
import pandas as pd
Expand All @@ -25,17 +25,8 @@
from vectorstore.search import SearchHandler
from langchain_nvidia_ai_endpoints import ChatNVIDIA

def load_data(input_dir, num_workers):
reader = SimpleDirectoryReader(input_dir=input_dir)
documents = reader.load_data(num_workers=num_workers)
return documents

def has_pdf_files(directory):
for file in os.listdir(directory):
if file.endswith(".pdf"):
return True
return False

st.set_page_config(page_title="Knowledge Graph RAG")
st.title("Knowledge Graph RAG")

st.subheader("Load Data from Files")
Expand All @@ -52,19 +43,16 @@ def has_pdf_files(directory):
llm = ChatNVIDIA(model=llm)

def app():
# Get the current working directory
cwd = os.getcwd()

# Get a list of visible directories in the current working directory
directories = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d)) and not d.startswith('.') and '__' not in d]

directories = get_list_of_directories()
# Create a dropdown menu for directory selection
selected_dir = st.selectbox("Select a directory:", directories, index=0)
selected_dir = st.selectbox("Select a directory:", directories, index=None)

# Construct the full path of the selected directory
directory = os.path.join(cwd, selected_dir)
if selected_dir:
directory = os.path.join(os.getcwd(), selected_dir)

if st.button("Process Documents"):
if st.button("Process Documents", disabled=(selected_dir is None)):
# Check if the selected directory has PDF files
res = has_pdf_files(directory)
if not res:
Expand Down
155 changes: 80 additions & 75 deletions experimental/knowledge_graph_rag/pages/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import streamlit as st
import json
import networkx as nx
st.set_page_config(layout = "wide")
st.set_page_config(page_title="Knowledge Graph RAG", layout = "wide")

from langchain_community.callbacks.streamlit import StreamlitCallbackHandler

Expand All @@ -29,77 +29,82 @@

from vectorstore.search import SearchHandler

G = nx.read_graphml("knowledge_graph.graphml")
graph = NetworkxEntityGraph(G)

models = ChatNVIDIA.get_available_models()
available_models = [model.id for model in models if model.model_type=="chat" and "instruct" in model.id]

with st.sidebar:
llm = st.selectbox("Choose an LLM", available_models, index=available_models.index("mistralai/mixtral-8x7b-instruct-v0.1"))
st.write("You selected: ", llm)
llm = ChatNVIDIA(model=llm)

st.subheader("Chat with your knowledge graph!")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

with st.sidebar:
use_kg = st.toggle("Use knowledge graph")

user_input = st.chat_input("Can you tell me how research helps users to solve problems?")

graph_chain = GraphQAChain.from_llm(llm = llm, graph=graph, verbose=True)

prompt_template = ChatPromptTemplate.from_messages(
[("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")]
)

chain = prompt_template | llm | StrOutputParser()
search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True)

if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)

with st.chat_message("assistant"):
if use_kg:
entity_string = llm.invoke("""Return a JSON with a single key 'entities' and list of entities within this user query. Each element in your list MUST BE part of the user's query. Do not provide any explanation. If the returned list is not parseable in Python, you will be heavily penalized. For example, input: 'What is the difference between Apple and Google?' output: ['Apple', 'Google']. Always follow this output format. Here's the user query: """ + user_input)
try:
entities = json.loads(entity_string.content)['entities']
with st.expander("Extracted triples"):
st.code(entities)
res = search_handler.search_and_rerank(user_input, k=5)
with st.expander("Retrieved and Reranked Sparse-Dense Hybrid Search"):
st.write(res)
context = "Here are the relevant passages from the knowledge base: \n\n" + "\n".join(item.text for item in res)
all_triplets = []
for entity in entities:
all_triplets.extend(graph_chain.graph.get_entity_knowledge(entity, depth=2))
context += "\n\nHere are the relationships from the knowledge graph: " + "\n".join(all_triplets)
with st.expander("All triplets"):
st.code(context)
except Exception as e:
st.write("Faced exception: ", e)
context = "No graph triples were available to extract from the knowledge graph. Always provide a disclaimer if you know the answer to the user's question, since it is not grounded in the knowledge you are provided from the graph."
message_placeholder = st.empty()
full_response = ""

for response in chain.stream("Context: " + context + "\n\nUser query: " + user_input):
full_response += response
message_placeholder.markdown(full_response + "▌")
else:
message_placeholder = st.empty()
full_response = ""
for response in chain.stream(user_input):
full_response += response
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)

st.session_state.messages.append({"role": "assistant", "content": full_response})

try:
G = nx.read_graphml("knowledge_graph.graphml")
except:
st.subheader("Please upload documents to the knowledge base.")
else:
graph = NetworkxEntityGraph(G)

models = ChatNVIDIA.get_available_models()
available_models = [model.id for model in models if model.model_type=="chat" and "instruct" in model.id]

with st.sidebar:
llm = st.selectbox("Choose an LLM", available_models, index=available_models.index("mistralai/mixtral-8x7b-instruct-v0.1"))
st.write("You selected: ", llm)
llm = ChatNVIDIA(model=llm)

st.subheader("Chat with your knowledge graph!")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

with st.sidebar:
use_kg = st.toggle("Use knowledge graph")

user_input = st.chat_input("Can you tell me how research helps users to solve problems?")

graph_chain = GraphQAChain.from_llm(llm = llm, graph=graph, verbose=True)

prompt_template = ChatPromptTemplate.from_messages(
[("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")]
)

chain = prompt_template | llm | StrOutputParser()
search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True)

if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)

with st.chat_message("assistant"):
if use_kg:
entity_string = llm.invoke("""Return a JSON with a single key 'entities' and list of entities within this user query. Each element in your list MUST BE part of the user's query. Do not provide any explanation. If the returned list is not parseable in Python, you will be heavily penalized. For example, input: 'What is the difference between Apple and Google?' output: ['Apple', 'Google']. Always follow this output format. Here's the user query: """ + user_input)
try:
entities = json.loads(entity_string.content)['entities']
with st.expander("Extracted triples"):
st.code(entities)
res = search_handler.search_and_rerank(user_input, k=5)
with st.expander("Retrieved and Reranked Sparse-Dense Hybrid Search"):
st.write(res)
context = "Here are the relevant passages from the knowledge base: \n\n" + "\n".join(item.text for item in res)
all_triplets = []
for entity in entities:
all_triplets.extend(graph_chain.graph.get_entity_knowledge(entity, depth=2))
context += "\n\nHere are the relationships from the knowledge graph: " + "\n".join(all_triplets)
with st.expander("All triplets"):
st.code(context)
except Exception as e:
st.write("Faced exception: ", e)
context = "No graph triples were available to extract from the knowledge graph. Always provide a disclaimer if you know the answer to the user's question, since it is not grounded in the knowledge you are provided from the graph."
message_placeholder = st.empty()
full_response = ""

for response in chain.stream("Context: " + context + "\n\nUser query: " + user_input):
full_response += response
message_placeholder.markdown(full_response + "▌")
else:
message_placeholder = st.empty()
full_response = ""
for response in chain.stream(user_input):
full_response += response
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)

st.session_state.messages.append({"role": "assistant", "content": full_response})
70 changes: 36 additions & 34 deletions experimental/knowledge_graph_rag/pages/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import random

import streamlit as st
from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex
from utils.preprocessor import generate_qa_pair
from llama_index.core import ServiceContext
import multiprocessing
import altair as alt
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
from utils.lc_graph import process_documents, save_triples_to_csvs
from vectorstore.search import SearchHandler
from langchain_nvidia_ai_endpoints import ChatNVIDIA
import random
import pandas as pd
import time
import json
from concurrent.futures import ThreadPoolExecutor
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
from langchain_core.output_parsers import StrOutputParser
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain_core.prompts import ChatPromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from openai import OpenAI

from vectorstore.search import SearchHandler
from utils.lc_graph import process_documents
from utils.preprocessor import get_list_of_directories, has_pdf_files, generate_qa_pair

from concurrent.futures import ThreadPoolExecutor, as_completed

from openai import OpenAI
st.set_page_config(page_title="Knowledge Graph RAG")

reward_client = OpenAI(
base_url = "https://integrate.api.nvidia.com/v1",
api_key = os.environ["NVIDIA_API_KEY"]
Expand Down Expand Up @@ -80,17 +79,6 @@ def process_question(question, answer):
[("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")]
)

def load_data(input_dir, num_workers):
reader = SimpleDirectoryReader(input_dir=input_dir)
documents = reader.load_data(num_workers=num_workers)
return documents

def has_pdf_files(directory):
for file in os.listdir(directory):
if file.endswith(".pdf"):
return True
return False

def get_text_RAG_response(question):
chain = prompt_template | llm | StrOutputParser()

Expand Down Expand Up @@ -154,18 +142,16 @@ def get_combined_RAG_response(question):
num_data = st.slider("How many Q&A pairs to generate?", 10, 100, 50, step=10)

def app():
# Get the current working directory
cwd = os.getcwd()

# Get a list of visible directories in the current working directory
directories = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d)) and not d.startswith('.') and '__' not in d]

# Get a list of visible directories in the current working directory
directories = get_list_of_directories()
# Create a dropdown menu for directory selection
selected_dir = st.selectbox("Select a directory:", directories, index=0)
selected_dir = st.selectbox("Select a directory:", directories, index=None)

# Construct the full path of the selected directory
directory = os.path.join(cwd, selected_dir)
if st.button("Process Documents"):
if selected_dir:
directory = os.path.join(os.getcwd(), selected_dir)

if st.button("Process Documents", disabled=(selected_dir is None)):
# Check if the selected directory has PDF files
res = has_pdf_files(directory)
if not res:
Expand Down Expand Up @@ -261,5 +247,21 @@ def app():
st.write("First few rows of the updated data:")
st.dataframe(combined_results.head())

average_scores = combined_results.mean(axis=0, numeric_only=True)
rows = []
for index, value in average_scores.items():
metric, category = tuple(index.split("_"))
rows.append([metric, category, value])

final_df = pd.DataFrame(rows, columns=["metric", "category", "average score"])

gp_chart = alt.Chart(final_df).mark_bar().encode(
x="metric:N",
y="average score:Q",
xOffset="category:N",
color="category:N"
)
st.altair_chart(gp_chart, use_container_width=True)

if __name__ == "__main__":
app()
2 changes: 1 addition & 1 deletion experimental/knowledge_graph_rag/pages/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import streamlit as st
import streamlit.components.v1 as components

st.set_page_config(layout="wide")
st.set_page_config(page_title="Knowledge Graph RAG", layout="wide")

def app():
st.title("Visualize the Knowledge Graph!")
Expand Down
4 changes: 3 additions & 1 deletion experimental/knowledge_graph_rag/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ llama_index==0.10.50
networkx==3.2.1
numpy==1.24.1
pandas==2.2.2
psutil==6.0.0
pymilvus==2.4.3
Requests==2.32.3
pymilvus[model]
Requests==2.31.0
streamlit==1.30.0
unstructured[all-docs]
tqdm==4.66.1
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,21 @@ def download_paper(result, download_dir, max_retries=3, retry_delay=5):
def download_papers(search_terms, start_date, end_date, max_results=10, download_dir='papers', num_threads=4, max_retries=3, retry_delay=5):
# Create the search query based on search terms and dates
search_query = f"({search_terms}) AND submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
client = arxiv.Client()

search = arxiv.Search(
query=search_query,
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate,
)

# Create the download directory if it doesn't exist
os.makedirs(download_dir, exist_ok=True)

# Use a thread pool to download papers in parallel
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for result in tqdm(search.results(), total=max_results, unit='paper'):
for result in tqdm(client.results(search), total=max_results, unit='paper'):
print(result.title)
# Submit download tasks to the executor
future = executor.submit(download_paper, result, download_dir, max_retries, retry_delay)
futures.append(future)
Expand Down Expand Up @@ -110,6 +111,7 @@ def download_papers(search_terms, start_date, end_date, max_results=10, download
else:
end_date = datetime.now() # Default to today

search_query = ' AND '.join([f'all:"{term}"' for term in args.search_terms.split(",")])
# Call the download_papers function with the provided arguments
download_papers(args.search_terms, start_date, end_date, args.max_results, args.download_dir, args.num_threads, args.max_retries, args.retry_delay)
download_papers(search_query, start_date, end_date, args.max_results, args.download_dir, args.num_threads, args.max_retries, args.retry_delay)

Loading