diff --git a/src/rag.py b/src/rag.py index 0f36934..f011f4e 100644 --- a/src/rag.py +++ b/src/rag.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Tuple, Optional from uuid import uuid4 import adalflow as adal @@ -16,6 +16,7 @@ from config import configs from src.data_pipeline import DatabaseManager from adalflow.utils import printc +from dataclasses import dataclass, field class Memory(DataComponent): @@ -26,7 +27,6 @@ def __init__(self): self.current_conversation = Conversation() def call(self) -> List[DialogTurn]: - all_dialog_turns = self.current_conversation.dialog_turns return all_dialog_turns @@ -41,6 +41,14 @@ def add_dialog_turn(self, user_query: str, assistant_response: str): self.current_conversation.append_dialog_turn(dialog_turn) +@dataclass +class RAGAnswer(adal.DataClass): + rationale: str = field(default="", metadata={"desc": "Rationale for the answer."}) + answer: str = field(default="", metadata={"desc": "Answer to the user query."}) + + __output_fields__ = ["rationale", "answer"] + + system_prompt = r""" You are a code assistant which answer's user question on a Github Repo. You will receive user query, relevant context, and past conversation history. @@ -75,23 +83,12 @@ def add_dialog_turn(self, user_query: str, assistant_response: str): """ -from dataclasses import dataclass, field - - -@dataclass -class RAGAnswer(adal.DataClass): - rationale: str = field(default="", metadata={"desc": "Rationale for the answer."}) - answer: str = field(default="", metadata={"desc": "Answer to the user query."}) - - __output_fields__ = ["rationale", "answer"] - class RAG(adal.Component): __doc__ = """RAG with one repo. If you want to load a new repo. You need to call prepare_retriever(repo_url_or_path) first.""" def __init__(self): - super().__init__() # Initialize embedder, generator, and db_manager @@ -119,6 +116,7 @@ def __init__(self): model_kwargs=configs["generator"]["model_kwargs"], output_processors=data_parser, ) + self.previous_retrieved_documents = None def initialize_db_manager(self): self.db_manager = DatabaseManager() @@ -136,15 +134,59 @@ def prepare_retriever(self, repo_url_or_path: str): document_map_func=lambda doc: doc.vector, ) - def call(self, query: str) -> Any: + def is_clarification_query(self, query: str) -> bool: + """ + Determines if the current query is a clarification of a previous query. + """ + if not self.memory(): + return False + + clarification_prompt = f""" + You are a clarification detector. Analyze if the query is a follow-up or clarification of the previous conversation. + Your response should include: + - A rationale explaining your reasoning + - A clear True/False answer + + Output your response in this format: + {{ + "rationale": "Your step-by-step reasoning here", + "answer": "True or False" + }} + + Conversation History: + {self.memory()} + + Query: + {query} + """ + response = self.generator( + prompt_kwargs={ + "conversation_history": self.memory(), + "system_prompt": clarification_prompt, + }, + ) - retrieved_documents = self.retriever(query) + is_clarification = "true" in response.data.answer.lower() + return is_clarification + + def call(self, query: str) -> Tuple[Any, Any]: + previous_context = ( + self.previous_retrieved_documents[0].documents + if self.previous_retrieved_documents + else None + ) - # fill in the document - retrieved_documents[0].documents = [ - self.transformed_docs[doc_index] - for doc_index in retrieved_documents[0].doc_indices - ] + is_clarification = self.is_clarification_query(query) + + if is_clarification and self.previous_retrieved_documents: + retrieved_documents = self.previous_retrieved_documents + else: + retrieved_documents = self.retriever(query) + retrieved_documents[0].documents = [ + self.transformed_docs[doc_index] + for doc_index in retrieved_documents[0].doc_indices + ] + self.previous_retrieved_documents = retrieved_documents printc(f"retrieved_documents: {retrieved_documents[0].documents}") printc(f"memory: {self.memory()}") @@ -168,6 +210,12 @@ def call(self, query: str) -> Any: return final_response, retrieved_documents + def _format_doc_paths(self, documents: List[Any]) -> str: + """Helper to format document paths for logging""" + return "\n ".join( + f"- {doc.meta_data.get('file_path', 'unknown')}" for doc in documents + ) + if __name__ == "__main__": from adalflow.utils import get_logger