From dd957d4ba53c2b4e809553d460ca3cdda2736768 Mon Sep 17 00:00:00 2001 From: Siddarth Date: Fri, 7 Feb 2025 12:58:01 -0500 Subject: [PATCH 1/3] Update rag.py --- src/rag.py | 89 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 71 insertions(+), 18 deletions(-) diff --git a/src/rag.py b/src/rag.py index 0f36934..9b7294f 100644 --- a/src/rag.py +++ b/src/rag.py @@ -1,5 +1,7 @@ -from typing import Any, List +from typing import Any, List, Tuple, Optional from uuid import uuid4 +import os +from datetime import datetime import adalflow as adal from adalflow.core.types import ( @@ -16,6 +18,7 @@ from config import configs from src.data_pipeline import DatabaseManager from adalflow.utils import printc +from dataclasses import dataclass, field class Memory(DataComponent): @@ -41,6 +44,14 @@ def add_dialog_turn(self, user_query: str, assistant_response: str): self.current_conversation.append_dialog_turn(dialog_turn) +@dataclass +class RAGAnswer(adal.DataClass): + rationale: str = field(default="", metadata={"desc": "Rationale for the answer."}) + answer: str = field(default="", metadata={"desc": "Answer to the user query."}) + + __output_fields__ = ["rationale", "answer"] + + system_prompt = r""" You are a code assistant which answer's user question on a Github Repo. You will receive user query, relevant context, and past conversation history. @@ -75,16 +86,6 @@ def add_dialog_turn(self, user_query: str, assistant_response: str): """ -from dataclasses import dataclass, field - - -@dataclass -class RAGAnswer(adal.DataClass): - rationale: str = field(default="", metadata={"desc": "Rationale for the answer."}) - answer: str = field(default="", metadata={"desc": "Answer to the user query."}) - - __output_fields__ = ["rationale", "answer"] - class RAG(adal.Component): __doc__ = """RAG with one repo. @@ -119,6 +120,7 @@ def __init__(self): model_kwargs=configs["generator"]["model_kwargs"], output_processors=data_parser, ) + self.previous_retrieved_documents = None def initialize_db_manager(self): self.db_manager = DatabaseManager() @@ -136,15 +138,59 @@ def prepare_retriever(self, repo_url_or_path: str): document_map_func=lambda doc: doc.vector, ) - def call(self, query: str) -> Any: + def is_clarification_query(self, query: str) -> bool: + """ + Determines if the current query is a clarification of a previous query. + """ + if not self.memory(): + return False + + clarification_prompt = f""" + You are a clarification detector. Analyze if the query is a follow-up or clarification of the previous conversation. + Your response should include: + - A rationale explaining your reasoning + - A clear True/False answer + + Output your response in this format: + {{ + "rationale": "Your step-by-step reasoning here", + "answer": "True or False" + }} + + Conversation History: + {self.memory()} + + Query: + {query} + """ + response = self.generator( + prompt_kwargs={ + "conversation_history": self.memory(), + "system_prompt": clarification_prompt, + }, + ) + + is_clarification = "true" in response.data.answer.lower() + return is_clarification - retrieved_documents = self.retriever(query) + def call(self, query: str) -> Tuple[Any, Any]: + previous_context = ( + self.previous_retrieved_documents[0].documents + if self.previous_retrieved_documents + else None + ) - # fill in the document - retrieved_documents[0].documents = [ - self.transformed_docs[doc_index] - for doc_index in retrieved_documents[0].doc_indices - ] + is_clarification = self.is_clarification_query(query) + + if is_clarification and self.previous_retrieved_documents: + retrieved_documents = self.previous_retrieved_documents + else: + retrieved_documents = self.retriever(query) + retrieved_documents[0].documents = [ + self.transformed_docs[doc_index] + for doc_index in retrieved_documents[0].doc_indices + ] + self.previous_retrieved_documents = retrieved_documents printc(f"retrieved_documents: {retrieved_documents[0].documents}") printc(f"memory: {self.memory()}") @@ -168,6 +214,13 @@ def call(self, query: str) -> Any: return final_response, retrieved_documents + def _format_doc_paths(self, documents: List[Any]) -> str: + """Helper to format document paths for logging""" + return "\n ".join( + f"- {doc.meta_data.get('file_path', 'unknown')}" + for doc in documents + ) + if __name__ == "__main__": from adalflow.utils import get_logger From 734bd66877af238e189b86d7a16ff14d39ee91d8 Mon Sep 17 00:00:00 2001 From: Siddarth Date: Fri, 7 Feb 2025 12:58:30 -0500 Subject: [PATCH 2/3] Update rag.py --- src/rag.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rag.py b/src/rag.py index 9b7294f..31c6f39 100644 --- a/src/rag.py +++ b/src/rag.py @@ -1,7 +1,5 @@ from typing import Any, List, Tuple, Optional from uuid import uuid4 -import os -from datetime import datetime import adalflow as adal from adalflow.core.types import ( From 7bc86f2df33b133462645c1f1ba87b4904ebb95e Mon Sep 17 00:00:00 2001 From: Siddarth Date: Fri, 7 Feb 2025 12:59:27 -0500 Subject: [PATCH 3/3] Update rag.py --- src/rag.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/rag.py b/src/rag.py index 31c6f39..f011f4e 100644 --- a/src/rag.py +++ b/src/rag.py @@ -27,7 +27,6 @@ def __init__(self): self.current_conversation = Conversation() def call(self) -> List[DialogTurn]: - all_dialog_turns = self.current_conversation.dialog_turns return all_dialog_turns @@ -90,7 +89,6 @@ class RAG(adal.Component): If you want to load a new repo. You need to call prepare_retriever(repo_url_or_path) first.""" def __init__(self): - super().__init__() # Initialize embedder, generator, and db_manager @@ -215,8 +213,7 @@ def call(self, query: str) -> Tuple[Any, Any]: def _format_doc_paths(self, documents: List[Any]) -> str: """Helper to format document paths for logging""" return "\n ".join( - f"- {doc.meta_data.get('file_path', 'unknown')}" - for doc in documents + f"- {doc.meta_data.get('file_path', 'unknown')}" for doc in documents )