diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..3e430b5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "terminal.explorerKind": "both" +} \ No newline at end of file diff --git a/service/add_citation.py b/service/add_citation.py index a5899ab..7944f1b 100644 --- a/service/add_citation.py +++ b/service/add_citation.py @@ -1,84 +1,76 @@ import nltk from llama_index.retrievers.bm25 import BM25Retriever +from typing import List, Tuple, Optional -def split_sentences(text): +def split_sentences(text: str) -> List[str]: return nltk.sent_tokenize(text) -def add_citation_with_retrieved_node(retrieved_nodes, final_response): - if retrieved_nodes is None or len(retrieved_nodes) <= 0: +def add_citation_with_retrieved_node(retrieved_nodes: Optional[List], final_response: str) -> str: + if not retrieved_nodes: return final_response + bm25_retriever = BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2) sentences = [sentence for sentence in split_sentences(final_response) if len(sentence) > 20] - start = 0 - cite_cnt = 1 - threshold = 13.5 - cited_paper_id_to_cnt = {} + + THRESHOLD = 13.5 + cited_papers = {} cited_paper_list = [] + for sentence in sentences: - left = final_response.find(sentence, start) - right = left + len(sentence) relevant_nodes = bm25_retriever.retrieve(sentence) - if len(relevant_nodes) == 0 or len(sentence.strip()) < 20: - start = right - continue - if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']: - paper1 = relevant_nodes[0] - paper1_id = paper1.node.metadata['id'] - paper1_title = paper1.node.metadata['title'] - if paper1.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) + if not relevant_nodes or len(sentence.strip()) < 20: continue - paper1 = relevant_nodes[0] - paper2 = relevant_nodes[1] - paper1_id = paper1.node.metadata['id'] - paper1_title = paper1.node.metadata['title'] - paper2_id = paper2.node.metadata['id'] - paper2_title = paper2.node.metadata['title'] - if paper1.score > threshold and paper2.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - if paper2_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper2_id] = cite_cnt - cited_paper_list.append((paper2_id, paper2_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id] - if paper1_cite_cnt > paper2_cite_cnt: - paper1_cite_cnt, paper2_cite_cnt = paper2_cite_cnt, paper1_cite_cnt - paper1_id, paper2_id = paper2_id, paper1_id - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - elif paper1.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - elif paper2.score > threshold: - if paper2_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper2_id] = cite_cnt - cited_paper_list.append((paper2_id, paper2_title)) - cite_cnt += 1 - paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id] - cite_str = f"[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - cited_list_str = "" - for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1): - cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n""" - if len(cited_list_str) > 0: - final_response += "\n\n**REFERENCES**\n\n" + cited_list_str - return final_response \ No newline at end of file + + citations = process_relevant_nodes(relevant_nodes, THRESHOLD, cited_papers, cited_paper_list) + if citations: + final_response = insert_citation(final_response, sentence, citations) + + if cited_paper_list: + final_response += generate_references(cited_paper_list) + + return final_response + +def process_relevant_nodes(relevant_nodes: List, threshold: float, cited_papers: dict, cited_paper_list: List) -> str: + if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']: + return process_single_paper(relevant_nodes[0], threshold, cited_papers, cited_paper_list) + else: + return process_two_papers(relevant_nodes[0], relevant_nodes[1], threshold, cited_papers, cited_paper_list) + +def process_single_paper(paper, threshold: float, cited_papers: dict, cited_paper_list: List) -> str: + if paper.score <= threshold: + return "" + + paper_id = paper.node.metadata['id'] + paper_title = paper.node.metadata['title'] + + if paper_id not in cited_papers: + new_cite_number = len(cited_papers) + 1 + cited_papers[paper_id] = new_cite_number + cited_paper_list.append((paper_id, paper_title)) + else: + new_cite_number = cited_papers[paper_id] + + return f"[[{new_cite_number}]](https://arxiv.org/abs/{paper_id})" + +def process_two_papers(paper1, paper2, threshold: float, cited_papers: dict, cited_paper_list: List) -> str: + citations = [] + for paper in (paper1, paper2): + if paper.score > threshold: + citation = process_single_paper(paper, threshold, cited_papers, cited_paper_list) + if citation: + citations.append(citation) + + return "".join(citations) + +def insert_citation(text: str, sentence: str, citation: str) -> str: + start = text.find(sentence) + if start == -1: + return text + end = start + len(sentence) + return f"{text[:end]}{citation}{text[end:]}" + +def generate_references(cited_paper_list: List[Tuple[str, str]]) -> str: + references = "\n\n**REFERENCES**\n\n" + for idx, (paper_id, paper_title) in enumerate(cited_paper_list, start=1): + references += f"[[{idx}] {paper_title}](https://arxiv.org/abs/{paper_id})\n\n" + return references \ No newline at end of file diff --git a/service/nodes_arrangement.py b/service/nodes_arrangement.py index f7a3b3d..110d40f 100644 --- a/service/nodes_arrangement.py +++ b/service/nodes_arrangement.py @@ -1,48 +1,40 @@ from functools import cmp_to_key +from collections import defaultdict def format_metadata(metadata): - meta_format = "" - if 'id' in metadata: - meta_format += f"id: {metadata['id']}\n" - if 'title' in metadata: - meta_format += f"title: {metadata['title']}\n" - if 'authors' in metadata: - meta_format += f"authors: {metadata['authors']}\n" - if 'journal-ref' in metadata: - meta_format += f"journal-ref: {metadata['journal-ref']}\n" - if 'categories' in metadata: - meta_format += f"categories: {metadata['categories']}\n" - if 'paper_time' in metadata: - meta_format += f"paper_time: {metadata['paper_time']}\n" - return meta_format - + return "\n".join( + f"{key}: {metadata[key]}" for key in [ + 'id', 'title', 'authors', 'journal-ref', 'categories', 'paper_time' + ] if key in metadata + ) def compare_node(a, b): - if a.score > b.score: - return -1 - return 1 + return -1 if a.score > b.score else 1 def nodes_arrangement(nodes): - paper_dict = {} + paper_dict = defaultdict(list) nodes = sorted(nodes, key=cmp_to_key(compare_node)) + for node in nodes: - paper_id = node.metadata['id'] - if paper_id not in paper_dict: - paper_dict[paper_id] = [] - paper_dict[paper_id].append(node) + paper_dict[node.metadata['id']].append(node) + contents = [] - vis_dict = {} - for node in nodes: - paper_id = node.metadata['id'] - if paper_id in vis_dict: - continue - vis_dict[paper_id] = True - # for paper_id, cur_papers in paper_dict.items(): - cur_papers = paper_dict[paper_id] - cur_papers.sort(key=lambda node: node.node_id) - metadata = cur_papers[0].metadata - cur_content = format_metadata(metadata) - for node in cur_papers: - cur_content += node.text + '\n' - contents.append(cur_content.strip()) - return contents \ No newline at end of file + vis_dict = {node.metadata['id']: False for node in nodes} + + actions = { + True: lambda paper_id, cur_papers: None, + False: lambda paper_id, cur_papers: handle_new_paper(paper_id, cur_papers, vis_dict, contents) + } + + for paper_id, cur_papers in paper_dict.items(): + actions[vis_dict[paper_id]](paper_id, cur_papers) + + return contents + +def handle_new_paper(paper_id, cur_papers, vis_dict, contents): + vis_dict[paper_id] = True + cur_papers.sort(key=lambda n: n.node_id) + metadata = cur_papers[0].metadata + cur_content = format_metadata(metadata) + cur_content += "\n" + "\n".join(node.text for node in cur_papers) + contents.append(cur_content.strip()) diff --git a/service/qdrant_retriever.py b/service/qdrant_retriever.py index 2231b11..e60f29c 100644 --- a/service/qdrant_retriever.py +++ b/service/qdrant_retriever.py @@ -1,136 +1,142 @@ +from typing import List, Optional, Tuple +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import lru_cache + +import pickle +import requests from llama_index.core import QueryBundle from llama_index.core.schema import NodeWithScore -from llama_index.core.vector_stores import VectorStoreQuery from llama_index.core.retrievers import BaseRetriever +from llama_index.core.postprocessor import BaseNodePostprocessor + from config import category_list, qdrant_month_list, qdrant_collection_prefix -from typing import List -from llama_index.core.base.base_retriever import BaseRetriever -from utils.qdrant_helper import * -from config import qdrant_collection_prefix from service.field_selector import field_selector from service.date_selector import date_selector -import concurrent.futures -import pickle -import requests -def get_retrieve_nodes(retriever, query): - return retriever.retrieve(query) +@dataclass +class RetrievalResult: + nodes: List[NodeWithScore] + category: str + month: str -class SingleQdrantRetriever(BaseRetriever): +class QdrantRetriever(BaseRetriever): def __init__( self, - embed_model, - vector_store, - similarity_top_k, - hybrid_top_k + qdrant_api_url: str, + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, ): - """Init params.""" + """Initialize the QdrantRetriever. + + Args: + qdrant_api_url (str): The URL for the Qdrant API endpoint. + node_postprocessors (Optional[List[BaseNodePostprocessor]]): List of node postprocessors. + """ super().__init__() - self.embed_model = embed_model - self.vector_store = vector_store - self.similarity_top_k = similarity_top_k - self.hybrid_top_k = hybrid_top_k - + self.qdrant_api_url = qdrant_api_url + self.node_postprocessors = node_postprocessors or [] - def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - emb = self.embed_model.get_text_embedding_batch([query_bundle.query_str]) - q = VectorStoreQuery( - query_embedding=emb[0], - similarity_top_k=self.similarity_top_k, - query_str=query_bundle.query_str, - hybrid_top_k=self.hybrid_top_k, - mode='hybrid', + @lru_cache(maxsize=128) + def _get_collection_name(self, category: str, month: str) -> str: + """Generate a collection name based on category and month. + + Args: + category (str): The category of the collection. + month (str): The month of the collection. + + Returns: + str: The formatted collection name. + """ + return f"{qdrant_collection_prefix}_{category}_{month}".replace(" ", "_").lower() + + def _api_retrieve(self, query: str, collection_name: str) -> List[NodeWithScore]: + """Retrieve nodes from the Qdrant API. + + Args: + query (str): The query string. + collection_name (str): The name of the collection to query. + + Returns: + List[NodeWithScore]: List of retrieved nodes with scores. + """ + response = requests.post( + self.qdrant_api_url, + json={"query": query, "collection_name": collection_name} ) - result = self.vector_store.query(q) - nodes = [] - for i in range(len(result.nodes)): - nodes.append(NodeWithScore(node=result.nodes[i], - score=result.similarities[i])) - return nodes + return pickle.loads(response.content) + + def _retrieve_parallel(self, query: str, categories: List[str], months: List[str]) -> List[RetrievalResult]: + """Perform parallel retrieval across multiple categories and months. + + Args: + query (str): The query string. + categories (List[str]): List of categories to search. + months (List[str]): List of months to search. + + Returns: + List[RetrievalResult]: List of retrieval results. + """ + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._api_retrieve, query, self._get_collection_name(category, month)) + for category in categories + for month in months + ] + results = [ + RetrievalResult(future.result(), category, month) + for future, (category, month) in zip(futures, + [(cat, mon) for cat in categories for mon in months]) + ] + return results + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Retrieve nodes based on the query bundle. + + Args: + query_bundle (QueryBundle): The query bundle containing the query string. + + Returns: + List[NodeWithScore]: List of retrieved nodes with scores. + """ + query_str = query_bundle.query_str + categories = field_selector(query_str) or category_list + months = date_selector(query_str, range_type="qdrant") or qdrant_month_list + + results = self._retrieve_parallel(query_str, categories, months) + nodes = [node for result in results for node in result.nodes] -class QdrantRetriever(BaseRetriever): - def __init__( - self, - qdrant_api_url - # embed_model, - # similarity_top_k, - # hybrid_top_k, - # node_postprocessors, - # insert_batch_size=10, - ): - """Init params.""" - super().__init__() - self.qdrant_api_url = qdrant_api_url - # client = QdrantClient(host=qdrant_host, port=qdrant_port) - # aclient = AsyncQdrantClient(host=qdrant_host, port=qdrant_port) - # self.node_postprocessors = node_postprocessors - # self.embed_model = embed_model - # self.similarity_top_k = similarity_top_k - # self.hybrid_top_k = hybrid_top_k - # self.retrievers_dict = {} - # for category in category_list: - # for month in qdrant_month_list: - # collection_name = "{}_{}_{}".format(qdrant_collection_prefix, category, month).replace(" ", "_").lower() - # vector_store = QdrantVectorStore( - # collection_name, - # client=client, - # aclient=aclient, - # enable_hybrid=True, - # batch_size=insert_batch_size, - # sparse_doc_fn=sparse_doc_vectors, - # sparse_query_fn=sparse_query_vectors, - # hybrid_fusion_fn=reciprocal_rank_fusion, - # ) - # retriever = SingleQdrantRetriever(embed_model=embed_model, - # vector_store=vector_store, - # similarity_top_k=similarity_top_k, - # hybrid_top_k=hybrid_top_k) - # self.retrievers_dict[collection_name] = retriever - - def _retrieve(self, query, **kwargs): - nodes = [] - query_str = query.query_str - need_categories = field_selector(query_str) - if len(need_categories) == 0: - need_categories = category_list - need_months = date_selector(query_str, range_type="qdrant") - if need_months is None: - need_months = qdrant_month_list - cur_retrievers = [] - for category in need_categories: - for month in need_months: - collection_name = "{}_{}_{}".format(qdrant_collection_prefix, category, month).replace(" ", "_").lower() - cur_retrievers.append(self.retrievers_dict[collection_name]) - for retriever in cur_retrievers: - nodes += retriever.retrieve(query, **kwargs) for postprocessor in self.node_postprocessors: - nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query) - return nodes - - def api_retrieve(self, query: str, collection_name: str): - response = requests.post(self.qdrant_api_url, - json={"query": query, - "collection_name": collection_name}) - nodes = pickle.loads(response.content) - return nodes + nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle) - def api_retrieve_wrapper(self, args): - query_str, collection_name = args - return self.api_retrieve(query=query_str, collection_name=collection_name) - - def custom_retrieve(self, query, need_categories, need_months): - nodes = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for category in need_categories: - for month in need_months: - collection_name = "{}_{}_{}".format(qdrant_collection_prefix, category, month).replace(" ", "_").lower() - future = executor.submit( - self.api_retrieve_wrapper, - (query.query_str, collection_name) - ) - futures.append((future, category, month)) - for future, category, month in futures: - result = future.result() - nodes.extend(result) return nodes + + def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Public method to retrieve nodes based on the query bundle. + + Args: + query_bundle (QueryBundle): The query bundle containing the query string. + + Returns: + List[NodeWithScore]: List of retrieved nodes with scores. + """ + return self._retrieve(query_bundle) + +# Example usage: +if __name__ == "__main__": + # Initialize the retriever + retriever = QdrantRetriever( + qdrant_api_url="https://your-qdrant-api-endpoint.com/retrieve", + node_postprocessors=[ + # Add any postprocessors here + ] + ) + + # Create a query bundle + query = QueryBundle(query_str="Example query string") + + # Retrieve nodes + retrieved_nodes = retriever.retrieve(query) + + # Process the retrieved nodes + for node in retrieved_nodes: + print(f"Node: {node.node.text}, Score: {node.score}") \ No newline at end of file diff --git a/service/query_router.py b/service/query_router.py index 4d51670..6a1544e 100644 --- a/service/query_router.py +++ b/service/query_router.py @@ -1,10 +1,10 @@ from llm.chat_llm import chat from config import query_router_model -prompt = """You are an intelligent assistant tasked with analyzing user queries in the context of their conversation history with an AI assistant. -Your goal is to categorize user's latest query into one of three types based on its complexity and the resources needed to answer it accurately. +prompt_template = """You are an intelligent assistant tasked with analyzing user queries in the context of their conversation history with an AI assistant. +Your goal is to categorize the user's latest query into one of three types based on its complexity and the resources needed to answer it accurately. -1. Simple Question: A query that you can answer directly using your general knowledge or according the conversation history. +1. Simple Question: A query that you can answer directly using your general knowledge or according to the conversation history. 2. Complex Question: A query that requires in-depth information or specific data that would benefit from using a RAG (Retrieval-Augmented Generation) system to provide a comprehensive and accurate answer. @@ -33,27 +33,24 @@ Query Type: """ - def query_router(query_str, history_messages): - multi_turn_content = "" - for message in history_messages: - if message['role'] == "user": - multi_turn_content += "User: " + message['content'] + "\n" - else: - multi_turn_content += "Assistant: " + message['content'] + "\n" + # Efficiently build the conversation history string without creating intermediate lists + multi_turn_content = "\n".join( + (f"User: {msg['content']}" if msg['role'] == "user" else f"Assistant: {msg['content']}") + for msg in history_messages + ) + + # Format the prompt in a single step + formatted_prompt = prompt_template.format(query_str=query_str, multi_turn_content=multi_turn_content) + + # Prepare the message payload messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt.format(query_str=query_str, - multi_turn_content=multi_turn_content)} + {"role": "user", "content": formatted_prompt} ] - response = "" - completion = chat(messages, model = query_router_model) - for chunk in completion: - response += chunk - if response == "LLM": - return 0 - if response == "CHAT": - return 2 - # 0 = llm, 1 = rag, 2 = chat with paper, if can not decision, use rag - return 1 - \ No newline at end of file + + # Efficient response collection without unnecessary concatenation + response = ''.join(chat(messages, model=query_router_model)).strip() + + # Return the appropriate code based on the response + return 0 if response == "LLM" else 2 if response == "CHAT" else 1 diff --git a/service/query_understanding.py b/service/query_understanding.py index db81500..e7df8da 100644 --- a/service/query_understanding.py +++ b/service/query_understanding.py @@ -1,23 +1,23 @@ from llm.chat_llm import chat from config import query_understanding_model -assitant_prompt = """You are an literature assistant. Your task is to ask question to user to get better understanding of user query. To provide the best assistance, follow these steps when interacting with users: +assistant_prompt = """You are a literature assistant. Your task is to ask a question to the user to get a better understanding of their query. To provide the best assistance, follow these steps when interacting with users: [RULE] -1. May need to ask for the specific academic field or subject area they are interested in, or inquire about the time range for the literature search (e.g., recent years, last decade, specific period). +1. Ask for the specific academic field or subject area they are interested in, or inquire about the time range for the literature search (e.g., recent years, last decade, specific period). 2. Request additional details or context regarding their inquiry to narrow down the search (e.g., specific theories, key terms, notable authors). 3. Make sure your questions are clear and concise. -4. If you need to ask the user a question, your response should start with '[NEED MORE INFORMATION]', followed by your question. -5. If you feel you have gathered enough information, please reply with '[DONE]' only. -6. Please ask only the necessary questions to get enough information to proceed. -7. Only can ask 1 questions. +4. If you need to ask the user a question, start your response with '[NEED MORE INFORMATION]', followed by your question. +5. If you have gathered enough information, reply with '[DONE]' only. +6. Ask only the necessary questions to gather sufficient information to proceed. +7. Ask only one question at a time. -Note that "assistant" in the conversation history refers to you, meaning that if the assistant has already asked N questions, it indicates that you have used N questioning opportunities. +Note: "Assistant" in the conversation history refers to you, meaning that if the assistant has already asked N questions, it indicates that you have used N questioning opportunities. [Example Begin] User: What is PPO? -Your: [NEED MORE INFORMATION]I need more information to better answer your question. Could you provide the academic field of 'PPO' you are inquiring about, or please explain the question in more detail? +Your: [NEED MORE INFORMATION] I need more information to better answer your question. Could you provide the academic field of 'PPO' you are inquiring about, or please explain the question in more detail? User: Computer field Your: [DONE] @@ -39,9 +39,9 @@ Your Question or [Done]: """ -query_rewrite_prompt = """You are tasked with analyzing the conversation history between a user and an AI assistant, and then rewriting the user's query to incorporate relevant context from the conversation. +query_rewrite_prompt = """You are tasked with analyzing the conversation history between a user and an AI assistant, and then rewriting the user's query to incorporate relevant context from the conversation. Your goal is to create a more informative and context-aware query that will help the AI provide a more accurate and helpful response. -Please note that you cannot miss any query information from the user. +Please note that you cannot omit any query information from the user. Your response must only contain the new query phrased from the user's perspective. Here is an example you can refer to: @@ -51,74 +51,65 @@ Assistant: Could you specify the academic field or domain you are interested in for these datasets? User: NLP in computer science -Rewrited query: -Recommend some datasets for verifying the correctness of model responses to factual questions in the field of NLP(Natural Language Processing). +Rewritten query: +Recommend some datasets for verifying the correctness of model responses to factual questions in the field of NLP (Natural Language Processing). [Example End] --- Conversation history: {multi_turn_content} -Rewrited query: +Rewritten query: """ def split_last_newline(input_string): + """Splits the input string at the last newline and returns the string without the last line.""" input_string = input_string.strip() last_newline_index = input_string.rfind('\n') - if last_newline_index != -1: - return input_string[:last_newline_index] - else: - return input_string + return input_string[:last_newline_index] if last_newline_index != -1 else input_string -class Multi_Turn_Query_Understanding: - def __init__(self) -> None: +class MultiTurnQueryUnderstanding: + def __init__(self): pass + def _build_conversation_history(self, history_messages): + """Builds a string representation of the conversation history.""" + return "\n".join( + f"{msg['role'].capitalize()}: {msg['content']}" + for msg in history_messages + if msg.get('role') and msg.get('content') + ) + def query_understanding_chat(self, history_messages): - multi_turn_content = "" - for message in history_messages: - if message['role'] == "user": - multi_turn_content += "User: " + message['content'] + "\n" - else: - multi_turn_content += "Assistant: " + message['content'] + "\n" + """Processes the conversation history and generates a question or a '[DONE]' response.""" + multi_turn_content = self._build_conversation_history(history_messages) + formatted_prompt = assistant_prompt.format(multi_turn_content=multi_turn_content) + messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": assitant_prompt.format( - multi_turn_content=multi_turn_content) - } + {"role": "user", "content": formatted_prompt} ] - completion = chat(messages, model = query_understanding_model, stop=["\n"]) + + completion = chat(messages, model=query_understanding_model, stop=["\n"]) chat_message = "" - skip_str_1 = "[NEED MORE INFORMATION]" - same_idx_1 = 0 - skip_str_2 = "[DONE]" - same_idx_2 = 0 for chunk in completion: chat_message += chunk - idx_1 = 0 - while same_idx_1 < len(skip_str_1) and idx_1 < len(chunk) and skip_str_1[same_idx_1] == chunk[idx_1]: - same_idx_1 += 1 - idx_1 += 1 - idx_2 = 0 - while same_idx_2 < len(skip_str_2) and idx_2 < len(chunk) and skip_str_2[same_idx_2] == chunk[idx_2]: - same_idx_2 += 1 - idx_2 += 1 - chunk = chunk[max(idx_1, idx_2):] - yield chunk - + if chat_message.startswith("[DONE]"): + return "[DONE]" + if chat_message.startswith("[NEED MORE INFORMATION]"): + return chat_message.strip() + return chat_message.strip() + def query_rewrite_according_messages(self, history_messages): - multi_turn_content = "" - for message in history_messages: - if 'role' not in message or 'content' not in message or message['role'] is None or message['content'] is None: - continue - if message['role'] == "user": - multi_turn_content += "User: " + message['content'] + "\n" - else: - multi_turn_content += "Assistant: " + message['content'] + "\n" + """Rewrites the user's query based on the conversation history.""" + multi_turn_content = split_last_newline(self._build_conversation_history(history_messages)) + formatted_prompt = query_rewrite_prompt.format(multi_turn_content=multi_turn_content) + messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": query_rewrite_prompt.format(multi_turn_content=split_last_newline(multi_turn_content))} + {"role": "user", "content": formatted_prompt} ] - completion = chat(messages, model = query_understanding_model) - for chunk in completion: - yield chunk + + completion = chat(messages, model=query_understanding_model) + return "".join(completion).strip() +