Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,36 @@ guidelines that you should follow and tips that you may find helpful.
**Implemented project-level `.vectorcode/` and `.git` as root anchor**
- [ ] ability to view and delete files in a collection (atm you can only `drop`
and `vectorise` again);
- [x] joint search (kinda, using codecompanion.nvim/MCP).
- [x] joint search (kinda, using codecompanion.nvim/MCP);
- [x] custom reranker support for query results.

## Custom Rerankers

VectorCode v0.5.6+ supports custom reranker implementations that can be used to reorder query results. The following rerankers are built-in:

- **NaiveReranker**: A simple reranker that sorts documents by their mean distance (default when no reranker is specified).
- **CrossEncoderReranker**: Uses sentence-transformers crossencoder models for reranking.
- **LlamaCppReranker**: A reranker designed to work with llama.cpp API endpoints.

To use a custom reranker, specify it in your config.json:

```json
{
"reranker": "LlamaCppReranker",
"reranker_params": {
"model_name": "http://localhost:8085/v1/reranking"
}
}
```

You can also create your own reranker by:

1. Creating a Python file with a class that inherits from `RerankerBase`
2. Implementing the `rerank(self, results)` method
3. Registering it with the `@register_reranker` decorator
4. Making sure the file is in your PYTHONPATH

For more details, see the [CLI documentation](./docs/cli.md).

## Credit

