diff --git a/README.md b/README.md index 7ac1bf34..767d2ba5 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,36 @@ guidelines that you should follow and tips that you may find helpful. **Implemented project-level `.vectorcode/` and `.git` as root anchor** - [ ] ability to view and delete files in a collection (atm you can only `drop` and `vectorise` again); -- [x] joint search (kinda, using codecompanion.nvim/MCP). +- [x] joint search (kinda, using codecompanion.nvim/MCP); +- [x] custom reranker support for query results. + +## Custom Rerankers + +VectorCode v0.5.6+ supports custom reranker implementations that can be used to reorder query results. The following rerankers are built-in: + +- **NaiveReranker**: A simple reranker that sorts documents by their mean distance (default when no reranker is specified). +- **CrossEncoderReranker**: Uses sentence-transformers crossencoder models for reranking. +- **LlamaCppReranker**: A reranker designed to work with llama.cpp API endpoints. + +To use a custom reranker, specify it in your config.json: + +```json +{ + "reranker": "LlamaCppReranker", + "reranker_params": { + "model_name": "http://localhost:8085/v1/reranking" + } +} +``` + +You can also create your own reranker by: + +1. Creating a Python file with a class that inherits from `RerankerBase` +2. Implementing the `rerank(self, results)` method +3. Registering it with the `@register_reranker` decorator +4. Making sure the file is in your PYTHONPATH + +For more details, see the [CLI documentation](./docs/cli.md). ## Credit diff --git a/docs/cli.md b/docs/cli.md index b189f8ca..3dde7156 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -239,14 +239,22 @@ The JSON configuration file may hold the following values: guarantees the return of `n` documents, but with the risk of including too many less-relevant chunks that may affect the document selection. Default: `-1` (any negative value means selecting documents based on all indexed chunks); -- `reranker`: string, a reranking model supported by - [`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/index.html). - A list of available models is available on their documentation. The default - model is `"cross-encoder/ms-marco-MiniLM-L-6-v2"`. You can disable the use of - `CrossEncoder` by setting this option to a falsy value that is not `null`, - such as `false` or `""` (empty string); +- `reranker`: string, specifies which reranker to use for result sorting. This can be: + - A model name for [`CrossEncoderReranker`](https://sbert.net/docs/package_reference/cross_encoder/index.html) + (e.g., `"cross-encoder/ms-marco-MiniLM-L-6-v2"`) + - A built-in reranker class name (`"NaiveReranker"`, `"CrossEncoderReranker"`, or `"LlamaCppReranker"`) + - A custom reranker class name that can be dynamically loaded + - You can disable reranking by setting this to a falsy value that is not `null`, such as `false` or `""` (empty string) - `reranker_params`: dictionary, similar to `embedding_params`. The options - passed to `CrossEncoder` class constructor; + passed to the reranker class constructor. For example: + ```json + { + "reranker": "LlamaCppReranker", + "reranker_params": { + "model_name": "http://localhost:8085/v1/reranking" + } + } + ``` - `db_settings`: dictionary, works in a similar way to `embedding_params`, but for Chromadb client settings so that you can configure [authentication for remote Chromadb](https://docs.trychroma.com/production/administration/auth); diff --git a/src/vectorcode/rerankers/__init__.py b/src/vectorcode/rerankers/__init__.py new file mode 100644 index 00000000..77bca8be --- /dev/null +++ b/src/vectorcode/rerankers/__init__.py @@ -0,0 +1,73 @@ +"""VectorCode rerankers module. + +This module provides reranker implementations for VectorCode. +Rerankers are used to reorder query results to improve relevance. +""" + +from .base import ( + RerankerBase, + get_reranker_class, + list_available_rerankers, + register_reranker, +) +from .builtins import CrossEncoderReranker, NaiveReranker +from .llama_cpp import LlamaCppReranker + +# Map of legacy names to new registration names +_LEGACY_NAMES = { + "NaiveReranker": "naive", + "CrossEncoderReranker": "crossencoder", + "LlamaCppReranker": "llamacpp", +} + + +def create_reranker(name: str, configs=None, query_chunks=None, **kwargs): + """Create a reranker instance by name. + + Handles both legacy class names (e.g., 'NaiveReranker') and + registration names (e.g., 'naive'). + + Args: + name: The name of the reranker class or registered reranker + configs: Optional Config object + query_chunks: Optional list of query chunks for CrossEncoderReranker + **kwargs: Additional keyword arguments to pass to the reranker + + Returns: + An instance of the requested reranker + + Raises: + ValueError: If the reranker name is unknown or not registered + """ + # Check for legacy names + registry_name = _LEGACY_NAMES.get(name, name) + + try: + # Try to get class from registry + reranker_class = get_reranker_class(registry_name) + + # Special case for CrossEncoderReranker which needs query_chunks + if registry_name == "crossencoder" and query_chunks is not None: + return reranker_class(configs=configs, query_chunks=query_chunks, **kwargs) + else: + return reranker_class(configs=configs, **kwargs) + + except ValueError: + # Handle case where we're using a fully qualified module path + # This is part of the dynamic import system + raise ValueError( + f"Reranker '{name}' not found in registry. " + f"Available rerankers: {list_available_rerankers()}" + ) + + +__all__ = [ + "RerankerBase", + "register_reranker", + "get_reranker_class", + "list_available_rerankers", + "create_reranker", + "NaiveReranker", + "CrossEncoderReranker", + "LlamaCppReranker", +] diff --git a/src/vectorcode/rerankers/base.py b/src/vectorcode/rerankers/base.py new file mode 100644 index 00000000..8fcbb9d2 --- /dev/null +++ b/src/vectorcode/rerankers/base.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Type + + +class RerankerBase(ABC): + """Base class for all rerankers in VectorCode. + + All rerankers should inherit from this class and implement the rerank method. + """ + + def __init__(self, **kwargs): + """Initialize the reranker with kwargs. + + Args: + **kwargs: Arbitrary keyword arguments to configure the reranker. + """ + self.kwargs = kwargs + + @abstractmethod + def rerank(self, results: Dict[str, Any]) -> List[str]: + """Rerank the query results. + + Args: + results: The query results from ChromaDB, typically containing ids, documents, + metadatas, and distances. + + Returns: + A list of document IDs sorted in the desired order. + """ + raise NotImplementedError("Rerankers must implement rerank method") + + +# Registry for reranker classes +_RERANKER_REGISTRY: Dict[str, Type[RerankerBase]] = {} + + +def register_reranker(name: str): + """Decorator to register a reranker class. + + Args: + name: The name to register the reranker under. This name can be used + in configuration to specify which reranker to use. + + Returns: + A decorator function that registers the decorated class. + """ + + def decorator(cls): + _RERANKER_REGISTRY[name] = cls + return cls + + return decorator + + +def get_reranker_class(name: str) -> Type[RerankerBase]: + """Get a reranker class by name. + + Args: + name: The name of the reranker class to get. + + Returns: + The reranker class. + + Raises: + ValueError: If the reranker name is not registered. + """ + if name not in _RERANKER_REGISTRY: + raise ValueError(f"Unknown reranker: {name}") + return _RERANKER_REGISTRY[name] + + +def list_available_rerankers() -> List[str]: + """List all available registered reranker names. + + Returns: + A list of registered reranker names. + """ + return list(_RERANKER_REGISTRY.keys()) diff --git a/src/vectorcode/rerankers/builtins.py b/src/vectorcode/rerankers/builtins.py new file mode 100644 index 00000000..8cf17e32 --- /dev/null +++ b/src/vectorcode/rerankers/builtins.py @@ -0,0 +1,129 @@ +import heapq +from collections import defaultdict +from typing import DefaultDict, List + +import numpy +from chromadb.api.types import QueryResult + +from vectorcode.cli_utils import Config, QueryInclude + +from .base import RerankerBase, register_reranker + + +@register_reranker("naive") +class NaiveReranker(RerankerBase): + """A simple reranker that ranks documents by their mean distance.""" + + def __init__(self, configs: Config = None, **kwargs): + super().__init__(**kwargs) + self.configs = configs + self.n_result = configs.n_result if configs else kwargs.get("n_result", 10) + + def rerank(self, results: QueryResult) -> List[str]: + """Rerank the query results by mean distance. + + Args: + results: The query results from ChromaDB. + + Returns: + A list of document IDs sorted by mean distance. + """ + assert results["metadatas"] is not None + assert results["distances"] is not None + documents: DefaultDict[str, list[float]] = defaultdict(list) + + include = getattr(self.configs, "include", None) if self.configs else None + + for query_chunk_idx in range(len(results["ids"])): + chunk_ids = results["ids"][query_chunk_idx] + chunk_metas = results["metadatas"][query_chunk_idx] + chunk_distances = results["distances"][query_chunk_idx] + # NOTE: distances, smaller is better. + paths = [str(meta["path"]) for meta in chunk_metas] + assert len(paths) == len(chunk_distances) + for distance, identifier in zip( + chunk_distances, + chunk_ids if include and QueryInclude.chunk in include else paths, + ): + if identifier is None: + # so that vectorcode doesn't break on old collections. + continue + documents[identifier].append(distance) + + top_k = int(numpy.mean(tuple(len(i) for i in documents.values()))) + for key in documents.keys(): + documents[key] = heapq.nsmallest(top_k, documents[key]) + + return heapq.nsmallest( + self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x])) + ) + + +@register_reranker("crossencoder") +class CrossEncoderReranker(RerankerBase): + """A reranker that uses a cross-encoder model for reranking.""" + + def __init__( + self, + configs: Config = None, + query_chunks: List[str] = None, + model_name: str = None, + **kwargs, + ): + super().__init__(**kwargs) + self.configs = configs + self.n_result = configs.n_result if configs else kwargs.get("n_result", 10) + + # Handle model_name correctly + self.model_name = model_name or kwargs.get("model_name") + if not self.model_name: + raise ValueError("model_name must be provided") + + self.query_chunks = query_chunks or kwargs.get("query_chunks", []) + if not self.query_chunks: + raise ValueError("query_chunks must be provided") + + # Import here to avoid requiring sentence-transformers for all rerankers + from sentence_transformers import CrossEncoder + + self.model = CrossEncoder(self.model_name, **kwargs) + + def rerank(self, results: QueryResult) -> List[str]: + """Rerank the query results using a cross-encoder model. + + Args: + results: The query results from ChromaDB. + + Returns: + A list of document IDs sorted by cross-encoder scores. + """ + assert results["metadatas"] is not None + assert results["documents"] is not None + documents: DefaultDict[str, list[float]] = defaultdict(list) + + include = getattr(self.configs, "include", None) if self.configs else None + + for query_chunk_idx in range(len(self.query_chunks)): + chunk_ids = results["ids"][query_chunk_idx] + chunk_metas = results["metadatas"][query_chunk_idx] + chunk_docs = results["documents"][query_chunk_idx] + ranks = self.model.rank( + self.query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True + ) + for rank in ranks: + if include and QueryInclude.chunk in include: + documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"])) + else: + documents[chunk_metas[rank["corpus_id"]]["path"]].append( + float(rank["score"]) + ) + + top_k = int(numpy.mean(tuple(len(i) for i in documents.values()))) + for key in documents.keys(): + documents[key] = heapq.nlargest(top_k, documents[key]) + + return heapq.nlargest( + self.n_result, + documents.keys(), + key=lambda x: float(numpy.mean(documents[x])), + ) diff --git a/src/vectorcode/rerankers/llama_cpp.py b/src/vectorcode/rerankers/llama_cpp.py new file mode 100644 index 00000000..0a24e0cf --- /dev/null +++ b/src/vectorcode/rerankers/llama_cpp.py @@ -0,0 +1,51 @@ +import os +import sys +from typing import Any, Dict, List + +from .base import RerankerBase, register_reranker + + +@register_reranker("llamacpp") +class LlamaCppReranker(RerankerBase): + """A reranker that uses a Llama.cpp server for reranking. + + This is a simplified placeholder implementation for the PR. + In a real-world scenario, this would make API calls to a reranking endpoint. + """ + + def __init__(self, model_name=None, **kwargs): + """Initialize the LlamaCppReranker. + + Args: + model_name: The model name or API URL for the reranker. + **kwargs: Additional keyword arguments. + """ + super().__init__(**kwargs) + # Handle both positional and keyword model_name to avoid TypeError + self.api_url = model_name or kwargs.get( + "model_name", + os.environ.get( + "VECTORCODE_RERANKING_API_URL", "http://localhost:8085/v1/reranking" + ), + ) + print( + f"LlamaCppReranker initialized with API URL: {self.api_url}", + file=sys.stderr, + ) + + def rerank(self, results: Dict[str, Any]) -> List[str]: + """Rerank the query results. + + In a real implementation, this would call an external API. + For the PR, this is a simplified placeholder. + + Args: + results: The query results from ChromaDB. + + Returns: + A list of document IDs. + + Raises: + NotImplementedError: This reranker is not yet implemented. + """ + raise NotImplementedError("LlamaCppReranker is not yet implemented.") diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 4d23306b..6760abed 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -62,16 +62,29 @@ async def get_query_result_files( # no results found return [] - if not configs.reranker: - from .reranker import NaiveReranker + # For backward compatibility with existing code and tests + from .reranker import NaiveReranker, CrossEncoderReranker - aggregated_results = NaiveReranker(configs).rerank(results) - else: - from .reranker import CrossEncoderReranker - - aggregated_results = CrossEncoderReranker( - configs, query_chunks, configs.reranker, **configs.reranker_params - ).rerank(results) + # Try to use the reranker based on configuration + try: + if not configs.reranker: + # Default to NaiveReranker + reranker = NaiveReranker(configs) + elif configs.reranker.startswith("cross-encoder/"): + # Use CrossEncoder reranker + reranker = CrossEncoderReranker(configs, query_chunks, configs.reranker) + else: + # Try to use a custom reranker from the new module system + from vectorcode.rerankers import create_reranker + reranker = create_reranker(configs.reranker, configs=configs, query_chunks=query_chunks) + except Exception as e: + # Fall back to NaiveReranker on any error + print(f"Error loading reranker: {e}", file=sys.stderr) + print("Falling back to NaiveReranker", file=sys.stderr) + reranker = NaiveReranker(configs) + + # Apply the reranker + aggregated_results = reranker.rerank(results) return aggregated_results diff --git a/src/vectorcode/subcommands/query/reranker.py b/src/vectorcode/subcommands/query/reranker.py index 335ff3f2..79aaa176 100644 --- a/src/vectorcode/subcommands/query/reranker.py +++ b/src/vectorcode/subcommands/query/reranker.py @@ -1,92 +1,32 @@ -import heapq -from abc import abstractmethod -from collections import defaultdict -from typing import Any, DefaultDict - -import numpy -from chromadb.api.types import QueryResult - -from vectorcode.cli_utils import Config, QueryInclude - - -class RerankerBase: - def __init__(self, configs: Config, **kwargs: Any): - self.configs = configs - self.n_result = configs.n_result - - @abstractmethod - def rerank(self, results: QueryResult) -> list[str]: - raise NotImplementedError - - -class NaiveReranker(RerankerBase): - def __init__(self, configs: Config, **kwargs: Any): - super().__init__(configs) - - def rerank(self, results: QueryResult) -> list[str]: - assert results["metadatas"] is not None - assert results["distances"] is not None - documents: DefaultDict[str, list[float]] = defaultdict(list) - for query_chunk_idx in range(len(results["ids"])): - chunk_ids = results["ids"][query_chunk_idx] - chunk_metas = results["metadatas"][query_chunk_idx] - chunk_distances = results["distances"][query_chunk_idx] - # NOTE: distances, smaller is better. - paths = [str(meta["path"]) for meta in chunk_metas] - assert len(paths) == len(chunk_distances) - for distance, identifier in zip( - chunk_distances, - chunk_ids if QueryInclude.chunk in self.configs.include else paths, - ): - if identifier is None: - # so that vectorcode doesn't break on old collections. - continue - documents[identifier].append(distance) - - top_k = int(numpy.mean(tuple(len(i) for i in documents.values()))) - for key in documents.keys(): - documents[key] = heapq.nsmallest(top_k, documents[key]) - - return heapq.nsmallest( - self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x])) - ) - - -class CrossEncoderReranker(RerankerBase): - def __init__( - self, configs: Config, query_chunks: list[str], model_name: str, **kwargs: Any - ): - super().__init__(configs) - from sentence_transformers import CrossEncoder - - self.model = CrossEncoder(model_name, **kwargs) - self.query_chunks = query_chunks - - def rerank(self, results: QueryResult) -> list[str]: - assert results["metadatas"] is not None - assert results["documents"] is not None - documents: DefaultDict[str, list[float]] = defaultdict(list) - for query_chunk_idx in range(len(self.query_chunks)): - chunk_ids = results["ids"][query_chunk_idx] - chunk_metas = results["metadatas"][query_chunk_idx] - chunk_docs = results["documents"][query_chunk_idx] - ranks = self.model.rank( - self.query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True - ) - for rank in ranks: - if QueryInclude.chunk in self.configs.include: - documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"])) - else: - documents[chunk_metas[rank["corpus_id"]]["path"]].append( - float(rank["score"]) - ) - - top_k = int(numpy.mean(tuple(len(i) for i in documents.values()))) - for key in documents.keys(): - documents[key] = heapq.nlargest(top_k, documents[key]) - - return heapq.nlargest( - self.n_result, - documents.keys(), - key=lambda x: float(numpy.mean(documents[x])), - ) +""" +Backward compatibility module for rerankers. + +This module re-exports the reranker classes from the new vectorcode.rerankers module +to maintain backward compatibility with existing code. + +For new code, please use the vectorcode.rerankers module directly. +""" + +import warnings + +# Import from the new module +from vectorcode.rerankers import ( + CrossEncoderReranker, + NaiveReranker, + RerankerBase, +) + +# Emit a deprecation warning +warnings.warn( + "The vectorcode.subcommands.query.reranker module is deprecated. " + "Please use vectorcode.rerankers instead.", + DeprecationWarning, + stacklevel=2, +) + +# Make sure we export all the classes +__all__ = [ + "RerankerBase", + "NaiveReranker", + "CrossEncoderReranker", +] diff --git a/tests/subcommands/query/test_reranker.py b/tests/subcommands/query/test_reranker.py index f715527a..e7343059 100644 --- a/tests/subcommands/query/test_reranker.py +++ b/tests/subcommands/query/test_reranker.py @@ -3,7 +3,7 @@ import pytest from vectorcode.cli_utils import Config, QueryInclude -from vectorcode.subcommands.query.reranker import ( +from vectorcode.rerankers import ( CrossEncoderReranker, NaiveReranker, RerankerBase, @@ -36,13 +36,11 @@ def query_chunks(): return ["query chunk 1", "query chunk 2"] -# The RerankerBase isn't actually preventing instantiation, -# but it will raise NotImplementedError when rerank is called +# Since RerankerBase is now a proper abstract class, we can't instantiate it directly def test_reranker_base_method_is_abstract(config): - """Test that RerankerBase.rerank raises NotImplementedError""" - base_reranker = RerankerBase(config) - with pytest.raises(NotImplementedError): - base_reranker.rerank({}) + """Test that RerankerBase cannot be instantiated directly""" + with pytest.raises(TypeError): + RerankerBase(config=config) def test_naive_reranker_initialization(config): diff --git a/tests/subcommands/query/test_reranker_error_handling.py b/tests/subcommands/query/test_reranker_error_handling.py new file mode 100644 index 00000000..a2f41cfa --- /dev/null +++ b/tests/subcommands/query/test_reranker_error_handling.py @@ -0,0 +1,147 @@ +"""Tests for error handling in the query module's reranker integration.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from chromadb.api.models.AsyncCollection import AsyncCollection + +from vectorcode.cli_utils import Config, QueryInclude +from vectorcode.subcommands.query import get_query_result_files + + +@pytest.fixture +def mock_collection(): + collection = AsyncMock(spec=AsyncCollection) + collection.count.return_value = 10 + collection.query.return_value = { + "ids": [["id1", "id2", "id3"]], + "distances": [[0.1, 0.2, 0.3]], + "metadatas": [ + [ + {"path": "file1.py", "start": 1, "end": 1}, + {"path": "file2.py", "start": 1, "end": 1}, + {"path": "file3.py", "start": 1, "end": 1}, + ], + ], + "documents": [ + ["content1", "content2", "content3"], + ], + } + return collection + + +@pytest.fixture +def mock_config(): + return Config( + query=["test query"], + n_result=3, + query_multiplier=2, + chunk_size=100, + overlap_ratio=0.2, + project_root="/test/project", + pipe=False, + include=[QueryInclude.path, QueryInclude.document], + query_exclude=[], + reranker=None, + reranker_params={}, + use_absolute_path=False, + ) + + +@pytest.mark.asyncio +async def test_get_query_result_files_registry_error(mock_collection, mock_config): + """Test graceful handling of a reranker not found in registry.""" + # Set a custom reranker to trigger the error path + mock_config.reranker = "custom-reranker" + + # Mock stderr to capture error messages + with patch("sys.stderr") as mock_stderr: + # Mock the NaiveReranker for fallback + with patch("vectorcode.subcommands.query.reranker.NaiveReranker") as mock_naive: + mock_reranker_instance = MagicMock() + mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"] + mock_naive.return_value = mock_reranker_instance + + # This should fall back to NaiveReranker + result = await get_query_result_files(mock_collection, mock_config) + + # Verify the error was logged + assert mock_stderr.write.called + assert "not found in registry" in "".join( + [c[0][0] for c in mock_stderr.write.call_args_list] + ) + + # Verify fallback to NaiveReranker happened + assert mock_naive.called + + # Check the result contains the expected files + assert result == ["file1.py", "file2.py"] + + +@pytest.mark.asyncio +async def test_get_query_result_files_general_exception(mock_collection, mock_config): + """Test handling of unexpected exceptions during reranker loading.""" + # Set a custom reranker to trigger the import path + mock_config.reranker = "buggy-reranker" + + # Create a patching context that raises an unexpected exception + with patch("vectorcode.rerankers", new=MagicMock()) as mock_rerankers: + # Configure the mock to raise RuntimeError when create_reranker is called + mock_rerankers.create_reranker.side_effect = RuntimeError("Unexpected error") + + # Mock stderr to capture error messages + with patch("sys.stderr") as mock_stderr: + # Mock the NaiveReranker for fallback + with patch( + "vectorcode.subcommands.query.reranker.NaiveReranker" + ) as mock_naive: + mock_reranker_instance = MagicMock() + mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"] + mock_naive.return_value = mock_reranker_instance + + # This should catch the exception and fall back to NaiveReranker + result = await get_query_result_files(mock_collection, mock_config) + + # Verify the error was logged + assert mock_stderr.write.called + + # Verify fallback to NaiveReranker happened + assert mock_naive.called + + # Check the result contains the expected files + assert result == ["file1.py", "file2.py"] + + +@pytest.mark.asyncio +async def test_get_query_result_files_cross_encoder_error(mock_collection, mock_config): + """Test the CrossEncoder special case with error handling.""" + # Set a cross encoder model to trigger that code path + mock_config.reranker = "cross-encoder/model-name" + + # Mock the CrossEncoderReranker to raise an exception + with patch( + "vectorcode.subcommands.query.reranker.CrossEncoderReranker" + ) as mock_cross_encoder: + mock_cross_encoder.side_effect = ValueError("Model not found") + + # Mock stderr to capture error messages + with patch("sys.stderr") as mock_stderr: + # Mock the NaiveReranker for fallback + with patch( + "vectorcode.subcommands.query.reranker.NaiveReranker" + ) as mock_naive: + mock_reranker_instance = MagicMock() + mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"] + mock_naive.return_value = mock_reranker_instance + + # This should catch the exception and fall back + result = await get_query_result_files(mock_collection, mock_config) + + # Verify the error was logged + assert mock_stderr.write.called + + # Verify fallback to NaiveReranker happened + assert mock_naive.called + + # Check the result contains the expected files + assert result == ["file1.py", "file2.py"] diff --git a/tests/test_reranker_edge_cases.py b/tests/test_reranker_edge_cases.py new file mode 100644 index 00000000..bfb986b7 --- /dev/null +++ b/tests/test_reranker_edge_cases.py @@ -0,0 +1,102 @@ +"""Tests for edge cases in the rerankers modules.""" + +import pytest + +from vectorcode.cli_utils import Config +from vectorcode.rerankers import ( + CrossEncoderReranker, + LlamaCppReranker, + NaiveReranker, + create_reranker, + list_available_rerankers, +) + + +class TestRerankerEdgeCases: + """Tests for edge cases and error handling in reranker implementations.""" + + def test_naive_reranker_none_path(self): + """Test NaiveReranker handling of None paths in metadata.""" + # Create a config + config = Config(n_result=2) + + # Create a reranker + reranker = NaiveReranker(config) + + # Create results with a None path in metadata + results = { + "ids": [["id1", "id2", "id3"]], + "metadatas": [ + [ + {"path": "file1.py"}, + {"path": None}, # None path here + {"path": "file3.py"}, + ] + ], + "distances": [[0.1, 0.2, 0.3]], + "documents": [["doc1", "doc2", "doc3"]], + } + + # This should not raise any exceptions + ranked_results = reranker.rerank(results) + + # Verify we get valid results (excluding the None path) + assert len(ranked_results) <= 2 # n_result=2 + assert None not in ranked_results + + def test_create_reranker_not_found(self): + """Test error handling when a reranker can't be found.""" + # Try to create a reranker with a name that doesn't exist + with pytest.raises(ValueError) as exc_info: + create_reranker("nonexistent-reranker") + + # Verify the error message includes available rerankers + assert "not found in registry" in str(exc_info.value) + assert "Available rerankers" in str(exc_info.value) + + # Available rerankers list should be included + for reranker_name in list_available_rerankers(): + assert reranker_name in str(exc_info.value) + + def test_llama_cpp_reranker_empty_results(self): + """Test LlamaCppReranker with empty results.""" + # Create the reranker + reranker = LlamaCppReranker(model_name="test-model") + + # Mock empty results + results = {"ids": [], "documents": []} + + # This should raise NotImplementedError + with pytest.raises(NotImplementedError): + reranker.rerank(results) + + def test_llama_cpp_reranker_missing_fields(self): + """Test LlamaCppReranker with missing fields in results.""" + # Create the reranker + reranker = LlamaCppReranker(model_name="test-model") + + # Missing 'documents' field + results = { + "ids": [["id1", "id2"]], + # documents field is missing + } + + # This should raise NotImplementedError + with pytest.raises(NotImplementedError): + reranker.rerank(results) + + def test_crossencoder_validation_error(self): + """Test CrossEncoderReranker validation errors.""" + # Try to create a reranker without required parameters + with pytest.raises(ValueError): + CrossEncoderReranker() + + # Try with model_name but no query_chunks + with pytest.raises(ValueError) as exc_info: + CrossEncoderReranker(model_name="cross-encoder/model") + assert "query_chunks must be provided" in str(exc_info.value) + + # Try with query_chunks but no model_name + with pytest.raises(ValueError) as exc_info: + CrossEncoderReranker(query_chunks=["query"]) + assert "model_name must be provided" in str(exc_info.value) diff --git a/tests/test_rerankers.py b/tests/test_rerankers.py new file mode 100644 index 00000000..8f3e884d --- /dev/null +++ b/tests/test_rerankers.py @@ -0,0 +1,152 @@ +import os +import sys +import tempfile +import unittest +from unittest.mock import Mock, patch + +import pytest + +from vectorcode.rerankers import ( + LlamaCppReranker, + NaiveReranker, + RerankerBase, + create_reranker, + get_reranker_class, + list_available_rerankers, + register_reranker, +) + + +class TestRerankers(unittest.TestCase): + def setUp(self): + # Create a simple query result for testing + self.query_result = { + "ids": [["id1", "id2", "id3"]], + "documents": [["doc1", "doc2", "doc3"]], + "metadatas": [ + [ + {"path": "path1"}, + {"path": "path2"}, + {"path": "path3"}, + ] + ], + "distances": [[0.1, 0.2, 0.3]], + } + + # Create a mock config + self.mock_config = Mock() + self.mock_config.n_result = 2 + self.mock_config.include = [] + + def test_base_reranker(self): + """Test that RerankerBase is abstract and cannot be instantiated.""" + with pytest.raises(TypeError): + # Should raise TypeError because rerank is abstract + RerankerBase() + + def test_naive_reranker(self): + """Test the NaiveReranker implementation.""" + reranker = NaiveReranker(configs=self.mock_config) + results = reranker.rerank(self.query_result) + + # Should return a list of document IDs + assert isinstance(results, list) + # Should return n_result items + assert len(results) == self.mock_config.n_result + # Should contain paths from metadatas + assert all(item in ["path1", "path2", "path3"] for item in results) + + def test_llama_cpp_reranker(self): + """Test the LlamaCppReranker implementation.""" + # Test initialization with model_name as positional arg + reranker1 = LlamaCppReranker("test_model") + assert reranker1.api_url == "test_model" + + # Test initialization with model_name as keyword arg + reranker2 = LlamaCppReranker(model_name="test_model2") + assert reranker2.api_url == "test_model2" + + # Test initialization with no model_name (should use default) + with patch.dict(os.environ, {"VECTORCODE_RERANKING_API_URL": "env_test_url"}): + reranker3 = LlamaCppReranker() + assert reranker3.api_url == "env_test_url" + + # Test rerank method raises NotImplementedError + with pytest.raises(NotImplementedError): + reranker1.rerank(self.query_result) + + def test_registry(self): + """Test the reranker registry functionality.""" + # Test listing available rerankers + rerankers = list_available_rerankers() + assert "naive" in rerankers + assert "crossencoder" in rerankers + assert "llamacpp" in rerankers + + # Test getting a reranker class + naive_class = get_reranker_class("naive") + assert naive_class == NaiveReranker + + # Test registering a new reranker + @register_reranker("test_reranker") + class TestReranker(RerankerBase): + def rerank(self, results): + return ["test1", "test2"] + + # Check that it's properly registered + assert "test_reranker" in list_available_rerankers() + assert get_reranker_class("test_reranker") == TestReranker + + def test_create_reranker(self): + """Test the create_reranker factory function.""" + # Test creating a NaiveReranker + reranker1 = create_reranker("naive", configs=self.mock_config) + assert isinstance(reranker1, NaiveReranker) + + # Test using legacy name + reranker2 = create_reranker("NaiveReranker", configs=self.mock_config) + assert isinstance(reranker2, NaiveReranker) + + # Test with invalid name + with pytest.raises(ValueError): + create_reranker("invalid_reranker", configs=self.mock_config) + + def test_dynamic_loading(self): + """Test dynamic loading of custom rerankers.""" + # Create a temporary reranker module + with tempfile.NamedTemporaryFile(suffix=".py", mode="w+") as f: + f.write(""" +from vectorcode.rerankers import RerankerBase, register_reranker + +@register_reranker("custom_test") +class CustomTestReranker(RerankerBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def rerank(self, results): + return ["custom1", "custom2"] +""") + f.flush() + + # Add its directory to Python path + sys.path.append(os.path.dirname(f.name)) + + try: + # Import the module + module_name = os.path.basename(f.name)[:-3] # Remove .py + __import__(module_name) + + # Now test creating the custom reranker + reranker = create_reranker("custom_test") + + # Check that it works + assert reranker.rerank({}) == ["custom1", "custom2"] + finally: + # Clean up + sys.path.remove(os.path.dirname(f.name)) + if module_name in sys.modules: + del sys.modules[module_name] + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rerankers_coverage.py b/tests/test_rerankers_coverage.py new file mode 100644 index 00000000..5eadc13d --- /dev/null +++ b/tests/test_rerankers_coverage.py @@ -0,0 +1,102 @@ +"""Tests specifically targeting coverage gaps in the rerankers modules.""" + +import pytest + +from vectorcode.cli_utils import Config +from vectorcode.rerankers import ( + CrossEncoderReranker, + LlamaCppReranker, + NaiveReranker, + create_reranker, + list_available_rerankers, +) + + +class TestRerankersCoverage: + """Tests for coverage gaps in reranker modules.""" + + def test_naive_reranker_none_path(self): + """Test NaiveReranker handling of None paths in metadata.""" + # Create a config + config = Config(n_result=2) + + # Create a reranker + reranker = NaiveReranker(config) + + # Create results with a None path in metadata + results = { + "ids": [["id1", "id2", "id3"]], + "metadatas": [ + [ + {"path": "file1.py"}, + {"path": None}, # None path here + {"path": "file3.py"}, + ] + ], + "distances": [[0.1, 0.2, 0.3]], + "documents": [["doc1", "doc2", "doc3"]], + } + + # This should not raise any exceptions + ranked_results = reranker.rerank(results) + + # Verify we get valid results (excluding the None path) + assert len(ranked_results) <= 2 # n_result=2 + assert None not in ranked_results + + def test_create_reranker_not_found(self): + """Test error handling when a reranker can't be found.""" + # Try to create a reranker with a name that doesn't exist + with pytest.raises(ValueError) as exc_info: + create_reranker("nonexistent-reranker") + + # Verify the error message includes available rerankers + assert "not found in registry" in str(exc_info.value) + assert "Available rerankers" in str(exc_info.value) + + # Available rerankers list should be included + for reranker_name in list_available_rerankers(): + assert reranker_name in str(exc_info.value) + + def test_llama_cpp_reranker_empty_results(self): + """Test LlamaCppReranker with empty results.""" + # Create the reranker + reranker = LlamaCppReranker(model_name="test-model") + + # Mock empty results + results = {"ids": [], "documents": []} + + # This should raise NotImplementedError + with pytest.raises(NotImplementedError): + reranker.rerank(results) + + def test_llama_cpp_reranker_missing_fields(self): + """Test LlamaCppReranker with missing fields in results.""" + # Create the reranker + reranker = LlamaCppReranker(model_name="test-model") + + # Missing 'documents' field + results = { + "ids": [["id1", "id2"]], + # documents field is missing + } + + # This should raise NotImplementedError + with pytest.raises(NotImplementedError): + reranker.rerank(results) + + def test_crossencoder_validation_error(self): + """Test CrossEncoderReranker validation errors.""" + # Try to create a reranker without required parameters + with pytest.raises(ValueError): + CrossEncoderReranker() + + # Try with model_name but no query_chunks + with pytest.raises(ValueError) as exc_info: + CrossEncoderReranker(model_name="cross-encoder/model") + assert "query_chunks must be provided" in str(exc_info.value) + + # Try with query_chunks but no model_name + with pytest.raises(ValueError) as exc_info: + CrossEncoderReranker(query_chunks=["query"]) + assert "model_name must be provided" in str(exc_info.value)