diff --git a/src/vectorcode/chunking.py b/src/vectorcode/chunking.py
index 6f8218d3..89e30ad1 100644
--- a/src/vectorcode/chunking.py
+++ b/src/vectorcode/chunking.py
@@ -25,21 +25,36 @@ class Chunk:
"""
text: str
- start: Point
- end: Point
+ start: Point | None = None
+ end: Point | None = None
+ path: str | None = None
+ id: str | None = None
def __str__(self):
return self.text
+ def __hash__(self) -> int:
+ return hash(f"VectorCodeChunk_{self.path}({self.start}:{self.end}@{self.text})")
+
def export_dict(self):
+ d: dict[str, str | dict[str, int]] = {"text": self.text}
if self.start is not None:
- return {
- "text": self.text,
- "start": {"row": self.start.row, "column": self.start.column},
- "end": {"row": self.end.row, "column": self.end.column},
- }
- else:
- return {"text": self.text}
+ d.update(
+ {
+ "start": {"row": self.start.row, "column": self.start.column},
+ }
+ )
+ if self.end is not None:
+ d.update(
+ {
+ "end": {"row": self.end.row, "column": self.end.column},
+ }
+ )
+ if self.path:
+ d["path"] = self.path
+ if self.id:
+ d["chunk_id"] = self.id
+ return d
@dataclass
@@ -129,7 +144,7 @@ def chunk(
) -> Generator[Chunk, None, None]:
logger.info("Started chunking %s using FileChunker.", data.name)
lines = data.readlines()
- if len(lines) == 0:
+ if len(lines) == 0: # pragma: nocover
return
if (
self.config.chunk_size < 0
diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py
index 2ba782c1..808ba6f4 100644
--- a/src/vectorcode/mcp_main.py
+++ b/src/vectorcode/mcp_main.py
@@ -3,6 +3,7 @@
import logging
import os
import sys
+import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, cast
@@ -160,14 +161,17 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
await remove_orphanes(collection, collection_lock, stats, stats_lock)
return stats.to_dict()
- except Exception as e:
- logger.error("Failed to access collection at %s", project_root)
- raise McpError(
- ErrorData(
- code=1,
- message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.",
- )
- )
+ except Exception as e: # pragma: nocover
+ if isinstance(e, McpError):
+ logger.error("Failed to access collection at %s", project_root)
+ raise
+ else:
+ raise McpError(
+ ErrorData(
+ code=1,
+ message="\n".join(traceback.format_exception(e)),
+ )
+ ) from e
async def query_tool(
@@ -211,24 +215,28 @@ async def query_tool(
configs=query_config,
)
results: list[str] = []
- for path in result_paths:
- if os.path.isfile(path):
- with open(path) as fin:
- rel_path = os.path.relpath(path, config.project_root)
- results.append(
- f"{rel_path}\n{fin.read()}",
- )
+ for result in result_paths:
+ if isinstance(result, str):
+ if os.path.isfile(result):
+ with open(result) as fin:
+ rel_path = os.path.relpath(result, config.project_root)
+ results.append(
+ f"{rel_path}\n{fin.read()}",
+ )
logger.info("Retrieved the following files: %s", result_paths)
return results
- except Exception as e:
- logger.error("Failed to access collection at %s", project_root)
- raise McpError(
- ErrorData(
- code=1,
- message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
- )
- )
+ except Exception as e: # pragma: nocover
+ if isinstance(e, McpError):
+ logger.error("Failed to access collection at %s", project_root)
+ raise
+ else:
+ raise McpError(
+ ErrorData(
+ code=1,
+ message="\n".join(traceback.format_exception(e)),
+ )
+ ) from e
async def ls_files(project_root: str) -> list[str]:
diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py
index 8dea28b3..fec70da4 100644
--- a/src/vectorcode/subcommands/query/__init__.py
+++ b/src/vectorcode/subcommands/query/__init__.py
@@ -3,12 +3,13 @@
import os
from typing import Any, cast
-from chromadb import GetResult, Where
+from chromadb import Where
from chromadb.api.models.AsyncCollection import AsyncCollection
-from chromadb.api.types import IncludeEnum
+from chromadb.api.types import IncludeEnum, QueryResult
from chromadb.errors import InvalidCollectionException, InvalidDimensionException
+from tree_sitter import Point
-from vectorcode.chunking import StringChunker
+from vectorcode.chunking import Chunk, StringChunker
from vectorcode.cli_utils import (
Config,
QueryInclude,
@@ -22,6 +23,7 @@
get_embedding_function,
verify_ef,
)
+from vectorcode.subcommands.query import types as vectorcode_types
from vectorcode.subcommands.query.reranker import (
RerankerError,
get_reranker,
@@ -30,14 +32,49 @@
logger = logging.getLogger(name=__name__)
+def convert_query_results(
+ chroma_result: QueryResult, queries: list[str]
+) -> list[vectorcode_types.QueryResult]:
+ """Convert chromadb query result to in-house query results"""
+ assert chroma_result["documents"] is not None
+ assert chroma_result["distances"] is not None
+ assert chroma_result["metadatas"] is not None
+ assert chroma_result["ids"] is not None
+
+ chroma_results_list: list[vectorcode_types.QueryResult] = []
+ for q_i in range(len(queries)):
+ q = queries[q_i]
+ documents = chroma_result["documents"][q_i]
+ distances = chroma_result["distances"][q_i]
+ metadatas = chroma_result["metadatas"][q_i]
+ ids = chroma_result["ids"][q_i]
+ for doc, dist, meta, _id in zip(documents, distances, metadatas, ids):
+ chunk = Chunk(text=doc, id=_id)
+ if meta.get("start"):
+ chunk.start = Point(int(meta.get("start", 0)), 0)
+ if meta.get("end"):
+ chunk.end = Point(int(meta.get("end", 0)), 0)
+ if meta.get("path"):
+ chunk.path = str(meta["path"])
+ chroma_results_list.append(
+ vectorcode_types.QueryResult(
+ chunk=chunk,
+ path=str(meta.get("path", "")),
+ query=(q,),
+ scores=(-dist,),
+ )
+ )
+ return chroma_results_list
+
+
async def get_query_result_files(
collection: AsyncCollection, configs: Config
-) -> list[str]:
+) -> list[str | Chunk]:
query_chunks = []
- if configs.query:
- chunker = StringChunker(configs)
- for q in configs.query:
- query_chunks.extend(str(i) for i in chunker.chunk(q))
+ assert configs.query, "Query messages cannot be empty."
+ chunker = StringChunker(configs)
+ for q in configs.query:
+ query_chunks.extend(str(i) for i in chunker.chunk(q))
configs.query_exclude = [
expand_path(i, True)
@@ -70,7 +107,7 @@ async def get_query_result_files(
query_embeddings = get_embedding_function(configs)(query_chunks)
if isinstance(configs.embedding_dims, int) and configs.embedding_dims > 0:
query_embeddings = [e[: configs.embedding_dims] for e in query_embeddings]
- results = await collection.query(
+ chroma_query_results: QueryResult = await collection.query(
query_embeddings=query_embeddings,
n_results=num_query,
include=[
@@ -85,69 +122,51 @@ async def get_query_result_files(
return []
reranker = get_reranker(configs)
- return await reranker.rerank(results)
+ return await reranker.rerank(
+ convert_query_results(chroma_query_results, configs.query)
+ )
async def build_query_results(
collection: AsyncCollection, configs: Config
) -> list[dict[str, str | int]]:
- structured_result = []
- for identifier in await get_query_result_files(collection, configs):
- if os.path.isfile(identifier):
- if configs.use_absolute_path:
- output_path = os.path.abspath(identifier)
- else:
- output_path = os.path.relpath(identifier, configs.project_root)
- full_result = {"path": output_path}
- with open(identifier) as fin:
- document = fin.read()
- full_result["document"] = document
+ assert configs.project_root
- structured_result.append(
- {str(key): full_result[str(key)] for key in configs.include}
- )
- elif QueryInclude.chunk in configs.include:
- chunks: GetResult = await collection.get(
- identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents]
- )
- meta = chunks.get(
- "metadatas",
- )
- if meta is not None and len(meta) != 0:
- chunk_texts = chunks.get("documents")
- assert chunk_texts is not None, (
- "QueryResult does not contain `documents`!"
- )
- full_result: dict[str, str | int] = {
- "chunk": str(chunk_texts[0]),
- "chunk_id": identifier,
- }
- if meta[0].get("start") is not None and meta[0].get("end") is not None:
- path = str(meta[0].get("path"))
- with open(path) as fin:
- start: int = int(meta[0]["start"])
- end: int = int(meta[0]["end"])
- full_result["chunk"] = "".join(fin.readlines()[start : end + 1])
- full_result["start_line"] = start
- full_result["end_line"] = end
- if QueryInclude.path in configs.include:
- full_result["path"] = str(
- meta[0]["path"]
- if configs.use_absolute_path
- else os.path.relpath(
- str(meta[0]["path"]), str(configs.project_root)
- )
- )
-
- structured_result.append(full_result)
- else: # pragma: nocover
- logger.error(
- "This collection doesn't support chunk-mode output because it lacks the necessary metadata. Please re-vectorise it.",
- )
+ def make_output_path(path: str, absolute: bool) -> str:
+ if absolute:
+ if os.path.isabs(path):
+ return path
+ return os.path.abspath(os.path.join(str(configs.project_root), path))
+ else:
+ rel_path = os.path.relpath(path, configs.project_root)
+ if isinstance(rel_path, bytes): # pragma: nocover
+ # for some reasons, some python versions report that `os.path.relpath` returns a string.
+ rel_path = rel_path.decode()
+ return rel_path
+ structured_result = []
+ for res in await get_query_result_files(collection, configs):
+ if isinstance(res, str):
+ output_path = make_output_path(res, configs.use_absolute_path)
+ io_path = make_output_path(res, True)
+ if not os.path.isfile(io_path):
+ logger.warning(f"{io_path} is no longer a valid file.")
+ continue
+ with open(io_path) as fin:
+ structured_result.append({"path": output_path, "document": fin.read()})
else:
- logger.warning(
- f"{identifier} is no longer a valid file! Please re-run vectorcode vectorise to refresh the database.",
+ res = cast(Chunk, res)
+ assert res.path, f"{res} has no `path` attribute."
+ structured_result.append(
+ {
+ "path": make_output_path(res.path, configs.use_absolute_path)
+ if res.path is not None
+ else None,
+ "chunk": res.text,
+ "start_line": res.start.row if res.start is not None else None,
+ "end_line": res.end.row if res.end is not None else None,
+ "chunk_id": res.id,
+ }
)
for result in structured_result:
if result.get("path") is not None:
diff --git a/src/vectorcode/subcommands/query/reranker/base.py b/src/vectorcode/subcommands/query/reranker/base.py
index 047d978e..18a4c68a 100644
--- a/src/vectorcode/subcommands/query/reranker/base.py
+++ b/src/vectorcode/subcommands/query/reranker/base.py
@@ -1,13 +1,13 @@
import heapq
import logging
from abc import ABC, abstractmethod
-from collections import defaultdict
-from typing import Any, DefaultDict, Optional, Sequence, cast
+from typing import Any
import numpy
-from chromadb.api.types import QueryResult
+from vectorcode.chunking import Chunk
from vectorcode.cli_utils import Config, QueryInclude
+from vectorcode.subcommands.query.types import QueryResult
logger = logging.getLogger(name=__name__)
@@ -29,7 +29,7 @@ def __init__(self, configs: Config, **kwargs: Any):
"'configs' should contain the query messages."
)
self.n_result = configs.n_result
- self._raw_results: Optional[QueryResult] = None
+ self._raw_results: list[QueryResult] = []
@classmethod
def create(cls, configs: Config, **kwargs: Any):
@@ -47,52 +47,37 @@ def create(cls, configs: Config, **kwargs: Any):
@abstractmethod
async def compute_similarity(
- self, results: list[str], query_message: str
- ) -> Sequence[float]: # pragma: nocover
- """Given a list of n results and 1 query message,
- return a list-like object of length n that contains the similarity scores between
- each item in `results` and the `query_message`.
-
- A high similarity score means the strings are semantically similar to each other.
- `query_message` will be loaded in the same order as they appear in `self.configs.query`.
-
- If you need the raw query results from chromadb,
- it'll be saved in `self._raw_results` before this method is called.
+ self, results: list[QueryResult]
+ ) -> None: # pragma: nocover
+ """
+ Modify the `QueryResult.scores` field **IN-PLACE** so that they contain the correct scores.
"""
raise NotImplementedError
- async def rerank(self, results: QueryResult | dict) -> list[str]:
- if len(results["ids"]) == 0 or all(len(i) == 0 for i in results["ids"]):
+ async def rerank(self, results: list[QueryResult]) -> list[str | Chunk]:
+ if len(results) == 0:
return []
- self._raw_results = cast(QueryResult, results)
- query_chunks = self.configs.query
- assert query_chunks
- assert results["metadatas"] is not None
- assert results["documents"] is not None
- documents: DefaultDict[str, list[float]] = defaultdict(list)
- for query_chunk_idx in range(len(query_chunks)):
- chunk_ids = results["ids"][query_chunk_idx]
- chunk_metas = results["metadatas"][query_chunk_idx]
- chunk_docs = results["documents"][query_chunk_idx]
- scores = await self.compute_similarity(
- chunk_docs, query_chunks[query_chunk_idx]
+ # compute the similarity scores
+ await self.compute_similarity(results)
+
+ # group the results by the query type: file (path) or chunk
+ # and only keep the `top_k` results for each group
+ group_by = "path"
+ if QueryInclude.chunk in self.configs.include:
+ group_by = "chunk"
+ grouped_results = QueryResult.group(*results, by=group_by, top_k="auto")
+
+ # compute the mean scores for each of the groups
+ scores: dict[Chunk | str, float] = {}
+ for key in grouped_results.keys():
+ scores[key] = float(
+ numpy.mean(tuple(i.mean_score() for i in grouped_results[key]))
)
- for i, score in enumerate(scores):
- if QueryInclude.chunk in self.configs.include:
- documents[chunk_ids[i]].append(float(score))
- else:
- documents[str(chunk_metas[i]["path"])].append(float(score))
- logger.debug("Document scores: %s", documents)
- top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
- for key in documents.keys():
- documents[key] = heapq.nlargest(top_k, documents[key])
-
- self._raw_results = None
-
- return heapq.nlargest(
- self.n_result,
- documents.keys(),
- key=lambda x: float(numpy.mean(documents[x])),
+ return list(
+ i
+ for i in heapq.nlargest(
+ self.configs.n_result, grouped_results.keys(), key=lambda x: scores[x]
+ )
)
diff --git a/src/vectorcode/subcommands/query/reranker/cross_encoder.py b/src/vectorcode/subcommands/query/reranker/cross_encoder.py
index 51758f76..3f2fcd1d 100644
--- a/src/vectorcode/subcommands/query/reranker/cross_encoder.py
+++ b/src/vectorcode/subcommands/query/reranker/cross_encoder.py
@@ -1,8 +1,8 @@
-import asyncio
import logging
from typing import Any
from vectorcode.cli_utils import Config
+from vectorcode.subcommands.query.types import QueryResult
from .base import RerankerBase
@@ -34,8 +34,8 @@ def __init__(
model_name = configs.reranker_params.pop("model_name_or_path")
self.model = CrossEncoder(model_name, **configs.reranker_params)
- async def compute_similarity(self, results: list[str], query_message: str):
- scores = await asyncio.to_thread(
- self.model.predict, [(chunk, query_message) for chunk in results]
- )
- return list(float(i) for i in scores)
+ async def compute_similarity(self, results: list[QueryResult]):
+ scores = self.model.predict([(str(res.chunk), res.query[0]) for res in results])
+
+ for res, score in zip(results, scores):
+ res.scores = (score,)
diff --git a/src/vectorcode/subcommands/query/reranker/naive.py b/src/vectorcode/subcommands/query/reranker/naive.py
index 6e6b4336..65478c09 100644
--- a/src/vectorcode/subcommands/query/reranker/naive.py
+++ b/src/vectorcode/subcommands/query/reranker/naive.py
@@ -1,7 +1,8 @@
import logging
-from typing import Any, Sequence
+from typing import Any
from vectorcode.cli_utils import Config
+from vectorcode.subcommands.query.types import QueryResult
from .base import RerankerBase
@@ -17,15 +18,8 @@ class NaiveReranker(RerankerBase):
def __init__(self, configs: Config, **kwargs: Any):
super().__init__(configs)
- async def compute_similarity(
- self, results: list[str], query_message: str
- ) -> Sequence[float]:
- assert self._raw_results is not None, "Expecting raw results from the database."
- assert self._raw_results.get("distances") is not None
- assert self.configs.query, "Expecting query messages in self.configs"
- idx = self.configs.query.index(query_message)
- dist = self._raw_results.get("distances")
- if dist is None: # pragma: nocover
- raise ValueError("QueryResult should contain distances!")
- else:
- return list(-i for i in dist[idx])
+ async def compute_similarity(self, results: list[QueryResult]):
+ """
+ Do nothing, because the QueryResult objects already contain distances.
+ """
+ pass
diff --git a/src/vectorcode/subcommands/query/types.py b/src/vectorcode/subcommands/query/types.py
new file mode 100644
index 00000000..e7e5507f
--- /dev/null
+++ b/src/vectorcode/subcommands/query/types.py
@@ -0,0 +1,94 @@
+import heapq
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Literal, Union
+
+import numpy
+
+from vectorcode.chunking import Chunk
+
+
+@dataclass
+class QueryResult:
+ """
+ The container for one single query result.
+
+ args:
+ - path: path to the file
+ - content: `vectorcode.chunking.Chunk` object that stores the chunk
+ - query: query messages used for the search
+ - scores: similarity scores for the corresponding query.
+ """
+
+ path: str
+ chunk: Chunk
+ query: tuple[str, ...]
+ scores: tuple[float, ...]
+
+ @classmethod
+ def merge(cls, *results: "QueryResult") -> "QueryResult":
+ """
+ Given the results of a single chunk/document from different queries, merge them into a single `QueryResult` object.
+ """
+ for i in range(len(results) - 1):
+ if (i < len(results) - 1) and not results[i].is_same_doc(results[i + 1]):
+ raise ValueError(
+ f"The inputs are not the same chunk: {results[i]}, {results[i + 1]}"
+ )
+
+ return QueryResult(
+ path=results[0].path,
+ chunk=results[0].chunk,
+ query=sum((tuple(i.query) for i in results), start=tuple()),
+ scores=sum((tuple(i.scores) for i in results), start=tuple()),
+ )
+
+ @staticmethod
+ def group(
+ *results: "QueryResult",
+ by: Union[Literal["path"], Literal["chunk"]] = "path",
+ top_k: int | Literal["auto"] | None = None,
+ ) -> dict[Chunk | str, list["QueryResult"]]:
+ """
+ Group the query results based on `key`.
+
+ args:
+ - `by`: either "path" or "chunk"
+ - `top_k`: if set, only return the top k results for each group based on mean scores. If "auto", top k is decided by the mean number of results per group.
+
+ returns:
+ - a dictionary that maps either path or chunk to a list of `QueryResult` object.
+
+ """
+ assert by in {"path", "chunk"}
+ grouped_result: dict[Chunk | str, list["QueryResult"]] = defaultdict(list)
+
+ for res in results:
+ grouped_result[getattr(res, by)].append(res)
+
+ if top_k == "auto":
+ top_k = int(numpy.mean(tuple(len(i) for i in grouped_result.values())))
+
+ if top_k and top_k > 0:
+ for group in grouped_result.keys():
+ grouped_result[group] = heapq.nlargest(top_k, grouped_result[group])
+ return grouped_result
+
+ def mean_score(self):
+ return float(numpy.mean(self.scores))
+
+ def __lt__(self, other: "QueryResult"):
+ assert isinstance(other, QueryResult)
+ return self.mean_score() < other.mean_score()
+
+ def __gt__(self, other: "QueryResult"):
+ assert isinstance(other, QueryResult)
+ return self.mean_score() > other.mean_score()
+
+ def __eq__(self, other: object, /) -> bool:
+ return (
+ isinstance(other, QueryResult) and self.mean_score() == other.mean_score()
+ )
+
+ def is_same_doc(self, other: "QueryResult") -> bool:
+ return self.path == other.path and self.chunk == other.chunk
diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py
index 1cf51569..2ce0b249 100644
--- a/src/vectorcode/subcommands/vectorise.py
+++ b/src/vectorcode/subcommands/vectorise.py
@@ -139,8 +139,10 @@ async def chunked_add(
"sha256": new_sha256,
}
if isinstance(chunk, Chunk):
- meta["start"] = chunk.start.row
- meta["end"] = chunk.end.row
+ if chunk.start:
+ meta["start"] = chunk.start.row
+ if chunk.end:
+ meta["end"] = chunk.end.row
metas.append(meta)
async with collection_lock:
diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py
index 43392526..67251e4c 100644
--- a/tests/subcommands/query/test_query.py
+++ b/tests/subcommands/query/test_query.py
@@ -1,7 +1,7 @@
-from unittest.mock import AsyncMock, MagicMock, mock_open, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-from chromadb import GetResult
+from chromadb import QueryResult
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import IncludeEnum
from chromadb.errors import InvalidCollectionException, InvalidDimensionException
@@ -9,6 +9,7 @@
from vectorcode.cli_utils import CliAction, Config, QueryInclude
from vectorcode.subcommands.query import (
build_query_results,
+ convert_query_results,
get_query_result_files,
query,
)
@@ -47,7 +48,7 @@ def mock_collection():
@pytest.fixture
def mock_config():
return Config(
- query=["test query"],
+ query=["test query", "test query 2"],
n_result=3,
query_multiplier=2,
chunk_size=100,
@@ -65,6 +66,7 @@ def mock_config():
@pytest.mark.asyncio
async def test_get_query_result_files(mock_collection, mock_config):
mock_embedding_function = MagicMock()
+ mock_config.embedding_dims = 10
with (
patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker,
patch(
@@ -88,7 +90,7 @@ async def test_get_query_result_files(mock_collection, mock_config):
# Check that query was called with the right parameters
mock_collection.query.assert_called_once()
args, kwargs = mock_collection.query.call_args
- mock_embedding_function.assert_called_once_with(["test query"])
+ mock_embedding_function.assert_called_once_with(["test query", "test query 2"])
assert kwargs["n_results"] == 6 # n_result(3) * query_multiplier(2)
assert IncludeEnum.metadatas in kwargs["include"]
assert IncludeEnum.distances in kwargs["include"]
@@ -98,11 +100,14 @@ async def test_get_query_result_files(mock_collection, mock_config):
# Check reranker was used correctly
mock_get_reranker.assert_called_once_with(mock_config)
mock_reranker_instance.rerank.assert_called_once_with(
- mock_collection.query.return_value
+ convert_query_results(mock_collection.query.return_value, mock_config.query)
)
# Check the result
assert result == ["file1.py", "file2.py", "file3.py"]
+ assert all(
+ len(i) == 10 for i in mock_collection.query.kwargs["query_embeddings"]
+ )
@pytest.mark.asyncio
@@ -128,59 +133,56 @@ async def test_get_query_result_files_include_chunk(mock_collection, mock_config
@pytest.mark.asyncio
async def test_build_query_results_chunk_mode_success(mock_collection, mock_config):
"""Test build_query_results in chunk mode successfully retrieves chunk details."""
- mock_config.include = [QueryInclude.chunk, QueryInclude.path]
- mock_config.project_root = "/test/project"
- mock_config.use_absolute_path = False
- identifier = "chunk_id_1"
- file_path = "/test/project/subdir/file1.py"
- relative_path = "subdir/file1.py"
- start_line = 5
- end_line = 10
-
- full_file_content_lines = [f"line {i}\n" for i in range(15)]
- full_file_content = "".join(full_file_content_lines)
-
- expected_chunk_content = "".join(full_file_content_lines[start_line : end_line + 1])
-
- mock_get_result = GetResult(
- ids=[identifier],
- embeddings=None,
- documents=["original chunk doc in db"],
- metadatas=[{"path": file_path, "start": start_line, "end": end_line}],
- )
-
- with (
- patch(
- "vectorcode.subcommands.query.get_query_result_files",
- return_value=[identifier],
- ),
- patch("os.path.isfile", return_value=False),
- patch("builtins.open", mock_open(read_data=full_file_content)) as mocked_open,
- patch("os.path.relpath", return_value=relative_path) as mock_relpath,
- ):
- mock_collection.get = AsyncMock(return_value=mock_get_result)
-
- results = await build_query_results(mock_collection, mock_config)
-
- mock_collection.get.assert_called_once_with(
- identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents]
+ for request_abs_path in (True, False):
+ mock_config.include = [QueryInclude.chunk, QueryInclude.path]
+ mock_config.project_root = "/test/project"
+ mock_config.use_absolute_path = request_abs_path
+ mock_config.query = ["dummy_query"]
+ identifier = "chunk_id"
+ file_path = "/test/project/subdir/file1.py"
+ relative_path = "subdir/file1.py"
+ start_line = 5
+ end_line = 10
+
+ full_file_content_lines = [f"line {i}\n" for i in range(15)]
+
+ expected_chunk_content = "".join(
+ full_file_content_lines[start_line : end_line + 1]
)
- mocked_open.assert_called_once_with(file_path)
-
- mock_relpath.assert_called_once_with(file_path, str(mock_config.project_root))
-
- assert len(results) == 1
-
- expected_full_result = {
- "path": relative_path,
- "chunk": expected_chunk_content,
- "start_line": start_line,
- "end_line": end_line,
- "chunk_id": identifier,
- }
-
- assert results[0] == expected_full_result
+ mock_get_result = QueryResult(
+ ids=[[identifier]],
+ documents=[[expected_chunk_content]],
+ metadatas=[[{"path": file_path, "start": start_line, "end": end_line}]],
+ distances=[[0.2]],
+ )
+ mock_collection.query = AsyncMock(return_value=mock_get_result)
+ with (
+ patch(
+ "vectorcode.subcommands.query.get_query_result_files",
+ return_value=await get_query_result_files(mock_collection, mock_config),
+ ),
+ patch("os.path.isfile", return_value=False),
+ patch("os.path.relpath", return_value=relative_path) as mock_relpath,
+ ):
+ results = await build_query_results(mock_collection, mock_config)
+
+ if not request_abs_path:
+ mock_relpath.assert_called_once_with(
+ file_path, str(mock_config.project_root)
+ )
+
+ assert len(results) == 1
+
+ expected_full_result = {
+ "path": file_path if request_abs_path else relative_path,
+ "chunk": expected_chunk_content,
+ "start_line": start_line,
+ "end_line": end_line,
+ "chunk_id": identifier,
+ }
+
+ assert results[0] == expected_full_result
@pytest.mark.asyncio
@@ -323,40 +325,6 @@ async def test_get_query_result_files_chunking(mock_collection, mock_config):
assert result == ["file1.py", "file2.py"]
-@pytest.mark.asyncio
-async def test_get_query_result_files_multiple_queries(mock_collection, mock_config):
- # Set multiple query terms
- mock_config.query = ["term1", "term2", "term3"]
- mock_config.embedding_dims = 10
-
- with (
- patch("vectorcode.subcommands.query.StringChunker") as MockChunker,
- patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker,
- ):
- # Set up MockChunker to return the query terms as is
- mock_chunker_instance = MagicMock()
- mock_chunker_instance.chunk.side_effect = lambda q: [q]
- MockChunker.return_value = mock_chunker_instance
-
- mock_reranker_instance = MagicMock()
- mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
- MockReranker.return_value = mock_reranker_instance
-
- # Call the function
- result = await get_query_result_files(mock_collection, mock_config)
-
- # Check that chunker was called for each query term
- assert mock_chunker_instance.chunk.call_count == 3
-
- # Check query was called with all query terms
- mock_collection.query.assert_called_once()
- _, kwargs = mock_collection.query.call_args
- assert all(len(i) == 10 for i in kwargs["query_embeddings"])
-
- # Check the result
- assert result == ["file1.py", "file2.py"]
-
-
@pytest.mark.asyncio
async def test_query_success(mock_config):
# Mock all the necessary dependencies
diff --git a/tests/subcommands/query/test_reranker.py b/tests/subcommands/query/test_reranker.py
index 26ba4a20..31efe758 100644
--- a/tests/subcommands/query/test_reranker.py
+++ b/tests/subcommands/query/test_reranker.py
@@ -14,6 +14,7 @@
get_available_rerankers,
get_reranker,
)
+from vectorcode.subcommands.query.types import QueryResult
@pytest.fixture(scope="function")
@@ -37,29 +38,50 @@ def naive_reranker_conf():
@pytest.fixture(scope="function")
-def query_result():
- return {
- "ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
- "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
- "metadatas": [
- [{"path": "file1.py"}, {"path": "file2.py"}, {"path": "file3.py"}],
- [{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
- ],
- "documents": [
- ["content1", "content2", "content3"],
- ["content4", "content5", "content6"],
- ],
- }
+def query_result() -> list[QueryResult]:
+ return [
+ QueryResult(
+ path="file1.py",
+ chunk=MagicMock(),
+ query=("query chunk 1",),
+ scores=(0.5,),
+ ),
+ QueryResult(
+ path="file2.py",
+ chunk=MagicMock(),
+ query=("query chunk 1",),
+ scores=(0.9,),
+ ),
+ QueryResult(
+ path="file3.py",
+ chunk=MagicMock(),
+ query=("query chunk 1",),
+ scores=(0.3,),
+ ),
+ QueryResult(
+ path="file2.py",
+ chunk=MagicMock(),
+ query=("query chunk 2",),
+ scores=(0.6,),
+ ),
+ QueryResult(
+ path="file4.py",
+ chunk=MagicMock(),
+ query=("query chunk 2",),
+ scores=(0.7,),
+ ),
+ QueryResult(
+ path="file3.py",
+ chunk=MagicMock(),
+ query=("query chunk 2",),
+ scores=(0.2,),
+ ),
+ ]
@pytest.fixture(scope="function")
def empty_query_result():
- return {
- "ids": [],
- "distances": [],
- "metadatas": [],
- "documents": [],
- }
+ return []
@pytest.fixture(scope="function")
@@ -103,8 +125,24 @@ async def test_naive_reranker_rerank(naive_reranker_conf, query_result):
assert len(result) <= naive_reranker_conf.n_result
# Check all returned items are strings (paths)
- for path in result:
- assert isinstance(path, str)
+ for res in result:
+ assert isinstance(res, str)
+
+
+@pytest.mark.asyncio
+async def test_naive_reranker_rerank_chunks(naive_reranker_conf, query_result):
+ """Test basic reranking functionality of NaiveReranker"""
+ naive_reranker_conf.include = [QueryInclude.chunk]
+ reranker = NaiveReranker(naive_reranker_conf)
+ chunks = {i.chunk for i in query_result}
+ result = await reranker.rerank(query_result)
+
+ # Check the result is a list of paths with correct length
+ assert isinstance(result, list)
+ assert len(result) <= naive_reranker_conf.n_result
+
+ for res in result:
+ assert res in chunks
@pytest.mark.asyncio
@@ -143,21 +181,7 @@ async def test_cross_encoder_reranker_rerank(mock_cross_encoder, config, query_r
mock_model = MagicMock()
mock_cross_encoder.return_value = mock_model
- # Configure mock predict to return numpy array with float32 dtype
- scores = numpy.array([0.9, 0.7, 0.8], dtype=numpy.float32)
- mock_model.predict.return_value = scores
-
- # Ensure complete query_result structure
- query_result.update(
- {
- "ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
- "documents": [["doc1", "doc2", "doc3"], ["doc4", "doc5", "doc6"]],
- "metadatas": [
- [{"path": "p1"}, {"path": "p2"}, {"path": "p3"}],
- [{"path": "p4"}, {"path": "p5"}, {"path": "p6"}],
- ],
- }
- )
+ mock_model.predict = lambda x: numpy.random.random((len(x),))
reranker = CrossEncoderReranker(config)
result = await reranker.rerank(query_result)
@@ -184,46 +208,6 @@ async def test_naive_reranker_document_selection_logic(
assert "file2.py" in result or "file3.py" in result
-@pytest.mark.asyncio
-async def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
- """Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
- naive_reranker_conf.include.append(
- QueryInclude.chunk
- ) # Assuming QueryInclude.chunk would be "chunk"
-
- reranker = NaiveReranker(naive_reranker_conf)
- result = await reranker.rerank(query_result)
-
- assert isinstance(result, list)
- assert len(result) <= naive_reranker_conf.n_result
- assert all(isinstance(id, str) for id in result)
- assert all(id.startswith("id") for id in result) # Verify IDs not paths
-
-
-@pytest.mark.asyncio
-@patch("sentence_transformers.CrossEncoder")
-async def test_cross_encoder_reranker_with_chunk_ids(
- mock_cross_encoder, config, query_result
-):
- """Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
- mock_model = MagicMock()
- mock_cross_encoder.return_value = mock_model
-
- # Setup mock to return numpy array scores
- scores = numpy.array([0.9, 0.7], dtype=numpy.float32)
- mock_model.predict.return_value = scores
-
- config.include = {QueryInclude.chunk}
- reranker = CrossEncoderReranker(config)
-
- result = await reranker.rerank(query_result)
-
- mock_model.predict.assert_called()
- assert isinstance(result, list)
- assert all(isinstance(id, str) for id in result)
- assert all(id in ["id1", "id2", "id3", "id4"] for id in result)
-
-
def test_get_reranker(config, naive_reranker_conf):
assert get_reranker(naive_reranker_conf).configs.reranker == "NaiveReranker"
diff --git a/tests/subcommands/query/test_types.py b/tests/subcommands/query/test_types.py
new file mode 100644
index 00000000..392b6c6f
--- /dev/null
+++ b/tests/subcommands/query/test_types.py
@@ -0,0 +1,82 @@
+import pytest
+from tree_sitter import Point
+
+from vectorcode.chunking import Chunk
+from vectorcode.subcommands.query.types import QueryResult
+
+
+def make_dummy_chunk():
+ return QueryResult(
+ path="dummy1.py",
+ chunk=Chunk(
+ text="hello", start=Point(row=1, column=0), end=Point(row=1, column=4)
+ ),
+ query=["hello"],
+ scores=[0.9],
+ )
+
+
+def test_QueryResult_merge():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.query = ["bye"]
+ res2.scores = [0.1]
+
+ merged = QueryResult.merge(res1, res2)
+ assert merged.path == res1.path
+ assert merged.chunk == res1.chunk
+ assert merged.mean_score() == 0.5
+ assert merged.query == ("hello", "bye")
+
+
+def test_QueryResult_merge_failed():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.path = "dummy2.py"
+ with pytest.raises(ValueError):
+ QueryResult.merge(res1, res2)
+
+
+def test_QueryResult_group_by_path():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.chunk = Chunk(
+ "hello", start=Point(row=2, column=0), end=Point(row=2, column=4)
+ )
+ res2.query = ["bye"]
+ res2.scores = [0.1]
+
+ grouped_dict = QueryResult.group(res1, res2)
+ assert len(grouped_dict.keys()) == 1
+ assert len(grouped_dict["dummy1.py"]) == 2
+
+
+def test_QueryResult_group_by_chunk():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.query = ["bye"]
+ res2.scores = [0.1]
+
+ grouped_dict = QueryResult.group(res1, res2, by="chunk")
+ assert len(grouped_dict.keys()) == 1
+ assert len(grouped_dict[res1.chunk]) == 2
+
+
+def test_QueryResult_group_top_k():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.chunk = Chunk(
+ "hello", start=Point(row=2, column=0), end=Point(row=2, column=4)
+ )
+ res2.query = ["bye"]
+ res2.scores = [0.1]
+
+ grouped_dict = QueryResult.group(res1, res2, top_k=1)
+ assert len(grouped_dict.keys()) == 1
+ assert len(grouped_dict["dummy1.py"]) == 1
+ assert grouped_dict["dummy1.py"][0].query[0] == "hello"
+
+
+def test_QueryResult_lt():
+ res1, res2 = (make_dummy_chunk(), make_dummy_chunk())
+ res2.chunk = Chunk(
+ "hello", start=Point(row=2, column=0), end=Point(row=2, column=4)
+ )
+ res2.query = ["bye"]
+ res2.scores = [0.1]
+ assert res2 < res1
diff --git a/tests/subcommands/test_chunks.py b/tests/subcommands/test_chunks.py
index 1a9d0f03..f90ce577 100644
--- a/tests/subcommands/test_chunks.py
+++ b/tests/subcommands/test_chunks.py
@@ -1,3 +1,4 @@
+import json
from unittest.mock import MagicMock, patch
import pytest
@@ -37,3 +38,62 @@ async def test_chunks():
assert mock_chunker.config == mock_config
mock_chunker.chunk.assert_called()
assert mock_chunker.chunk.call_count == 2
+
+
+@pytest.mark.asyncio
+async def test_chunks_pipe(capsys):
+ # Mock the Config object
+ mock_config = MagicMock(spec=Config)
+ mock_config.chunk_size = 2000
+ mock_config.overlap_ratio = 0.2
+ mock_config.files = ["file1.py"]
+ mock_config.pipe = True
+
+ # Mock the TreeSitterChunker
+ mock_chunker = TreeSitterChunker(mock_config)
+ mock_chunker.chunk = MagicMock()
+ _chunks = [
+ Chunk("chunk1_file1", Point(0, 1), Point(0, 12), path="file1.py", id="c1"),
+ Chunk("chunk2_file1", Point(1, 1), Point(1, 12), path="file1.py", id="c2"),
+ ]
+ mock_chunker.chunk.side_effect = [
+ _chunks,
+ ]
+ with patch(
+ "vectorcode.subcommands.chunks.TreeSitterChunker", return_value=mock_chunker
+ ):
+ # Call the chunks function
+ result = await chunks(mock_config)
+
+ # Assertions
+ assert result == 0
+ assert mock_chunker.config == mock_config
+ mock_chunker.chunk.assert_called()
+ assert mock_chunker.chunk.call_count == 1
+
+ captured = capsys.readouterr()
+ output = json.loads(captured.out.strip())
+ assert output == [
+ [
+ {
+ "text": "chunk1_file1",
+ "start": {
+ "row": 0,
+ "column": 1,
+ },
+ "end": {"row": 0, "column": 12},
+ "path": "file1.py",
+ "chunk_id": "c1",
+ },
+ {
+ "text": "chunk2_file1",
+ "start": {
+ "row": 1,
+ "column": 1,
+ },
+ "end": {"row": 1, "column": 12},
+ "path": "file1.py",
+ "chunk_id": "c2",
+ },
+ ]
+ ]
diff --git a/tests/test_mcp.py b/tests/test_mcp.py
index 0d1b8bd8..b9a40bbf 100644
--- a/tests/test_mcp.py
+++ b/tests/test_mcp.py
@@ -108,7 +108,7 @@ async def test_query_tool_success():
fin.writelines([f"doc{i}"])
with (
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
- patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
+ patch("vectorcode.mcp_main.get_collection", return_value=mock_collection),
patch(
"vectorcode.mcp_main.ClientManager._create_client",
return_value=mock_client,
@@ -125,7 +125,7 @@ async def test_query_tool_success():
mock_load_config_file.return_value = mock_config
mock_get_project_config.return_value = mock_config
- mock_get_collection.return_value = mock_collection
+ # mock_get_collection.return_value = mock_collection
mock_get_query_result_files.return_value = [
os.path.join(temp_dir, i) for i in ("file1.py", "file2.py")
@@ -149,17 +149,11 @@ async def test_query_tool_collection_access_failure():
side_effect=Exception("Failed to connect"),
),
):
- with pytest.raises(McpError) as exc_info:
+ with pytest.raises(McpError):
await query_tool(
n_query=2, query_messages=["keyword1"], project_root="/valid/path"
)
- assert exc_info.value.error.code == 1
- assert (
- "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field."
- in exc_info.value.error.message
- )
-
@pytest.mark.asyncio
async def test_query_tool_no_collection():
@@ -174,17 +168,11 @@ async def test_query_tool_no_collection():
):
mock_get_collection.return_value = None
- with pytest.raises(McpError) as exc_info:
+ with pytest.raises(McpError):
await query_tool(
n_query=2, query_messages=["keyword1"], project_root="/valid/path"
)
- assert exc_info.value.error.code == 1
- assert (
- "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field."
- in exc_info.value.error.message
- )
-
@pytest.mark.asyncio
async def test_vectorise_tool_invalid_project_root():
@@ -249,15 +237,9 @@ async def test_vectorise_files_collection_access_failure():
),
patch("vectorcode.mcp_main.get_collection"),
):
- with pytest.raises(McpError) as exc_info:
+ with pytest.raises(McpError):
await vectorise_files(paths=["file.py"], project_root="/valid/path")
- assert exc_info.value.error.code == 1
- assert (
- "Failed to create the collection at /valid/path"
- in exc_info.value.error.message
- )
-
@pytest.mark.asyncio
async def test_vectorise_files_with_exclude_spec():