Expand Down
22 changes: 15 additions & 7 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,22 @@ The JSON configuration file may hold the following values:
guarantees the return of `n` documents, but with the risk of including too
many less-relevant chunks that may affect the document selection. Default:
`-1` (any negative value means selecting documents based on all indexed chunks);
- `reranker`: string, a reranking model supported by
[`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/index.html).
A list of available models is available on their documentation. The default
model is `"cross-encoder/ms-marco-MiniLM-L-6-v2"`. You can disable the use of
`CrossEncoder` by setting this option to a falsy value that is not `null`,
such as `false` or `""` (empty string);
- `reranker`: string, specifies which reranker to use for result sorting. This can be:
- A model name for [`CrossEncoderReranker`](https://sbert.net/docs/package_reference/cross_encoder/index.html)
(e.g., `"cross-encoder/ms-marco-MiniLM-L-6-v2"`)
- A built-in reranker class name (`"NaiveReranker"`, `"CrossEncoderReranker"`, or `"LlamaCppReranker"`)
- A custom reranker class name that can be dynamically loaded
- You can disable reranking by setting this to a falsy value that is not `null`, such as `false` or `""` (empty string)
- `reranker_params`: dictionary, similar to `embedding_params`. The options
passed to `CrossEncoder` class constructor;
passed to the reranker class constructor. For example:
```json
{
"reranker": "LlamaCppReranker",
"reranker_params": {
"model_name": "http://localhost:8085/v1/reranking"
}
}
```
- `db_settings`: dictionary, works in a similar way to `embedding_params`, but
for Chromadb client settings so that you can configure
[authentication for remote Chromadb](https://docs.trychroma.com/production/administration/auth);
Expand Down
73 changes: 73 additions & 0 deletions src/vectorcode/rerankers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""VectorCode rerankers module.

This module provides reranker implementations for VectorCode.
Rerankers are used to reorder query results to improve relevance.
"""

from .base import (
RerankerBase,
get_reranker_class,
list_available_rerankers,
register_reranker,
)
from .builtins import CrossEncoderReranker, NaiveReranker
from .llama_cpp import LlamaCppReranker

# Map of legacy names to new registration names
_LEGACY_NAMES = {
"NaiveReranker": "naive",
"CrossEncoderReranker": "crossencoder",
"LlamaCppReranker": "llamacpp",
}


def create_reranker(name: str, configs=None, query_chunks=None, **kwargs):
"""Create a reranker instance by name.

Handles both legacy class names (e.g., 'NaiveReranker') and
registration names (e.g., 'naive').

Args:
name: The name of the reranker class or registered reranker
configs: Optional Config object
query_chunks: Optional list of query chunks for CrossEncoderReranker
**kwargs: Additional keyword arguments to pass to the reranker

Returns:
An instance of the requested reranker

Raises:
ValueError: If the reranker name is unknown or not registered
"""
# Check for legacy names
registry_name = _LEGACY_NAMES.get(name, name)

try:
# Try to get class from registry
reranker_class = get_reranker_class(registry_name)

# Special case for CrossEncoderReranker which needs query_chunks
if registry_name == "crossencoder" and query_chunks is not None:
return reranker_class(configs=configs, query_chunks=query_chunks, **kwargs)

Check warning on line 51 in src/vectorcode/rerankers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/vectorcode/rerankers/__init__.py#L51

Added line #L51 was not covered by tests
else:
return reranker_class(configs=configs, **kwargs)

except ValueError:
# Handle case where we're using a fully qualified module path
# This is part of the dynamic import system
raise ValueError(
f"Reranker '{name}' not found in registry. "
f"Available rerankers: {list_available_rerankers()}"
)


__all__ = [
"RerankerBase",
"register_reranker",
"get_reranker_class",
"list_available_rerankers",
"create_reranker",
"NaiveReranker",
"CrossEncoderReranker",
"LlamaCppReranker",
]
78 changes: 78 additions & 0 deletions src/vectorcode/rerankers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type


class RerankerBase(ABC):
"""Base class for all rerankers in VectorCode.

All rerankers should inherit from this class and implement the rerank method.
"""

def __init__(self, **kwargs):
"""Initialize the reranker with kwargs.

Args:
**kwargs: Arbitrary keyword arguments to configure the reranker.
"""
self.kwargs = kwargs

@abstractmethod
def rerank(self, results: Dict[str, Any]) -> List[str]:
"""Rerank the query results.

Args:
results: The query results from ChromaDB, typically containing ids, documents,
metadatas, and distances.

Returns:
A list of document IDs sorted in the desired order.
"""
raise NotImplementedError("Rerankers must implement rerank method")

Check warning on line 30 in src/vectorcode/rerankers/base.py

View check run for this annotation

Codecov / codecov/patch

src/vectorcode/rerankers/base.py#L30

Added line #L30 was not covered by tests


# Registry for reranker classes
_RERANKER_REGISTRY: Dict[str, Type[RerankerBase]] = {}


def register_reranker(name: str):
"""Decorator to register a reranker class.

Args:
name: The name to register the reranker under. This name can be used
in configuration to specify which reranker to use.

Returns:
A decorator function that registers the decorated class.
"""

def decorator(cls):
_RERANKER_REGISTRY[name] = cls
return cls

return decorator


def get_reranker_class(name: str) -> Type[RerankerBase]:
"""Get a reranker class by name.

Args:
name: The name of the reranker class to get.

Returns:
The reranker class.

Raises:
ValueError: If the reranker name is not registered.
"""
if name not in _RERANKER_REGISTRY:
raise ValueError(f"Unknown reranker: {name}")
return _RERANKER_REGISTRY[name]


def list_available_rerankers() -> List[str]:
"""List all available registered reranker names.

Returns:
A list of registered reranker names.
"""
return list(_RERANKER_REGISTRY.keys())
129 changes: 129 additions & 0 deletions src/vectorcode/rerankers/builtins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import heapq
from collections import defaultdict
from typing import DefaultDict, List

import numpy
from chromadb.api.types import QueryResult

from vectorcode.cli_utils import Config, QueryInclude

from .base import RerankerBase, register_reranker


@register_reranker("naive")
class NaiveReranker(RerankerBase):
"""A simple reranker that ranks documents by their mean distance."""

def __init__(self, configs: Config = None, **kwargs):
super().__init__(**kwargs)
self.configs = configs
self.n_result = configs.n_result if configs else kwargs.get("n_result", 10)

def rerank(self, results: QueryResult) -> List[str]:
"""Rerank the query results by mean distance.

Args:
results: The query results from ChromaDB.

Returns:
A list of document IDs sorted by mean distance.
"""
assert results["metadatas"] is not None
assert results["distances"] is not None
documents: DefaultDict[str, list[float]] = defaultdict(list)

include = getattr(self.configs, "include", None) if self.configs else None

for query_chunk_idx in range(len(results["ids"])):
chunk_ids = results["ids"][query_chunk_idx]
chunk_metas = results["metadatas"][query_chunk_idx]
chunk_distances = results["distances"][query_chunk_idx]
# NOTE: distances, smaller is better.
paths = [str(meta["path"]) for meta in chunk_metas]
assert len(paths) == len(chunk_distances)
for distance, identifier in zip(
chunk_distances,
chunk_ids if include and QueryInclude.chunk in include else paths,
):
if identifier is None:
# so that vectorcode doesn't break on old collections.
continue

Check warning on line 50 in src/vectorcode/rerankers/builtins.py

View check run for this annotation

Codecov / codecov/patch

src/vectorcode/rerankers/builtins.py#L50

Added line #L50 was not covered by tests
documents[identifier].append(distance)

top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
for key in documents.keys():
documents[key] = heapq.nsmallest(top_k, documents[key])

return heapq.nsmallest(
self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x]))
)


@register_reranker("crossencoder")
class CrossEncoderReranker(RerankerBase):
"""A reranker that uses a cross-encoder model for reranking."""

def __init__(
self,
configs: Config = None,
query_chunks: List[str] = None,
model_name: str = None,
**kwargs,
):
super().__init__(**kwargs)
self.configs = configs
self.n_result = configs.n_result if configs else kwargs.get("n_result", 10)

# Handle model_name correctly
self.model_name = model_name or kwargs.get("model_name")
if not self.model_name:
raise ValueError("model_name must be provided")

self.query_chunks = query_chunks or kwargs.get("query_chunks", [])
if not self.query_chunks:
raise ValueError("query_chunks must be provided")

# Import here to avoid requiring sentence-transformers for all rerankers
from sentence_transformers import CrossEncoder

self.model = CrossEncoder(self.model_name, **kwargs)

def rerank(self, results: QueryResult) -> List[str]:
"""Rerank the query results using a cross-encoder model.

Args:
results: The query results from ChromaDB.

Returns:
A list of document IDs sorted by cross-encoder scores.
"""
assert results["metadatas"] is not None
assert results["documents"] is not None
documents: DefaultDict[str, list[float]] = defaultdict(list)

include = getattr(self.configs, "include", None) if self.configs else None

for query_chunk_idx in range(len(self.query_chunks)):
chunk_ids = results["ids"][query_chunk_idx]
chunk_metas = results["metadatas"][query_chunk_idx]
chunk_docs = results["documents"][query_chunk_idx]
ranks = self.model.rank(
self.query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
)
for rank in ranks:
if include and QueryInclude.chunk in include:
documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"]))
else:
documents[chunk_metas[rank["corpus_id"]]["path"]].append(
float(rank["score"])
)

top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
for key in documents.keys():
documents[key] = heapq.nlargest(top_k, documents[key])

return heapq.nlargest(
self.n_result,
documents.keys(),
key=lambda x: float(numpy.mean(documents[x])),
)
Loading