diff --git a/src/vectorcode/chunking.py b/src/vectorcode/chunking.py index 6f8218d3..89e30ad1 100644 --- a/src/vectorcode/chunking.py +++ b/src/vectorcode/chunking.py @@ -25,21 +25,36 @@ class Chunk: """ text: str - start: Point - end: Point + start: Point | None = None + end: Point | None = None + path: str | None = None + id: str | None = None def __str__(self): return self.text + def __hash__(self) -> int: + return hash(f"VectorCodeChunk_{self.path}({self.start}:{self.end}@{self.text})") + def export_dict(self): + d: dict[str, str | dict[str, int]] = {"text": self.text} if self.start is not None: - return { - "text": self.text, - "start": {"row": self.start.row, "column": self.start.column}, - "end": {"row": self.end.row, "column": self.end.column}, - } - else: - return {"text": self.text} + d.update( + { + "start": {"row": self.start.row, "column": self.start.column}, + } + ) + if self.end is not None: + d.update( + { + "end": {"row": self.end.row, "column": self.end.column}, + } + ) + if self.path: + d["path"] = self.path + if self.id: + d["chunk_id"] = self.id + return d @dataclass @@ -129,7 +144,7 @@ def chunk( ) -> Generator[Chunk, None, None]: logger.info("Started chunking %s using FileChunker.", data.name) lines = data.readlines() - if len(lines) == 0: + if len(lines) == 0: # pragma: nocover return if ( self.config.chunk_size < 0 diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 2ba782c1..808ba6f4 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -3,6 +3,7 @@ import logging import os import sys +import traceback from dataclasses import dataclass from pathlib import Path from typing import Optional, cast @@ -160,14 +161,17 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] await remove_orphanes(collection, collection_lock, stats, stats_lock) return stats.to_dict() - except Exception as e: - logger.error("Failed to access collection at %s", project_root) - raise McpError( - ErrorData( - code=1, - message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.", - ) - ) + except Exception as e: # pragma: nocover + if isinstance(e, McpError): + logger.error("Failed to access collection at %s", project_root) + raise + else: + raise McpError( + ErrorData( + code=1, + message="\n".join(traceback.format_exception(e)), + ) + ) from e async def query_tool( @@ -211,24 +215,28 @@ async def query_tool( configs=query_config, ) results: list[str] = [] - for path in result_paths: - if os.path.isfile(path): - with open(path) as fin: - rel_path = os.path.relpath(path, config.project_root) - results.append( - f"{rel_path}\n{fin.read()}", - ) + for result in result_paths: + if isinstance(result, str): + if os.path.isfile(result): + with open(result) as fin: + rel_path = os.path.relpath(result, config.project_root) + results.append( + f"{rel_path}\n{fin.read()}", + ) logger.info("Retrieved the following files: %s", result_paths) return results - except Exception as e: - logger.error("Failed to access collection at %s", project_root) - raise McpError( - ErrorData( - code=1, - message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", - ) - ) + except Exception as e: # pragma: nocover + if isinstance(e, McpError): + logger.error("Failed to access collection at %s", project_root) + raise + else: + raise McpError( + ErrorData( + code=1, + message="\n".join(traceback.format_exception(e)), + ) + ) from e async def ls_files(project_root: str) -> list[str]: diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 8dea28b3..fec70da4 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -3,12 +3,13 @@ import os from typing import Any, cast -from chromadb import GetResult, Where +from chromadb import Where from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.types import IncludeEnum +from chromadb.api.types import IncludeEnum, QueryResult from chromadb.errors import InvalidCollectionException, InvalidDimensionException +from tree_sitter import Point -from vectorcode.chunking import StringChunker +from vectorcode.chunking import Chunk, StringChunker from vectorcode.cli_utils import ( Config, QueryInclude, @@ -22,6 +23,7 @@ get_embedding_function, verify_ef, ) +from vectorcode.subcommands.query import types as vectorcode_types from vectorcode.subcommands.query.reranker import ( RerankerError, get_reranker, @@ -30,14 +32,49 @@ logger = logging.getLogger(name=__name__) +def convert_query_results( + chroma_result: QueryResult, queries: list[str] +) -> list[vectorcode_types.QueryResult]: + """Convert chromadb query result to in-house query results""" + assert chroma_result["documents"] is not None + assert chroma_result["distances"] is not None + assert chroma_result["metadatas"] is not None + assert chroma_result["ids"] is not None + + chroma_results_list: list[vectorcode_types.QueryResult] = [] + for q_i in range(len(queries)): + q = queries[q_i] + documents = chroma_result["documents"][q_i] + distances = chroma_result["distances"][q_i] + metadatas = chroma_result["metadatas"][q_i] + ids = chroma_result["ids"][q_i] + for doc, dist, meta, _id in zip(documents, distances, metadatas, ids): + chunk = Chunk(text=doc, id=_id) + if meta.get("start"): + chunk.start = Point(int(meta.get("start", 0)), 0) + if meta.get("end"): + chunk.end = Point(int(meta.get("end", 0)), 0) + if meta.get("path"): + chunk.path = str(meta["path"]) + chroma_results_list.append( + vectorcode_types.QueryResult( + chunk=chunk, + path=str(meta.get("path", "")), + query=(q,), + scores=(-dist,), + ) + ) + return chroma_results_list + + async def get_query_result_files( collection: AsyncCollection, configs: Config -) -> list[str]: +) -> list[str | Chunk]: query_chunks = [] - if configs.query: - chunker = StringChunker(configs) - for q in configs.query: - query_chunks.extend(str(i) for i in chunker.chunk(q)) + assert configs.query, "Query messages cannot be empty." + chunker = StringChunker(configs) + for q in configs.query: + query_chunks.extend(str(i) for i in chunker.chunk(q)) configs.query_exclude = [ expand_path(i, True) @@ -70,7 +107,7 @@ async def get_query_result_files( query_embeddings = get_embedding_function(configs)(query_chunks) if isinstance(configs.embedding_dims, int) and configs.embedding_dims > 0: query_embeddings = [e[: configs.embedding_dims] for e in query_embeddings] - results = await collection.query( + chroma_query_results: QueryResult = await collection.query( query_embeddings=query_embeddings, n_results=num_query, include=[ @@ -85,69 +122,51 @@ async def get_query_result_files( return [] reranker = get_reranker(configs) - return await reranker.rerank(results) + return await reranker.rerank( + convert_query_results(chroma_query_results, configs.query) + ) async def build_query_results( collection: AsyncCollection, configs: Config ) -> list[dict[str, str | int]]: - structured_result = [] - for identifier in await get_query_result_files(collection, configs): - if os.path.isfile(identifier): - if configs.use_absolute_path: - output_path = os.path.abspath(identifier) - else: - output_path = os.path.relpath(identifier, configs.project_root) - full_result = {"path": output_path} - with open(identifier) as fin: - document = fin.read() - full_result["document"] = document + assert configs.project_root - structured_result.append( - {str(key): full_result[str(key)] for key in configs.include} - ) - elif QueryInclude.chunk in configs.include: - chunks: GetResult = await collection.get( - identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents] - ) - meta = chunks.get( - "metadatas", - ) - if meta is not None and len(meta) != 0: - chunk_texts = chunks.get("documents") - assert chunk_texts is not None, ( - "QueryResult does not contain `documents`!" - ) - full_result: dict[str, str | int] = { - "chunk": str(chunk_texts[0]), - "chunk_id": identifier, - } - if meta[0].get("start") is not None and meta[0].get("end") is not None: - path = str(meta[0].get("path")) - with open(path) as fin: - start: int = int(meta[0]["start"]) - end: int = int(meta[0]["end"]) - full_result["chunk"] = "".join(fin.readlines()[start : end + 1]) - full_result["start_line"] = start - full_result["end_line"] = end - if QueryInclude.path in configs.include: - full_result["path"] = str( - meta[0]["path"] - if configs.use_absolute_path - else os.path.relpath( - str(meta[0]["path"]), str(configs.project_root) - ) - ) - - structured_result.append(full_result) - else: # pragma: nocover - logger.error( - "This collection doesn't support chunk-mode output because it lacks the necessary metadata. Please re-vectorise it.", - ) + def make_output_path(path: str, absolute: bool) -> str: + if absolute: + if os.path.isabs(path): + return path + return os.path.abspath(os.path.join(str(configs.project_root), path)) + else: + rel_path = os.path.relpath(path, configs.project_root) + if isinstance(rel_path, bytes): # pragma: nocover + # for some reasons, some python versions report that `os.path.relpath` returns a string. + rel_path = rel_path.decode() + return rel_path + structured_result = [] + for res in await get_query_result_files(collection, configs): + if isinstance(res, str): + output_path = make_output_path(res, configs.use_absolute_path) + io_path = make_output_path(res, True) + if not os.path.isfile(io_path): + logger.warning(f"{io_path} is no longer a valid file.") + continue + with open(io_path) as fin: + structured_result.append({"path": output_path, "document": fin.read()}) else: - logger.warning( - f"{identifier} is no longer a valid file! Please re-run vectorcode vectorise to refresh the database.", + res = cast(Chunk, res) + assert res.path, f"{res} has no `path` attribute." + structured_result.append( + { + "path": make_output_path(res.path, configs.use_absolute_path) + if res.path is not None + else None, + "chunk": res.text, + "start_line": res.start.row if res.start is not None else None, + "end_line": res.end.row if res.end is not None else None, + "chunk_id": res.id, + } ) for result in structured_result: if result.get("path") is not None: diff --git a/src/vectorcode/subcommands/query/reranker/base.py b/src/vectorcode/subcommands/query/reranker/base.py index 047d978e..18a4c68a 100644 --- a/src/vectorcode/subcommands/query/reranker/base.py +++ b/src/vectorcode/subcommands/query/reranker/base.py @@ -1,13 +1,13 @@ import heapq import logging from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, DefaultDict, Optional, Sequence, cast +from typing import Any import numpy -from chromadb.api.types import QueryResult +from vectorcode.chunking import Chunk from vectorcode.cli_utils import Config, QueryInclude +from vectorcode.subcommands.query.types import QueryResult logger = logging.getLogger(name=__name__) @@ -29,7 +29,7 @@ def __init__(self, configs: Config, **kwargs: Any): "'configs' should contain the query messages." ) self.n_result = configs.n_result - self._raw_results: Optional[QueryResult] = None + self._raw_results: list[QueryResult] = [] @classmethod def create(cls, configs: Config, **kwargs: Any): @@ -47,52 +47,37 @@ def create(cls, configs: Config, **kwargs: Any): @abstractmethod async def compute_similarity( - self, results: list[str], query_message: str - ) -> Sequence[float]: # pragma: nocover - """Given a list of n results and 1 query message, - return a list-like object of length n that contains the similarity scores between - each item in `results` and the `query_message`. - - A high similarity score means the strings are semantically similar to each other. - `query_message` will be loaded in the same order as they appear in `self.configs.query`. - - If you need the raw query results from chromadb, - it'll be saved in `self._raw_results` before this method is called. + self, results: list[QueryResult] + ) -> None: # pragma: nocover + """ + Modify the `QueryResult.scores` field **IN-PLACE** so that they contain the correct scores. """ raise NotImplementedError - async def rerank(self, results: QueryResult | dict) -> list[str]: - if len(results["ids"]) == 0 or all(len(i) == 0 for i in results["ids"]): + async def rerank(self, results: list[QueryResult]) -> list[str | Chunk]: + if len(results) == 0: return [] - self._raw_results = cast(QueryResult, results) - query_chunks = self.configs.query - assert query_chunks - assert results["metadatas"] is not None - assert results["documents"] is not None - documents: DefaultDict[str, list[float]] = defaultdict(list) - for query_chunk_idx in range(len(query_chunks)): - chunk_ids = results["ids"][query_chunk_idx] - chunk_metas = results["metadatas"][query_chunk_idx] - chunk_docs = results["documents"][query_chunk_idx] - scores = await self.compute_similarity( - chunk_docs, query_chunks[query_chunk_idx] + # compute the similarity scores + await self.compute_similarity(results) + + # group the results by the query type: file (path) or chunk + # and only keep the `top_k` results for each group + group_by = "path" + if QueryInclude.chunk in self.configs.include: + group_by = "chunk" + grouped_results = QueryResult.group(*results, by=group_by, top_k="auto") + + # compute the mean scores for each of the groups + scores: dict[Chunk | str, float] = {} + for key in grouped_results.keys(): + scores[key] = float( + numpy.mean(tuple(i.mean_score() for i in grouped_results[key])) ) - for i, score in enumerate(scores): - if QueryInclude.chunk in self.configs.include: - documents[chunk_ids[i]].append(float(score)) - else: - documents[str(chunk_metas[i]["path"])].append(float(score)) - logger.debug("Document scores: %s", documents) - top_k = int(numpy.mean(tuple(len(i) for i in documents.values()))) - for key in documents.keys(): - documents[key] = heapq.nlargest(top_k, documents[key]) - - self._raw_results = None - - return heapq.nlargest( - self.n_result, - documents.keys(), - key=lambda x: float(numpy.mean(documents[x])), + return list( + i + for i in heapq.nlargest( + self.configs.n_result, grouped_results.keys(), key=lambda x: scores[x] + ) ) diff --git a/src/vectorcode/subcommands/query/reranker/cross_encoder.py b/src/vectorcode/subcommands/query/reranker/cross_encoder.py index 51758f76..3f2fcd1d 100644 --- a/src/vectorcode/subcommands/query/reranker/cross_encoder.py +++ b/src/vectorcode/subcommands/query/reranker/cross_encoder.py @@ -1,8 +1,8 @@ -import asyncio import logging from typing import Any from vectorcode.cli_utils import Config +from vectorcode.subcommands.query.types import QueryResult from .base import RerankerBase @@ -34,8 +34,8 @@ def __init__( model_name = configs.reranker_params.pop("model_name_or_path") self.model = CrossEncoder(model_name, **configs.reranker_params) - async def compute_similarity(self, results: list[str], query_message: str): - scores = await asyncio.to_thread( - self.model.predict, [(chunk, query_message) for chunk in results] - ) - return list(float(i) for i in scores) + async def compute_similarity(self, results: list[QueryResult]): + scores = self.model.predict([(str(res.chunk), res.query[0]) for res in results]) + + for res, score in zip(results, scores): + res.scores = (score,) diff --git a/src/vectorcode/subcommands/query/reranker/naive.py b/src/vectorcode/subcommands/query/reranker/naive.py index 6e6b4336..65478c09 100644 --- a/src/vectorcode/subcommands/query/reranker/naive.py +++ b/src/vectorcode/subcommands/query/reranker/naive.py @@ -1,7 +1,8 @@ import logging -from typing import Any, Sequence +from typing import Any from vectorcode.cli_utils import Config +from vectorcode.subcommands.query.types import QueryResult from .base import RerankerBase @@ -17,15 +18,8 @@ class NaiveReranker(RerankerBase): def __init__(self, configs: Config, **kwargs: Any): super().__init__(configs) - async def compute_similarity( - self, results: list[str], query_message: str - ) -> Sequence[float]: - assert self._raw_results is not None, "Expecting raw results from the database." - assert self._raw_results.get("distances") is not None - assert self.configs.query, "Expecting query messages in self.configs" - idx = self.configs.query.index(query_message) - dist = self._raw_results.get("distances") - if dist is None: # pragma: nocover - raise ValueError("QueryResult should contain distances!") - else: - return list(-i for i in dist[idx]) + async def compute_similarity(self, results: list[QueryResult]): + """ + Do nothing, because the QueryResult objects already contain distances. + """ + pass diff --git a/src/vectorcode/subcommands/query/types.py b/src/vectorcode/subcommands/query/types.py new file mode 100644 index 00000000..e7e5507f --- /dev/null +++ b/src/vectorcode/subcommands/query/types.py @@ -0,0 +1,94 @@ +import heapq +from collections import defaultdict +from dataclasses import dataclass +from typing import Literal, Union + +import numpy + +from vectorcode.chunking import Chunk + + +@dataclass +class QueryResult: + """ + The container for one single query result. + + args: + - path: path to the file + - content: `vectorcode.chunking.Chunk` object that stores the chunk + - query: query messages used for the search + - scores: similarity scores for the corresponding query. + """ + + path: str + chunk: Chunk + query: tuple[str, ...] + scores: tuple[float, ...] + + @classmethod + def merge(cls, *results: "QueryResult") -> "QueryResult": + """ + Given the results of a single chunk/document from different queries, merge them into a single `QueryResult` object. + """ + for i in range(len(results) - 1): + if (i < len(results) - 1) and not results[i].is_same_doc(results[i + 1]): + raise ValueError( + f"The inputs are not the same chunk: {results[i]}, {results[i + 1]}" + ) + + return QueryResult( + path=results[0].path, + chunk=results[0].chunk, + query=sum((tuple(i.query) for i in results), start=tuple()), + scores=sum((tuple(i.scores) for i in results), start=tuple()), + ) + + @staticmethod + def group( + *results: "QueryResult", + by: Union[Literal["path"], Literal["chunk"]] = "path", + top_k: int | Literal["auto"] | None = None, + ) -> dict[Chunk | str, list["QueryResult"]]: + """ + Group the query results based on `key`. + + args: + - `by`: either "path" or "chunk" + - `top_k`: if set, only return the top k results for each group based on mean scores. If "auto", top k is decided by the mean number of results per group. + + returns: + - a dictionary that maps either path or chunk to a list of `QueryResult` object. + + """ + assert by in {"path", "chunk"} + grouped_result: dict[Chunk | str, list["QueryResult"]] = defaultdict(list) + + for res in results: + grouped_result[getattr(res, by)].append(res) + + if top_k == "auto": + top_k = int(numpy.mean(tuple(len(i) for i in grouped_result.values()))) + + if top_k and top_k > 0: + for group in grouped_result.keys(): + grouped_result[group] = heapq.nlargest(top_k, grouped_result[group]) + return grouped_result + + def mean_score(self): + return float(numpy.mean(self.scores)) + + def __lt__(self, other: "QueryResult"): + assert isinstance(other, QueryResult) + return self.mean_score() < other.mean_score() + + def __gt__(self, other: "QueryResult"): + assert isinstance(other, QueryResult) + return self.mean_score() > other.mean_score() + + def __eq__(self, other: object, /) -> bool: + return ( + isinstance(other, QueryResult) and self.mean_score() == other.mean_score() + ) + + def is_same_doc(self, other: "QueryResult") -> bool: + return self.path == other.path and self.chunk == other.chunk diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py index 1cf51569..2ce0b249 100644 --- a/src/vectorcode/subcommands/vectorise.py +++ b/src/vectorcode/subcommands/vectorise.py @@ -139,8 +139,10 @@ async def chunked_add( "sha256": new_sha256, } if isinstance(chunk, Chunk): - meta["start"] = chunk.start.row - meta["end"] = chunk.end.row + if chunk.start: + meta["start"] = chunk.start.row + if chunk.end: + meta["end"] = chunk.end.row metas.append(meta) async with collection_lock: diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py index 43392526..67251e4c 100644 --- a/tests/subcommands/query/test_query.py +++ b/tests/subcommands/query/test_query.py @@ -1,7 +1,7 @@ -from unittest.mock import AsyncMock, MagicMock, mock_open, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from chromadb import GetResult +from chromadb import QueryResult from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.types import IncludeEnum from chromadb.errors import InvalidCollectionException, InvalidDimensionException @@ -9,6 +9,7 @@ from vectorcode.cli_utils import CliAction, Config, QueryInclude from vectorcode.subcommands.query import ( build_query_results, + convert_query_results, get_query_result_files, query, ) @@ -47,7 +48,7 @@ def mock_collection(): @pytest.fixture def mock_config(): return Config( - query=["test query"], + query=["test query", "test query 2"], n_result=3, query_multiplier=2, chunk_size=100, @@ -65,6 +66,7 @@ def mock_config(): @pytest.mark.asyncio async def test_get_query_result_files(mock_collection, mock_config): mock_embedding_function = MagicMock() + mock_config.embedding_dims = 10 with ( patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker, patch( @@ -88,7 +90,7 @@ async def test_get_query_result_files(mock_collection, mock_config): # Check that query was called with the right parameters mock_collection.query.assert_called_once() args, kwargs = mock_collection.query.call_args - mock_embedding_function.assert_called_once_with(["test query"]) + mock_embedding_function.assert_called_once_with(["test query", "test query 2"]) assert kwargs["n_results"] == 6 # n_result(3) * query_multiplier(2) assert IncludeEnum.metadatas in kwargs["include"] assert IncludeEnum.distances in kwargs["include"] @@ -98,11 +100,14 @@ async def test_get_query_result_files(mock_collection, mock_config): # Check reranker was used correctly mock_get_reranker.assert_called_once_with(mock_config) mock_reranker_instance.rerank.assert_called_once_with( - mock_collection.query.return_value + convert_query_results(mock_collection.query.return_value, mock_config.query) ) # Check the result assert result == ["file1.py", "file2.py", "file3.py"] + assert all( + len(i) == 10 for i in mock_collection.query.kwargs["query_embeddings"] + ) @pytest.mark.asyncio @@ -128,59 +133,56 @@ async def test_get_query_result_files_include_chunk(mock_collection, mock_config @pytest.mark.asyncio async def test_build_query_results_chunk_mode_success(mock_collection, mock_config): """Test build_query_results in chunk mode successfully retrieves chunk details.""" - mock_config.include = [QueryInclude.chunk, QueryInclude.path] - mock_config.project_root = "/test/project" - mock_config.use_absolute_path = False - identifier = "chunk_id_1" - file_path = "/test/project/subdir/file1.py" - relative_path = "subdir/file1.py" - start_line = 5 - end_line = 10 - - full_file_content_lines = [f"line {i}\n" for i in range(15)] - full_file_content = "".join(full_file_content_lines) - - expected_chunk_content = "".join(full_file_content_lines[start_line : end_line + 1]) - - mock_get_result = GetResult( - ids=[identifier], - embeddings=None, - documents=["original chunk doc in db"], - metadatas=[{"path": file_path, "start": start_line, "end": end_line}], - ) - - with ( - patch( - "vectorcode.subcommands.query.get_query_result_files", - return_value=[identifier], - ), - patch("os.path.isfile", return_value=False), - patch("builtins.open", mock_open(read_data=full_file_content)) as mocked_open, - patch("os.path.relpath", return_value=relative_path) as mock_relpath, - ): - mock_collection.get = AsyncMock(return_value=mock_get_result) - - results = await build_query_results(mock_collection, mock_config) - - mock_collection.get.assert_called_once_with( - identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents] + for request_abs_path in (True, False): + mock_config.include = [QueryInclude.chunk, QueryInclude.path] + mock_config.project_root = "/test/project" + mock_config.use_absolute_path = request_abs_path + mock_config.query = ["dummy_query"] + identifier = "chunk_id" + file_path = "/test/project/subdir/file1.py" + relative_path = "subdir/file1.py" + start_line = 5 + end_line = 10 + + full_file_content_lines = [f"line {i}\n" for i in range(15)] + + expected_chunk_content = "".join( + full_file_content_lines[start_line : end_line + 1] ) - mocked_open.assert_called_once_with(file_path) - - mock_relpath.assert_called_once_with(file_path, str(mock_config.project_root)) - - assert len(results) == 1 - - expected_full_result = { - "path": relative_path, - "chunk": expected_chunk_content, - "start_line": start_line, - "end_line": end_line, - "chunk_id": identifier, - } - - assert results[0] == expected_full_result + mock_get_result = QueryResult( + ids=[[identifier]], + documents=[[expected_chunk_content]], + metadatas=[[{"path": file_path, "start": start_line, "end": end_line}]], + distances=[[0.2]], + ) + mock_collection.query = AsyncMock(return_value=mock_get_result) + with ( + patch( + "vectorcode.subcommands.query.get_query_result_files", + return_value=await get_query_result_files(mock_collection, mock_config), + ), + patch("os.path.isfile", return_value=False), + patch("os.path.relpath", return_value=relative_path) as mock_relpath, + ): + results = await build_query_results(mock_collection, mock_config) + + if not request_abs_path: + mock_relpath.assert_called_once_with( + file_path, str(mock_config.project_root) + ) + + assert len(results) == 1 + + expected_full_result = { + "path": file_path if request_abs_path else relative_path, + "chunk": expected_chunk_content, + "start_line": start_line, + "end_line": end_line, + "chunk_id": identifier, + } + + assert results[0] == expected_full_result @pytest.mark.asyncio @@ -323,40 +325,6 @@ async def test_get_query_result_files_chunking(mock_collection, mock_config): assert result == ["file1.py", "file2.py"] -@pytest.mark.asyncio -async def test_get_query_result_files_multiple_queries(mock_collection, mock_config): - # Set multiple query terms - mock_config.query = ["term1", "term2", "term3"] - mock_config.embedding_dims = 10 - - with ( - patch("vectorcode.subcommands.query.StringChunker") as MockChunker, - patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker, - ): - # Set up MockChunker to return the query terms as is - mock_chunker_instance = MagicMock() - mock_chunker_instance.chunk.side_effect = lambda q: [q] - MockChunker.return_value = mock_chunker_instance - - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) - MockReranker.return_value = mock_reranker_instance - - # Call the function - result = await get_query_result_files(mock_collection, mock_config) - - # Check that chunker was called for each query term - assert mock_chunker_instance.chunk.call_count == 3 - - # Check query was called with all query terms - mock_collection.query.assert_called_once() - _, kwargs = mock_collection.query.call_args - assert all(len(i) == 10 for i in kwargs["query_embeddings"]) - - # Check the result - assert result == ["file1.py", "file2.py"] - - @pytest.mark.asyncio async def test_query_success(mock_config): # Mock all the necessary dependencies diff --git a/tests/subcommands/query/test_reranker.py b/tests/subcommands/query/test_reranker.py index 26ba4a20..31efe758 100644 --- a/tests/subcommands/query/test_reranker.py +++ b/tests/subcommands/query/test_reranker.py @@ -14,6 +14,7 @@ get_available_rerankers, get_reranker, ) +from vectorcode.subcommands.query.types import QueryResult @pytest.fixture(scope="function") @@ -37,29 +38,50 @@ def naive_reranker_conf(): @pytest.fixture(scope="function") -def query_result(): - return { - "ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]], - "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - "metadatas": [ - [{"path": "file1.py"}, {"path": "file2.py"}, {"path": "file3.py"}], - [{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}], - ], - "documents": [ - ["content1", "content2", "content3"], - ["content4", "content5", "content6"], - ], - } +def query_result() -> list[QueryResult]: + return [ + QueryResult( + path="file1.py", + chunk=MagicMock(), + query=("query chunk 1",), + scores=(0.5,), + ), + QueryResult( + path="file2.py", + chunk=MagicMock(), + query=("query chunk 1",), + scores=(0.9,), + ), + QueryResult( + path="file3.py", + chunk=MagicMock(), + query=("query chunk 1",), + scores=(0.3,), + ), + QueryResult( + path="file2.py", + chunk=MagicMock(), + query=("query chunk 2",), + scores=(0.6,), + ), + QueryResult( + path="file4.py", + chunk=MagicMock(), + query=("query chunk 2",), + scores=(0.7,), + ), + QueryResult( + path="file3.py", + chunk=MagicMock(), + query=("query chunk 2",), + scores=(0.2,), + ), + ] @pytest.fixture(scope="function") def empty_query_result(): - return { - "ids": [], - "distances": [], - "metadatas": [], - "documents": [], - } + return [] @pytest.fixture(scope="function") @@ -103,8 +125,24 @@ async def test_naive_reranker_rerank(naive_reranker_conf, query_result): assert len(result) <= naive_reranker_conf.n_result # Check all returned items are strings (paths) - for path in result: - assert isinstance(path, str) + for res in result: + assert isinstance(res, str) + + +@pytest.mark.asyncio +async def test_naive_reranker_rerank_chunks(naive_reranker_conf, query_result): + """Test basic reranking functionality of NaiveReranker""" + naive_reranker_conf.include = [QueryInclude.chunk] + reranker = NaiveReranker(naive_reranker_conf) + chunks = {i.chunk for i in query_result} + result = await reranker.rerank(query_result) + + # Check the result is a list of paths with correct length + assert isinstance(result, list) + assert len(result) <= naive_reranker_conf.n_result + + for res in result: + assert res in chunks @pytest.mark.asyncio @@ -143,21 +181,7 @@ async def test_cross_encoder_reranker_rerank(mock_cross_encoder, config, query_r mock_model = MagicMock() mock_cross_encoder.return_value = mock_model - # Configure mock predict to return numpy array with float32 dtype - scores = numpy.array([0.9, 0.7, 0.8], dtype=numpy.float32) - mock_model.predict.return_value = scores - - # Ensure complete query_result structure - query_result.update( - { - "ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]], - "documents": [["doc1", "doc2", "doc3"], ["doc4", "doc5", "doc6"]], - "metadatas": [ - [{"path": "p1"}, {"path": "p2"}, {"path": "p3"}], - [{"path": "p4"}, {"path": "p5"}, {"path": "p6"}], - ], - } - ) + mock_model.predict = lambda x: numpy.random.random((len(x),)) reranker = CrossEncoderReranker(config) result = await reranker.rerank(query_result) @@ -184,46 +208,6 @@ async def test_naive_reranker_document_selection_logic( assert "file2.py" in result or "file3.py" in result -@pytest.mark.asyncio -async def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result): - """Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set""" - naive_reranker_conf.include.append( - QueryInclude.chunk - ) # Assuming QueryInclude.chunk would be "chunk" - - reranker = NaiveReranker(naive_reranker_conf) - result = await reranker.rerank(query_result) - - assert isinstance(result, list) - assert len(result) <= naive_reranker_conf.n_result - assert all(isinstance(id, str) for id in result) - assert all(id.startswith("id") for id in result) # Verify IDs not paths - - -@pytest.mark.asyncio -@patch("sentence_transformers.CrossEncoder") -async def test_cross_encoder_reranker_with_chunk_ids( - mock_cross_encoder, config, query_result -): - """Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set""" - mock_model = MagicMock() - mock_cross_encoder.return_value = mock_model - - # Setup mock to return numpy array scores - scores = numpy.array([0.9, 0.7], dtype=numpy.float32) - mock_model.predict.return_value = scores - - config.include = {QueryInclude.chunk} - reranker = CrossEncoderReranker(config) - - result = await reranker.rerank(query_result) - - mock_model.predict.assert_called() - assert isinstance(result, list) - assert all(isinstance(id, str) for id in result) - assert all(id in ["id1", "id2", "id3", "id4"] for id in result) - - def test_get_reranker(config, naive_reranker_conf): assert get_reranker(naive_reranker_conf).configs.reranker == "NaiveReranker" diff --git a/tests/subcommands/query/test_types.py b/tests/subcommands/query/test_types.py new file mode 100644 index 00000000..392b6c6f --- /dev/null +++ b/tests/subcommands/query/test_types.py @@ -0,0 +1,82 @@ +import pytest +from tree_sitter import Point + +from vectorcode.chunking import Chunk +from vectorcode.subcommands.query.types import QueryResult + + +def make_dummy_chunk(): + return QueryResult( + path="dummy1.py", + chunk=Chunk( + text="hello", start=Point(row=1, column=0), end=Point(row=1, column=4) + ), + query=["hello"], + scores=[0.9], + ) + + +def test_QueryResult_merge(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.query = ["bye"] + res2.scores = [0.1] + + merged = QueryResult.merge(res1, res2) + assert merged.path == res1.path + assert merged.chunk == res1.chunk + assert merged.mean_score() == 0.5 + assert merged.query == ("hello", "bye") + + +def test_QueryResult_merge_failed(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.path = "dummy2.py" + with pytest.raises(ValueError): + QueryResult.merge(res1, res2) + + +def test_QueryResult_group_by_path(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.chunk = Chunk( + "hello", start=Point(row=2, column=0), end=Point(row=2, column=4) + ) + res2.query = ["bye"] + res2.scores = [0.1] + + grouped_dict = QueryResult.group(res1, res2) + assert len(grouped_dict.keys()) == 1 + assert len(grouped_dict["dummy1.py"]) == 2 + + +def test_QueryResult_group_by_chunk(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.query = ["bye"] + res2.scores = [0.1] + + grouped_dict = QueryResult.group(res1, res2, by="chunk") + assert len(grouped_dict.keys()) == 1 + assert len(grouped_dict[res1.chunk]) == 2 + + +def test_QueryResult_group_top_k(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.chunk = Chunk( + "hello", start=Point(row=2, column=0), end=Point(row=2, column=4) + ) + res2.query = ["bye"] + res2.scores = [0.1] + + grouped_dict = QueryResult.group(res1, res2, top_k=1) + assert len(grouped_dict.keys()) == 1 + assert len(grouped_dict["dummy1.py"]) == 1 + assert grouped_dict["dummy1.py"][0].query[0] == "hello" + + +def test_QueryResult_lt(): + res1, res2 = (make_dummy_chunk(), make_dummy_chunk()) + res2.chunk = Chunk( + "hello", start=Point(row=2, column=0), end=Point(row=2, column=4) + ) + res2.query = ["bye"] + res2.scores = [0.1] + assert res2 < res1 diff --git a/tests/subcommands/test_chunks.py b/tests/subcommands/test_chunks.py index 1a9d0f03..f90ce577 100644 --- a/tests/subcommands/test_chunks.py +++ b/tests/subcommands/test_chunks.py @@ -1,3 +1,4 @@ +import json from unittest.mock import MagicMock, patch import pytest @@ -37,3 +38,62 @@ async def test_chunks(): assert mock_chunker.config == mock_config mock_chunker.chunk.assert_called() assert mock_chunker.chunk.call_count == 2 + + +@pytest.mark.asyncio +async def test_chunks_pipe(capsys): + # Mock the Config object + mock_config = MagicMock(spec=Config) + mock_config.chunk_size = 2000 + mock_config.overlap_ratio = 0.2 + mock_config.files = ["file1.py"] + mock_config.pipe = True + + # Mock the TreeSitterChunker + mock_chunker = TreeSitterChunker(mock_config) + mock_chunker.chunk = MagicMock() + _chunks = [ + Chunk("chunk1_file1", Point(0, 1), Point(0, 12), path="file1.py", id="c1"), + Chunk("chunk2_file1", Point(1, 1), Point(1, 12), path="file1.py", id="c2"), + ] + mock_chunker.chunk.side_effect = [ + _chunks, + ] + with patch( + "vectorcode.subcommands.chunks.TreeSitterChunker", return_value=mock_chunker + ): + # Call the chunks function + result = await chunks(mock_config) + + # Assertions + assert result == 0 + assert mock_chunker.config == mock_config + mock_chunker.chunk.assert_called() + assert mock_chunker.chunk.call_count == 1 + + captured = capsys.readouterr() + output = json.loads(captured.out.strip()) + assert output == [ + [ + { + "text": "chunk1_file1", + "start": { + "row": 0, + "column": 1, + }, + "end": {"row": 0, "column": 12}, + "path": "file1.py", + "chunk_id": "c1", + }, + { + "text": "chunk2_file1", + "start": { + "row": 1, + "column": 1, + }, + "end": {"row": 1, "column": 12}, + "path": "file1.py", + "chunk_id": "c2", + }, + ] + ] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 0d1b8bd8..b9a40bbf 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -108,7 +108,7 @@ async def test_query_tool_success(): fin.writelines([f"doc{i}"]) with ( patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch("vectorcode.mcp_main.get_collection", return_value=mock_collection), patch( "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client, @@ -125,7 +125,7 @@ async def test_query_tool_success(): mock_load_config_file.return_value = mock_config mock_get_project_config.return_value = mock_config - mock_get_collection.return_value = mock_collection + # mock_get_collection.return_value = mock_collection mock_get_query_result_files.return_value = [ os.path.join(temp_dir, i) for i in ("file1.py", "file2.py") @@ -149,17 +149,11 @@ async def test_query_tool_collection_access_failure(): side_effect=Exception("Failed to connect"), ), ): - with pytest.raises(McpError) as exc_info: + with pytest.raises(McpError): await query_tool( n_query=2, query_messages=["keyword1"], project_root="/valid/path" ) - assert exc_info.value.error.code == 1 - assert ( - "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field." - in exc_info.value.error.message - ) - @pytest.mark.asyncio async def test_query_tool_no_collection(): @@ -174,17 +168,11 @@ async def test_query_tool_no_collection(): ): mock_get_collection.return_value = None - with pytest.raises(McpError) as exc_info: + with pytest.raises(McpError): await query_tool( n_query=2, query_messages=["keyword1"], project_root="/valid/path" ) - assert exc_info.value.error.code == 1 - assert ( - "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field." - in exc_info.value.error.message - ) - @pytest.mark.asyncio async def test_vectorise_tool_invalid_project_root(): @@ -249,15 +237,9 @@ async def test_vectorise_files_collection_access_failure(): ), patch("vectorcode.mcp_main.get_collection"), ): - with pytest.raises(McpError) as exc_info: + with pytest.raises(McpError): await vectorise_files(paths=["file.py"], project_root="/valid/path") - assert exc_info.value.error.code == 1 - assert ( - "Failed to create the collection at /valid/path" - in exc_info.value.error.message - ) - @pytest.mark.asyncio async def test_vectorise_files_with_exclude_spec():