Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions src/rag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Tuple, Optional
from uuid import uuid4

import adalflow as adal
Expand All @@ -16,6 +16,7 @@
from config import configs
from src.data_pipeline import DatabaseManager
from adalflow.utils import printc
from dataclasses import dataclass, field


class Memory(DataComponent):
Expand All @@ -26,7 +27,6 @@ def __init__(self):
self.current_conversation = Conversation()

def call(self) -> List[DialogTurn]:

all_dialog_turns = self.current_conversation.dialog_turns

return all_dialog_turns
Expand All @@ -41,6 +41,14 @@ def add_dialog_turn(self, user_query: str, assistant_response: str):
self.current_conversation.append_dialog_turn(dialog_turn)


@dataclass
class RAGAnswer(adal.DataClass):
rationale: str = field(default="", metadata={"desc": "Rationale for the answer."})
answer: str = field(default="", metadata={"desc": "Answer to the user query."})

__output_fields__ = ["rationale", "answer"]


system_prompt = r"""
You are a code assistant which answer's user question on a Github Repo.
You will receive user query, relevant context, and past conversation history.
Expand Down Expand Up @@ -75,23 +83,12 @@ def add_dialog_turn(self, user_query: str, assistant_response: str):
<END_OF_USER_PROMPT>
"""

from dataclasses import dataclass, field


@dataclass
class RAGAnswer(adal.DataClass):
rationale: str = field(default="", metadata={"desc": "Rationale for the answer."})
answer: str = field(default="", metadata={"desc": "Answer to the user query."})

__output_fields__ = ["rationale", "answer"]


class RAG(adal.Component):
__doc__ = """RAG with one repo.
If you want to load a new repo. You need to call prepare_retriever(repo_url_or_path) first."""

def __init__(self):

super().__init__()

# Initialize embedder, generator, and db_manager
Expand Down Expand Up @@ -119,6 +116,7 @@ def __init__(self):
model_kwargs=configs["generator"]["model_kwargs"],
output_processors=data_parser,
)
self.previous_retrieved_documents = None

def initialize_db_manager(self):
self.db_manager = DatabaseManager()
Expand All @@ -136,15 +134,59 @@ def prepare_retriever(self, repo_url_or_path: str):
document_map_func=lambda doc: doc.vector,
)

def call(self, query: str) -> Any:
def is_clarification_query(self, query: str) -> bool:
"""
Determines if the current query is a clarification of a previous query.
"""
if not self.memory():
return False

clarification_prompt = f"""
You are a clarification detector. Analyze if the query is a follow-up or clarification of the previous conversation.
Your response should include:
- A rationale explaining your reasoning
- A clear True/False answer

Output your response in this format:
{{
"rationale": "Your step-by-step reasoning here",
"answer": "True or False"
}}

Conversation History:
{self.memory()}

Query:
{query}
"""
response = self.generator(
prompt_kwargs={
"conversation_history": self.memory(),
"system_prompt": clarification_prompt,
},
)

retrieved_documents = self.retriever(query)
is_clarification = "true" in response.data.answer.lower()
return is_clarification

def call(self, query: str) -> Tuple[Any, Any]:
previous_context = (
self.previous_retrieved_documents[0].documents
if self.previous_retrieved_documents
else None
)

# fill in the document
retrieved_documents[0].documents = [
self.transformed_docs[doc_index]
for doc_index in retrieved_documents[0].doc_indices
]
is_clarification = self.is_clarification_query(query)

if is_clarification and self.previous_retrieved_documents:
retrieved_documents = self.previous_retrieved_documents
else:
retrieved_documents = self.retriever(query)
retrieved_documents[0].documents = [
self.transformed_docs[doc_index]
for doc_index in retrieved_documents[0].doc_indices
]
self.previous_retrieved_documents = retrieved_documents

printc(f"retrieved_documents: {retrieved_documents[0].documents}")
printc(f"memory: {self.memory()}")
Expand All @@ -168,6 +210,12 @@ def call(self, query: str) -> Any:

return final_response, retrieved_documents

def _format_doc_paths(self, documents: List[Any]) -> str:
"""Helper to format document paths for logging"""
return "\n ".join(
f"- {doc.meta_data.get('file_path', 'unknown')}" for doc in documents
)


if __name__ == "__main__":
from adalflow.utils import get_logger
Expand Down