From 48f6b8f6b73f682ac8e6dec06f9d8df32cbb9590 Mon Sep 17 00:00:00 2001 From: Jay Prakash Date: Mon, 16 Mar 2026 09:16:33 +0530 Subject: [PATCH 1/3] chore: add mypy type checking for src/rag/ with all errors fixed --- mypy.ini | 15 +++++++++ requirements.txt | 3 ++ src/rag/llm_manager.py | 16 ++++----- src/rag/neet_rag.py | 14 ++++---- src/rag/vector_store.py | 75 +++++++++++++++++++++++++++-------------- 5 files changed, 83 insertions(+), 40 deletions(-) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..185445b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,15 @@ +[mypy] +files = src/rag/ +python_version = 3.11 +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true +ignore_missing_imports = true +follow_imports = silent + +[mypy-src.processors.*] +ignore_errors = true + +[mypy-src.utils.*] +ignore_errors = true diff --git a/requirements.txt b/requirements.txt index 6c1fc28..32e8ff8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,6 @@ redis>=4.0.0 boto3>=1.34.0 markdownify>=0.11.0 pandas>=2.0.0 + +# Type checking (dev) +mypy>=1.10.0 diff --git a/src/rag/llm_manager.py b/src/rag/llm_manager.py index cbbcb56..69995df 100644 --- a/src/rag/llm_manager.py +++ b/src/rag/llm_manager.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, cast import os import base64 @@ -10,13 +10,13 @@ def __init__( model: str = "llama3.2", api_key: Optional[str] = None, base_url: Optional[str] = None, - ): + ) -> None: self.provider = provider self.model = model - self.llm = None + self.llm: Optional[Any] = None self._initialize_llm(api_key, base_url) - def _initialize_llm(self, api_key: Optional[str], base_url: Optional[str]): + def _initialize_llm(self, api_key: Optional[str], base_url: Optional[str]) -> None: if self.provider == "ollama": try: from langchain_community.llms import Ollama @@ -33,7 +33,7 @@ def _initialize_llm(self, api_key: Optional[str], base_url: Optional[str]): self.llm = ChatOpenAI( model=self.model, - api_key=api_key or os.getenv("OPENAI_API_KEY"), + api_key=api_key or os.getenv("OPENAI_API_KEY"), # type: ignore[arg-type] # LangChain accepts raw str API keys at runtime base_url=base_url or os.getenv("OPENAI_BASE_URL"), temperature=0.7, ) @@ -44,9 +44,9 @@ def _initialize_llm(self, api_key: Optional[str], base_url: Optional[str]): try: from langchain_anthropic import ChatAnthropic - self.llm = ChatAnthropic( + self.llm = ChatAnthropic( # type: ignore[call-arg] # Runtime accepts model kwarg; stubs may differ model=self.model, - api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), + api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), # type: ignore[arg-type] # LangChain accepts raw str API keys at runtime temperature=0.7, ) except ImportError: @@ -115,7 +115,7 @@ def generate( response = self.llm.invoke(prompt) if hasattr(response, "content"): - return response.content + return cast(str, response.content) return str(response) def extract_image_context( diff --git a/src/rag/neet_rag.py b/src/rag/neet_rag.py index d328214..f01b529 100644 --- a/src/rag/neet_rag.py +++ b/src/rag/neet_rag.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional, Union +from typing import Dict, Any, List, Optional, Union, Tuple from pathlib import Path import os import re @@ -57,6 +57,7 @@ def __init__( youtube_subdir = os.path.join(resolved_persist_dir, "youtube") csv_subdir = os.path.join(resolved_persist_dir, "csv") has_split_indexes = os.path.isdir(youtube_subdir) or os.path.isdir(csv_subdir) + self.vector_manager: Union[VectorStoreManager, CompositeVectorStoreManager] if has_split_indexes: self.vector_manager = build_composite_manager( @@ -80,7 +81,7 @@ def __init__( self.prompt_builder = RAGPromptBuilder() self._vectorstore_loaded = False self.logger = logging.getLogger(__name__) - self._source_manager = None + self._source_manager: Optional[Any] = None self._source_title_cache: Dict[str, str] = {} @staticmethod @@ -97,7 +98,7 @@ def _is_meaningful_title(title: str) -> bool: return False return True - def _get_source_manager(self): + def _get_source_manager(self) -> Optional[Any]: if self._source_manager is not None: return self._source_manager try: @@ -360,6 +361,7 @@ def _dedupe_docs(self, docs: List[Document]) -> List[Document]: deduped = [] seen = set() for doc in docs: + key: Tuple[Any, ...] source_type = doc.metadata.get("source_type") or doc.metadata.get( "content_type", "" ) @@ -464,8 +466,8 @@ def _retrieve_docs_blended(self, question: str, top_k: int) -> List[Document]: @staticmethod def _is_youtube_doc(doc: Document) -> bool: - source_type = doc.metadata.get("source_type") or doc.metadata.get( - "content_type", "" + source_type = str( + doc.metadata.get("source_type") or doc.metadata.get("content_type", "") ) return source_type == "youtube" @@ -756,6 +758,6 @@ def get_stats(self) -> Dict[str, Any]: except Exception as e: return {"error": str(e)} - def reset_knowledge_base(self): + def reset_knowledge_base(self) -> None: self.vector_manager.delete_collection() self._vectorstore_loaded = False diff --git a/src/rag/vector_store.py b/src/rag/vector_store.py index bf74103..9d55db8 100644 --- a/src/rag/vector_store.py +++ b/src/rag/vector_store.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Tuple from pathlib import Path import contextlib import io @@ -6,7 +6,7 @@ import os from langchain_core.documents import Document -from langchain_core.embeddings import FakeEmbeddings +from langchain_core.embeddings import Embeddings, FakeEmbeddings from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain_openai import OpenAIEmbeddings @@ -17,11 +17,11 @@ class VectorStoreManager: def __init__( self, - persist_directory: str = None, + persist_directory: Optional[str] = None, embedding_provider: str = "huggingface", embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", embedding_dimension: int = 384, - ): + ) -> None: self.persist_directory = persist_directory or os.path.join( os.environ.get("DATA_DIR", "./data"), "faiss_index" ) @@ -29,11 +29,11 @@ def __init__( self.embedding_model = embedding_model self.embedding_dimension = embedding_dimension - self.embeddings = None - self.vectorstore = None + self.embeddings: Optional[Embeddings] = None + self.vectorstore: Optional[FAISS] = None self._initialize_embeddings() - def _initialize_embeddings(self): + def _initialize_embeddings(self) -> None: if self.embedding_provider == "huggingface": os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") @@ -52,7 +52,8 @@ def _build_hf_embeddings() -> HuggingFaceEmbeddings: self.embeddings = _build_hf_embeddings() elif self.embedding_provider == "openai": self.embeddings = OpenAIEmbeddings( - model="text-embedding-3-small", api_key=os.getenv("OPENAI_API_KEY") + model="text-embedding-3-small", + api_key=os.getenv("OPENAI_API_KEY"), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # LangChain accepts raw str API keys at runtime ) elif self.embedding_provider == "fake": self.embeddings = FakeEmbeddings(size=self.embedding_dimension) @@ -63,33 +64,41 @@ def _build_hf_embeddings() -> HuggingFaceEmbeddings: def create_vectorstore( self, documents: List[Document], collection_name: str = "neet_knowledge" - ): + ) -> FAISS: if not documents: raise ValueError("No documents provided") + if self.embeddings is None: + raise ValueError("Embeddings not initialized") + embeddings = self.embeddings + assert embeddings is not None Path(self.persist_directory).mkdir(parents=True, exist_ok=True) self.vectorstore = FAISS.from_documents( documents=documents, - embedding=self.embeddings, + embedding=embeddings, ) self.vectorstore.save_local(self.persist_directory) return self.vectorstore - def load_vectorstore(self, collection_name: str = "neet_knowledge"): + def load_vectorstore(self, collection_name: str = "neet_knowledge") -> FAISS: if not os.path.exists(self.persist_directory): raise FileNotFoundError(f"No vectorstore found at {self.persist_directory}") + if self.embeddings is None: + raise ValueError("Embeddings not initialized") + embeddings = self.embeddings + assert embeddings is not None self.vectorstore = FAISS.load_local( self.persist_directory, - self.embeddings, + embeddings, allow_dangerous_deserialization=True, ) return self.vectorstore - def add_documents(self, documents: List[Document]): + def add_documents(self, documents: List[Document]) -> FAISS: if self.vectorstore is None: return self.create_vectorstore(documents) @@ -114,7 +123,7 @@ def similarity_search_with_score( k: int = 5, filter: Optional[Dict[str, Any]] = None, fetch_k: Optional[int] = None, - ) -> List[tuple]: + ) -> List[Tuple[Document, float]]: if self.vectorstore is None: raise ValueError( "Vectorstore not initialized. Load or create a vectorstore first." @@ -129,7 +138,7 @@ def similarity_search_with_score( return self.vectorstore.similarity_search_with_score(**kwargs) - def delete_collection(self, collection_name: str = "neet_knowledge"): + def delete_collection(self, collection_name: str = "neet_knowledge") -> None: if os.path.exists(self.persist_directory): import shutil @@ -161,8 +170,15 @@ def delete_by_source_id_and_question_id( if self.vectorstore is None: self.load_vectorstore() + if self.vectorstore is None: + return 0 + if self.embeddings is None: + raise ValueError("Embeddings not initialized") + vectorstore = self.vectorstore + embeddings = self.embeddings + assert embeddings is not None - doc_map = getattr(self.vectorstore.docstore, "_dict", {}) + doc_map = getattr(vectorstore.docstore, "_dict", {}) all_docs = [doc for doc in doc_map.values() if isinstance(doc, Document)] keep_docs: List[Document] = [] @@ -183,7 +199,7 @@ def delete_by_source_id_and_question_id( if keep_docs: self.vectorstore = FAISS.from_documents( documents=keep_docs, - embedding=self.embeddings, + embedding=embeddings, ) self.vectorstore.save_local(self.persist_directory) else: @@ -204,8 +220,15 @@ def _delete_by_metadata_key( if self.vectorstore is None: self.load_vectorstore() + if self.vectorstore is None: + return 0 + if self.embeddings is None: + raise ValueError("Embeddings not initialized") + vectorstore = self.vectorstore + embeddings = self.embeddings + assert embeddings is not None - doc_map = getattr(self.vectorstore.docstore, "_dict", {}) + doc_map = getattr(vectorstore.docstore, "_dict", {}) all_docs = [doc for doc in doc_map.values() if isinstance(doc, Document)] keep_docs: List[Document] = [] @@ -230,7 +253,7 @@ def _delete_by_metadata_key( if keep_docs: self.vectorstore = FAISS.from_documents( documents=keep_docs, - embedding=self.embeddings, + embedding=embeddings, ) self.vectorstore.save_local(self.persist_directory) else: @@ -300,13 +323,13 @@ def _strip_source_type_filter( return rest or None @property - def vectorstore(self): + def vectorstore(self) -> Optional[FAISS]: for mgr in self._managers.values(): if mgr.vectorstore is not None: return mgr.vectorstore return None - def load_vectorstore(self, collection_name: str = "neet_knowledge"): + def load_vectorstore(self, collection_name: str = "neet_knowledge") -> None: errors: List[str] = [] for label, mgr in self._managers.items(): try: @@ -326,7 +349,7 @@ def load_vectorstore(self, collection_name: str = "neet_knowledge"): def create_vectorstore( self, documents: List[Document], collection_name: str = "neet_knowledge" - ): + ) -> None: buckets: Dict[str, List[Document]] = {k: [] for k in self._managers} for doc in documents: st = doc.metadata.get("source_type", self._default_source_type) @@ -339,7 +362,7 @@ def create_vectorstore( docs, collection_name=collection_name ) - def add_documents(self, documents: List[Document]): + def add_documents(self, documents: List[Document]) -> None: buckets: Dict[str, List[Document]] = {k: [] for k in self._managers} for doc in documents: st = doc.metadata.get("source_type", self._default_source_type) @@ -374,7 +397,7 @@ def similarity_search_with_score( k: int = 5, filter: Optional[Dict[str, Any]] = None, fetch_k: Optional[int] = None, - ) -> List[tuple]: + ) -> List[Tuple[Document, float]]: target = self._manager_for_filter(filter) if target is not None: remaining_filter = self._strip_source_type_filter(filter) @@ -382,7 +405,7 @@ def similarity_search_with_score( query=query, k=k, filter=remaining_filter, fetch_k=fetch_k ) - all_scored: List[tuple] = [] + all_scored: List[Tuple[Document, float]] = [] for mgr in self._managers.values(): if mgr.vectorstore is None: continue @@ -397,7 +420,7 @@ def similarity_search_with_score( all_scored.sort(key=lambda pair: pair[1]) return all_scored[:k] - def delete_collection(self, collection_name: str = "neet_knowledge"): + def delete_collection(self, collection_name: str = "neet_knowledge") -> None: for mgr in self._managers.values(): mgr.delete_collection(collection_name=collection_name) From c9c6f14cc6e8d30a4ae69018cee1c8eeb4dba38b Mon Sep 17 00:00:00 2001 From: Jay Prakash Date: Mon, 16 Mar 2026 09:33:38 +0530 Subject: [PATCH 2/3] test: add pytest suite for RAGPromptBuilder, CompositeVectorStoreManager, and NEETRAG (38 tests) --- requirements.txt | 4 + tests/conftest.py | 13 +- tests/test_composite_vector_store.py | 311 +++++++++++++++++++++++++++ tests/test_neet_rag_integration.py | 187 ++++++++++++++++ tests/test_prompt_builder.py | 173 +++++++++++++++ 5 files changed, 679 insertions(+), 9 deletions(-) create mode 100644 tests/test_composite_vector_store.py create mode 100644 tests/test_neet_rag_integration.py create mode 100644 tests/test_prompt_builder.py diff --git a/requirements.txt b/requirements.txt index 32e8ff8..14aebce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,3 +48,7 @@ pandas>=2.0.0 # Type checking (dev) mypy>=1.10.0 + +# Testing +pytest>=8.0.0 +pytest-cov>=5.0.0 diff --git a/tests/conftest.py b/tests/conftest.py index ba608c1..4bc0fc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,7 @@ import os -import shutil + def pytest_configure(config): - """Set up test environment variables to isolate storage before any tests run.""" - # Force the app to use the isolated test_data directory - os.environ["DATA_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_data")) - - # Optional: ensure it's clean before starting - test_faiss_dir = os.path.join(os.environ["DATA_DIR"], "faiss_index") - if os.path.exists(test_faiss_dir): - shutil.rmtree(test_faiss_dir) + os.environ["DATA_DIR"] = os.path.abspath( + os.path.join(os.path.dirname(__file__), "test_data") + ) diff --git a/tests/test_composite_vector_store.py b/tests/test_composite_vector_store.py new file mode 100644 index 0000000..ddf15ff --- /dev/null +++ b/tests/test_composite_vector_store.py @@ -0,0 +1,311 @@ +# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportAny=false, reportPrivateUsage=false, reportArgumentType=false, reportAttributeAccessIssue=false, reportUnusedCallResult=false, reportUnannotatedClassAttribute=false +import os +import sys +from pathlib import Path + +import pytest +from langchain_core.documents import Document + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from src.rag.vector_store import ( + CompositeVectorStoreManager, + VectorStoreManager, + build_composite_manager, +) + + +class RecordingManager: + def __init__(self, label: str, *, has_vectorstore: bool = True): + self.label = label + self.embedding_model = f"{label}-embedding" + self.embeddings = object() + self.persist_directory = f"/tmp/{label}" + self.vectorstore = object() if has_vectorstore else None + + self.added_batches = [] + self.search_calls = [] + self.search_with_score_calls = [] + self.delete_calls = [] + + def add_documents(self, documents): + self.added_batches.append(documents) + + def similarity_search(self, query, k=5, filter=None): + self.search_calls.append({"query": query, "k": k, "filter": filter}) + return [ + Document( + page_content=f"{self.label}-result-{i}", + metadata={"source_type": self.label, "rank": i}, + ) + for i in range(k) + ] + + def similarity_search_with_score(self, query, k=5, filter=None, fetch_k=None): + self.search_with_score_calls.append( + { + "query": query, + "k": k, + "filter": filter, + "fetch_k": fetch_k, + } + ) + return [ + ( + Document( + page_content=f"{self.label}-scored-{i}", + metadata={"source_type": self.label, "rank": i}, + ), + float(i), + ) + for i in range(k) + ] + + def delete_by_source(self, source, track_id=None): + self.delete_calls.append({"source": source, "track_id": track_id}) + return 1 + + def get_collection_info(self): + return { + "collection_name": self.label, + "persist_directory": self.persist_directory, + } + + +def _make_doc(text: str, source_type: str | None = None, **metadata): + doc_metadata = dict(metadata) + if source_type is not None: + doc_metadata["source_type"] = source_type + return Document(page_content=text, metadata=doc_metadata) + + +def _make_real_manager(path: Path) -> VectorStoreManager: + return VectorStoreManager( + persist_directory=str(path), + embedding_provider="fake", + embedding_dimension=16, + ) + + +def _stored_docs(manager: VectorStoreManager): + if manager.vectorstore is None: + return [] + doc_map = getattr(manager.vectorstore.docstore, "_dict", {}) + return [doc for doc in doc_map.values() if isinstance(doc, Document)] + + +def test_composite_requires_at_least_one_manager(): + with pytest.raises(ValueError, match="At least one sub-manager is required"): + CompositeVectorStoreManager(managers={}) + + +def test_manager_for_filter_routes_by_source_type(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + selected = composite._manager_for_filter({"source_type": "csv"}) + + assert selected is csv + + +def test_manager_for_filter_returns_none_without_source_type(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + assert composite._manager_for_filter({"source": "chapter-1"}) is None + assert composite._manager_for_filter(None) is None + + +def test_strip_source_type_filter_preserves_other_fields(): + stripped = CompositeVectorStoreManager._strip_source_type_filter( + {"source_type": "youtube", "source": "abc", "track_id": "t1"} + ) + + assert stripped == {"source": "abc", "track_id": "t1"} + + +def test_strip_source_type_filter_returns_none_when_only_source_type(): + assert ( + CompositeVectorStoreManager._strip_source_type_filter( + {"source_type": "youtube"} + ) + is None + ) + + +def test_add_documents_routes_docs_by_source_type(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + docs = [ + _make_doc("yt-1", source_type="youtube", source="yt1"), + _make_doc("csv-1", source_type="csv", source="csv1"), + _make_doc("unknown", source_type="pdf", source="pdf1"), + _make_doc("missing", source="no-type"), + ] + + composite.add_documents(docs) + + assert len(yt.added_batches) == 1 + assert len(csv.added_batches) == 1 + assert [doc.metadata.get("source") for doc in yt.added_batches[0]] == [ + "yt1", + "pdf1", + "no-type", + ] + assert [doc.metadata.get("source") for doc in csv.added_batches[0]] == ["csv1"] + + +def test_similarity_search_routes_to_specific_manager_with_source_type_filter(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + results = composite.similarity_search( + query="find topic", k=2, filter={"source_type": "csv", "source": "sheet1"} + ) + + assert len(csv.search_calls) == 1 + assert csv.search_calls[0]["filter"] == {"source": "sheet1"} + assert yt.search_calls == [] + assert all(doc.metadata["source_type"] == "csv" for doc in results) + + +def test_similarity_search_with_score_routes_with_source_type_filter(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + results = composite.similarity_search_with_score( + query="rank docs", + k=3, + filter={"source_type": "youtube", "source": "lecture-1"}, + fetch_k=77, + ) + + assert len(yt.search_with_score_calls) == 1 + assert yt.search_with_score_calls[0]["filter"] == {"source": "lecture-1"} + assert yt.search_with_score_calls[0]["fetch_k"] == 77 + assert csv.search_with_score_calls == [] + assert all(doc.metadata["source_type"] == "youtube" for doc, _ in results) + + +def test_similarity_search_without_filter_fans_out_to_all_managers(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + results = composite.similarity_search(query="biology", k=3) + + assert len(yt.search_calls) == 1 + assert len(csv.search_calls) == 1 + assert len(results) == 3 + assert [doc.page_content for doc in results] == [ + "youtube-result-0", + "youtube-result-1", + "youtube-result-2", + ] + + +def test_delete_by_source_fans_out_to_all_sub_managers(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + + def _yt_delete(source, track_id=None): + yt.delete_calls.append({"source": source, "track_id": track_id}) + return 2 + + def _csv_delete(source, track_id=None): + csv.delete_calls.append({"source": source, "track_id": track_id}) + return 3 + + yt.delete_by_source = _yt_delete + csv.delete_by_source = _csv_delete + + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + removed = composite.delete_by_source("shared-source", track_id="track-9") + + assert removed == 5 + assert yt.delete_calls == [{"source": "shared-source", "track_id": "track-9"}] + assert csv.delete_calls == [{"source": "shared-source", "track_id": "track-9"}] + + +def test_get_collection_info_includes_composite_and_sub_indexes(): + yt = RecordingManager("youtube") + csv = RecordingManager("csv") + composite = CompositeVectorStoreManager({"youtube": yt, "csv": csv}) + + info = composite.get_collection_info() + + assert info["type"] == "composite" + assert set(info["sub_indexes"].keys()) == {"youtube", "csv"} + assert info["sub_indexes"]["youtube"]["collection_name"] == "youtube" + assert info["sub_indexes"]["csv"]["collection_name"] == "csv" + assert info["embedding_model"] == yt.embedding_model + + +def test_build_composite_manager_creates_youtube_csv_with_shared_embeddings(tmp_path): + composite = build_composite_manager( + base_persist_directory=str(tmp_path / "indexes"), + embedding_provider="fake", + embedding_dimension=24, + ) + + assert isinstance(composite, CompositeVectorStoreManager) + assert set(composite._managers.keys()) == {"youtube", "csv"} + + yt = composite._managers["youtube"] + csv = composite._managers["csv"] + + assert yt.embeddings is csv.embeddings + assert yt.persist_directory.endswith(os.path.join("indexes", "youtube")) + assert csv.persist_directory.endswith(os.path.join("indexes", "csv")) + + +def test_create_vectorstore_routes_docs_to_correct_sub_managers(tmp_path): + composite = build_composite_manager( + base_persist_directory=str(tmp_path / "composite"), + embedding_provider="fake", + embedding_dimension=16, + ) + + docs = [ + _make_doc("yt-doc", source_type="youtube", source="yt-1"), + _make_doc("csv-doc", source_type="csv", source="csv-1"), + _make_doc("fallback-unknown", source_type="pdf", source="pdf-1"), + _make_doc("fallback-missing", source="none-1"), + ] + + composite.create_vectorstore(docs) + + yt_docs = _stored_docs(composite._managers["youtube"]) + csv_docs = _stored_docs(composite._managers["csv"]) + + assert len(yt_docs) == 3 + assert len(csv_docs) == 1 + assert {doc.metadata.get("source") for doc in yt_docs} == { + "yt-1", + "pdf-1", + "none-1", + } + assert {doc.metadata.get("source") for doc in csv_docs} == {"csv-1"} + + +def test_vectorstore_property_returns_first_non_none_vectorstore(tmp_path): + youtube_mgr = _make_real_manager(tmp_path / "youtube") + csv_mgr = _make_real_manager(tmp_path / "csv") + + composite = CompositeVectorStoreManager({"youtube": youtube_mgr, "csv": csv_mgr}) + + assert composite.vectorstore is None + + csv_mgr.create_vectorstore([_make_doc("csv-one", source_type="csv", source="c1")]) + assert composite.vectorstore is csv_mgr.vectorstore + + youtube_mgr.create_vectorstore( + [_make_doc("yt-one", source_type="youtube", source="y1")] + ) + assert composite.vectorstore is youtube_mgr.vectorstore diff --git a/tests/test_neet_rag_integration.py b/tests/test_neet_rag_integration.py new file mode 100644 index 0000000..8d02791 --- /dev/null +++ b/tests/test_neet_rag_integration.py @@ -0,0 +1,187 @@ +# pyright: reportMissingImports=false + +from pathlib import Path +import os +import sys +from unittest.mock import patch + +import pytest +from langchain_core.documents import Document + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from src.rag.llm_manager import LLMManager +from src.rag.neet_rag import NEETRAG +from src.rag.vector_store import CompositeVectorStoreManager, VectorStoreManager + + +EMBEDDING_DIM = 16 + + +@pytest.fixture(autouse=True) +def stub_llm_initialization(monkeypatch): + def _fake_initialize_llm(self, api_key=None, base_url=None): + self.llm = object() + + monkeypatch.setattr(LLMManager, "_initialize_llm", _fake_initialize_llm) + + +def _build_csv_doc(idx: int) -> Document: + return Document( + page_content=f"Biology QA pair {idx}: Cell is the basic unit of life.", + metadata={ + "source": f"csv://neet_bio_{idx}.csv", + "source_type": "csv", + "content_type": "csv_qa_pair", + "chapter_name": "Biology", + "question_id": str(idx), + }, + ) + + +def _build_youtube_doc(idx: int, video_id: str = "abc123def45") -> Document: + return Document( + page_content=f"YouTube transcript chunk {idx}: photosynthesis explained.", + metadata={ + "source": f"https://www.youtube.com/watch?v={video_id}", + "source_type": "youtube", + "content_type": "youtube", + "video_id": video_id, + "start_time": float(idx * 60), + }, + ) + + +def _create_index(index_dir: Path, docs: list[Document]) -> None: + manager = VectorStoreManager( + persist_directory=str(index_dir), + embedding_provider="fake", + embedding_dimension=EMBEDDING_DIM, + ) + manager.create_vectorstore(docs) + + +def _create_split_indexes( + tmp_path: Path, csv_count: int = 3, youtube_count: int = 2 +) -> tuple[Path, Path]: + youtube_dir = tmp_path / "youtube" + csv_dir = tmp_path / "csv" + + _create_index( + youtube_dir, [_build_youtube_doc(i) for i in range(1, youtube_count + 1)] + ) + _create_index(csv_dir, [_build_csv_doc(i) for i in range(1, csv_count + 1)]) + + return youtube_dir, csv_dir + + +def _create_single_index( + tmp_path: Path, csv_count: int = 2, youtube_count: int = 1 +) -> None: + docs = [_build_csv_doc(i) for i in range(1, csv_count + 1)] + [ + _build_youtube_doc(i) for i in range(1, youtube_count + 1) + ] + _create_index(tmp_path, docs) + + +def _build_rag(tmp_path: Path) -> NEETRAG: + rag = NEETRAG( + persist_directory=str(tmp_path), + embedding_provider="fake", + embedding_dimension=EMBEDDING_DIM, + llm_provider="openai", + llm_model="mock-model", + ) + rag.similarity_threshold = 0.0 + return rag + + +def test_neetrag_initializes_without_error_with_fake_embeddings(tmp_path): + _create_single_index(tmp_path) + + rag = _build_rag(tmp_path) + + assert rag is not None + assert isinstance(rag.vector_manager, VectorStoreManager) + + +def test_neetrag_detects_split_indexes_when_youtube_and_csv_subdirs_exist(tmp_path): + _create_split_indexes(tmp_path) + + rag = _build_rag(tmp_path) + + assert isinstance(rag.vector_manager, CompositeVectorStoreManager) + + +def test_neetrag_falls_back_to_single_index_without_split_subdirs(tmp_path): + _create_single_index(tmp_path) + + rag = _build_rag(tmp_path) + + assert isinstance(rag.vector_manager, VectorStoreManager) + assert not isinstance(rag.vector_manager, CompositeVectorStoreManager) + + +def test_retrieve_docs_blended_returns_only_csv_docs(tmp_path): + _create_split_indexes(tmp_path, csv_count=4, youtube_count=3) + rag = _build_rag(tmp_path) + rag.vector_manager.load_vectorstore() + + docs = rag._retrieve_docs_blended("cell biology question", top_k=5) + + assert docs + assert all(doc.metadata.get("source_type") == "csv" for doc in docs) + + +def test_retrieve_docs_blended_returns_max_three_docs(tmp_path): + _create_split_indexes(tmp_path, csv_count=6, youtube_count=1) + rag = _build_rag(tmp_path) + rag.vector_manager.load_vectorstore() + + docs = rag._retrieve_docs_blended("neet biology", top_k=10) + + assert len(docs) == 3 + + +def test_is_youtube_doc_identifies_youtube_documents(tmp_path): + _create_single_index(tmp_path) + doc = _build_youtube_doc(1) + + assert NEETRAG._is_youtube_doc(doc) is True + + +def test_is_youtube_doc_identifies_non_youtube_documents(tmp_path): + _create_single_index(tmp_path) + doc = _build_csv_doc(1) + + assert NEETRAG._is_youtube_doc(doc) is False + + +def test_query_returns_answer_sources_and_question_sources_keys(tmp_path): + _create_split_indexes(tmp_path, csv_count=4, youtube_count=3) + rag = _build_rag(tmp_path) + + with patch.object(LLMManager, "generate", return_value="Mocked LLM answer"): + result = rag.query("What is the basic unit of life?", top_k=5) + + assert result["answer"] == "Mocked LLM answer" + assert "sources" in result + assert "question_sources" in result + assert isinstance(result["sources"], list) + assert isinstance(result["question_sources"], list) + + +def test_format_youtube_url_generates_timestamped_url(): + url = NEETRAG._format_youtube_url( + "https://www.youtube.com/watch?v=abc123", "abc123", 120.0 + ) + + assert url == "https://www.youtube.com/watch?v=abc123&t=120s" + + +def test_format_youtube_url_generates_url_without_timestamp(): + url = NEETRAG._format_youtube_url( + "https://www.youtube.com/watch?v=abc123", "abc123", 0.0 + ) + + assert url == "https://www.youtube.com/watch?v=abc123" diff --git a/tests/test_prompt_builder.py b/tests/test_prompt_builder.py new file mode 100644 index 0000000..50f2b47 --- /dev/null +++ b/tests/test_prompt_builder.py @@ -0,0 +1,173 @@ +import os +import sys +from importlib import import_module +from typing import cast + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from src.rag.llm_manager import RAGPromptBuilder + + +def _doc(content: str, **metadata: object) -> object: + document_cls = cast( + type, getattr(import_module("langchain_core.documents"), "Document") + ) + return cast(object, document_cls(page_content=content, metadata=metadata)) + + +def test_youtube_docs_are_excluded_from_prompt(): + builder = RAGPromptBuilder() + docs = [ + _doc("YouTube transcript snippet", source_type="youtube"), + _doc("CSV PYQ content", source_type="csv"), + ] + + prompt = builder.build_prompt("Explain this", docs) + + assert "YouTube transcript snippet" not in prompt + assert "CSV PYQ content" in prompt + + +def test_csv_docs_get_previous_year_question_label(): + builder = RAGPromptBuilder() + + prompt = builder.build_prompt( + "Question?", [_doc("PYQ statement", source_type="csv")] + ) + + assert "--- Previous Year Question ---" in prompt + + +def test_csv_docs_with_chapter_name_get_label_with_chapter(): + builder = RAGPromptBuilder() + + prompt = builder.build_prompt( + "Question?", + [_doc("PYQ statement", source_type="csv", chapter_name="Organic Chemistry")], + ) + + assert "--- Previous Year Question (Organic Chemistry) ---" in prompt + + +def test_csv_docs_without_chapter_name_get_plain_label(): + builder = RAGPromptBuilder() + + prompt = builder.build_prompt( + "Question?", [_doc("Another PYQ", source_type="csv", chapter_name="")] + ) + + assert "--- Previous Year Question ---" in prompt + assert "--- Previous Year Question (" not in prompt + + +def test_all_youtube_docs_return_no_matching_message(): + builder = RAGPromptBuilder() + docs = [ + _doc("YT 1", source_type="youtube"), + _doc("YT 2", source_type="youtube"), + ] + + prompt = builder.build_prompt("Find PYQ", docs) + + assert prompt == "Question: Find PYQ\n\nNo matching previous year questions found." + + +def test_no_docs_return_no_matching_message(): + builder = RAGPromptBuilder() + + prompt = builder.build_prompt("Find PYQ", []) + + assert prompt == "Question: Find PYQ\n\nNo matching previous year questions found." + + +def test_prompt_ends_with_expected_guidance_suffix(): + builder = RAGPromptBuilder() + + prompt = builder.build_prompt("Help", [_doc("PYQ text", source_type="csv")]) + + assert prompt.endswith("Analyze the PYQs above and provide concise guidance.") + + +def test_query_text_appears_in_prompt(): + builder = RAGPromptBuilder() + query = "How to solve this NEET PYQ?" + + prompt = builder.build_prompt(query, [_doc("PYQ text", source_type="csv")]) + + assert f"Question: {query}" in prompt + + +def test_doc_content_appears_in_prompt(): + builder = RAGPromptBuilder() + content = "Given f(x)=x^2, find derivative." + + prompt = builder.build_prompt("Solve", [_doc(content, source_type="csv")]) + + assert f"Content: {content}" in prompt + + +def test_multiple_csv_docs_all_appear_in_prompt(): + builder = RAGPromptBuilder() + docs = [ + _doc("PYQ A", source_type="csv"), + _doc("PYQ B", source_type="csv"), + _doc("PYQ C", source_type="csv", chapter_name="Mechanics"), + ] + + prompt = builder.build_prompt("Compare", docs) + + assert "PYQ A" in prompt + assert "PYQ B" in prompt + assert "PYQ C" in prompt + + +def test_mixed_youtube_and_csv_only_csv_appears(): + builder = RAGPromptBuilder() + docs = [ + _doc("YouTube explanation", source_type="youtube"), + _doc("PYQ from CSV", source_type="csv"), + _doc("Another YouTube piece", content_type="youtube"), + ] + + prompt = builder.build_prompt("Analyze", docs) + + assert "PYQ from CSV" in prompt + assert "YouTube explanation" not in prompt + assert "Another YouTube piece" not in prompt + + +def test_build_with_history_includes_chat_history(): + builder = RAGPromptBuilder() + history: list[tuple[str, str]] = [ + ("What is osmosis?", "Movement of solvent across semipermeable membrane."), + ("And diffusion?", "Movement from high concentration to low concentration."), + ] + + prompt = builder.build_with_history( # pyright: ignore[reportUnknownMemberType] + "Give a quick comparison", + [_doc("PYQ on osmosis vs diffusion", source_type="csv")], + chat_history=history, + ) + + assert "Previous conversation:" in prompt + assert "User: What is osmosis?" in prompt + assert "Assistant: Movement of solvent across semipermeable membrane." in prompt + assert "User: And diffusion?" in prompt + assert "Assistant: Movement from high concentration to low concentration." in prompt + + +def test_custom_system_prompt_via_constructor(): + custom_prompt = "Custom instruction for strict PYQ tutoring." + + builder = RAGPromptBuilder(system_prompt=custom_prompt) + + assert builder.default_system_prompt == custom_prompt + + +def test_default_system_prompt_mentions_pyqs_not_video_or_youtube(): + builder = RAGPromptBuilder() + system_prompt = builder.default_system_prompt.lower() + + assert "pyqs" in system_prompt + assert "video" not in system_prompt + assert "youtube" not in system_prompt From 6900e52419a1eed9db48d859da13a5a15274f451 Mon Sep 17 00:00:00 2001 From: Jay Prakash Date: Mon, 16 Mar 2026 09:54:27 +0530 Subject: [PATCH 3/3] ci: add GitHub Actions workflow for mypy + pytest on PRs --- .github/workflows/test.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..1d86a29 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,29 @@ +name: Tests + +on: + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: pip-${{ hashFiles('requirements.txt') }} + restore-keys: pip- + + - run: pip install -r requirements.txt + + - name: Type check + run: mypy --config-file mypy.ini + + - name: Unit and integration tests + run: pytest tests/test_prompt_builder.py tests/test_composite_vector_store.py tests/test_neet_rag_integration.py -v