Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"terminal.explorerKind": "both"
}
136 changes: 64 additions & 72 deletions service/add_citation.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,76 @@
import nltk
from llama_index.retrievers.bm25 import BM25Retriever
from typing import List, Tuple, Optional

def split_sentences(text):
def split_sentences(text: str) -> List[str]:
return nltk.sent_tokenize(text)

def add_citation_with_retrieved_node(retrieved_nodes, final_response):
if retrieved_nodes is None or len(retrieved_nodes) <= 0:
def add_citation_with_retrieved_node(retrieved_nodes: Optional[List], final_response: str) -> str:
if not retrieved_nodes:
return final_response

bm25_retriever = BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2)
sentences = [sentence for sentence in split_sentences(final_response) if len(sentence) > 20]
start = 0
cite_cnt = 1
threshold = 13.5
cited_paper_id_to_cnt = {}

THRESHOLD = 13.5
cited_papers = {}
cited_paper_list = []

for sentence in sentences:
left = final_response.find(sentence, start)
right = left + len(sentence)
relevant_nodes = bm25_retriever.retrieve(sentence)
if len(relevant_nodes) == 0 or len(sentence.strip()) < 20:
start = right
continue
if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']:
paper1 = relevant_nodes[0]
paper1_id = paper1.node.metadata['id']
paper1_title = paper1.node.metadata['title']
if paper1.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
if not relevant_nodes or len(sentence.strip()) < 20:
continue
paper1 = relevant_nodes[0]
paper2 = relevant_nodes[1]
paper1_id = paper1.node.metadata['id']
paper1_title = paper1.node.metadata['title']
paper2_id = paper2.node.metadata['id']
paper2_title = paper2.node.metadata['title']
if paper1.score > threshold and paper2.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
if paper2_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper2_id] = cite_cnt
cited_paper_list.append((paper2_id, paper2_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id]
if paper1_cite_cnt > paper2_cite_cnt:
paper1_cite_cnt, paper2_cite_cnt = paper2_cite_cnt, paper1_cite_cnt
paper1_id, paper2_id = paper2_id, paper1_id
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
elif paper1.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
elif paper2.score > threshold:
if paper2_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper2_id] = cite_cnt
cited_paper_list.append((paper2_id, paper2_title))
cite_cnt += 1
paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id]
cite_str = f"[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
cited_list_str = ""
for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1):
cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n"""
if len(cited_list_str) > 0:
final_response += "\n\n**REFERENCES**\n\n" + cited_list_str
return final_response

citations = process_relevant_nodes(relevant_nodes, THRESHOLD, cited_papers, cited_paper_list)
if citations:
final_response = insert_citation(final_response, sentence, citations)

if cited_paper_list:
final_response += generate_references(cited_paper_list)

return final_response

def process_relevant_nodes(relevant_nodes: List, threshold: float, cited_papers: dict, cited_paper_list: List) -> str:
if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']:
return process_single_paper(relevant_nodes[0], threshold, cited_papers, cited_paper_list)
else:
return process_two_papers(relevant_nodes[0], relevant_nodes[1], threshold, cited_papers, cited_paper_list)

def process_single_paper(paper, threshold: float, cited_papers: dict, cited_paper_list: List) -> str:
if paper.score <= threshold:
return ""

paper_id = paper.node.metadata['id']
paper_title = paper.node.metadata['title']

if paper_id not in cited_papers:
new_cite_number = len(cited_papers) + 1
cited_papers[paper_id] = new_cite_number
cited_paper_list.append((paper_id, paper_title))
else:
new_cite_number = cited_papers[paper_id]

return f"[[{new_cite_number}]](https://arxiv.org/abs/{paper_id})"

def process_two_papers(paper1, paper2, threshold: float, cited_papers: dict, cited_paper_list: List) -> str:
citations = []
for paper in (paper1, paper2):
if paper.score > threshold:
citation = process_single_paper(paper, threshold, cited_papers, cited_paper_list)
if citation:
citations.append(citation)

return "".join(citations)

def insert_citation(text: str, sentence: str, citation: str) -> str:
start = text.find(sentence)
if start == -1:
return text
end = start + len(sentence)
return f"{text[:end]}{citation}{text[end:]}"

def generate_references(cited_paper_list: List[Tuple[str, str]]) -> str:
references = "\n\n**REFERENCES**\n\n"
for idx, (paper_id, paper_title) in enumerate(cited_paper_list, start=1):
references += f"[[{idx}] {paper_title}](https://arxiv.org/abs/{paper_id})\n\n"
return references
68 changes: 30 additions & 38 deletions service/nodes_arrangement.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,40 @@
from functools import cmp_to_key
from collections import defaultdict

def format_metadata(metadata):
meta_format = ""
if 'id' in metadata:
meta_format += f"id: {metadata['id']}\n"
if 'title' in metadata:
meta_format += f"title: {metadata['title']}\n"
if 'authors' in metadata:
meta_format += f"authors: {metadata['authors']}\n"
if 'journal-ref' in metadata:
meta_format += f"journal-ref: {metadata['journal-ref']}\n"
if 'categories' in metadata:
meta_format += f"categories: {metadata['categories']}\n"
if 'paper_time' in metadata:
meta_format += f"paper_time: {metadata['paper_time']}\n"
return meta_format

return "\n".join(
f"{key}: {metadata[key]}" for key in [
'id', 'title', 'authors', 'journal-ref', 'categories', 'paper_time'
] if key in metadata
)

def compare_node(a, b):
if a.score > b.score:
return -1
return 1
return -1 if a.score > b.score else 1

def nodes_arrangement(nodes):
paper_dict = {}
paper_dict = defaultdict(list)
nodes = sorted(nodes, key=cmp_to_key(compare_node))

for node in nodes:
paper_id = node.metadata['id']
if paper_id not in paper_dict:
paper_dict[paper_id] = []
paper_dict[paper_id].append(node)
paper_dict[node.metadata['id']].append(node)

contents = []
vis_dict = {}
for node in nodes:
paper_id = node.metadata['id']
if paper_id in vis_dict:
continue
vis_dict[paper_id] = True
# for paper_id, cur_papers in paper_dict.items():
cur_papers = paper_dict[paper_id]
cur_papers.sort(key=lambda node: node.node_id)
metadata = cur_papers[0].metadata
cur_content = format_metadata(metadata)
for node in cur_papers:
cur_content += node.text + '\n'
contents.append(cur_content.strip())
return contents
vis_dict = {node.metadata['id']: False for node in nodes}

actions = {
True: lambda paper_id, cur_papers: None,
False: lambda paper_id, cur_papers: handle_new_paper(paper_id, cur_papers, vis_dict, contents)
}

for paper_id, cur_papers in paper_dict.items():
actions[vis_dict[paper_id]](paper_id, cur_papers)

return contents

def handle_new_paper(paper_id, cur_papers, vis_dict, contents):
vis_dict[paper_id] = True
cur_papers.sort(key=lambda n: n.node_id)
metadata = cur_papers[0].metadata
cur_content = format_metadata(metadata)
cur_content += "\n" + "\n".join(node.text for node in cur_papers)
contents.append(cur_content.strip())
Loading