diff --git a/.github/workflows/test_and_cov.yml b/.github/workflows/test_and_cov.yml index eaea1fee..612d9b69 100644 --- a/.github/workflows/test_and_cov.yml +++ b/.github/workflows/test_and_cov.yml @@ -38,7 +38,7 @@ jobs: run: pdm config use_uv true - name: install pdm and dependencies - run: make deps + run: EXTRA_LOCK_ARGS="--group chroma0" make deps - name: Set custom HF cache directory run: | @@ -47,18 +47,12 @@ jobs: mkdir -p "$HF_HOME" [ -z "$(ls "$HF_HOME")" ] || rm "${HF_HOME:?}/*" -rf && true - - name: run tests - run: pdm run pytest --enable-coredumpy --coredumpy-dir ${{ env.COREDUMPY_DUMP_DIR }} - - name: run coverage run: | - pdm run coverage run -m pytest + sh ./scripts/coverage.sh pdm run coverage report -m pdm run coverage xml -i - - name: static analysis by basedpyright - run: pdm run basedpyright - - name: upload coverage reports to codecov uses: codecov/codecov-action@v5 with: diff --git a/Makefile b/Makefile index 048288b8..e4409f4e 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,17 @@ -.PHONY: multitest +EXTRA_LOCK_ARGS?= +EXTRA_DEPS?= +EXTRA_COVERAGEPY_ARGS?= + +LOADED_DOT_ENV=@if [ -f .env ] ; then source .env; fi; -DEFAULT_GROUPS=--group dev --group lsp --group mcp --group debug +DEFAULT_GROUPS=--group dev --group lsp --group mcp --group debug $(EXTRA_LOCK_ARGS) + +.PHONY: multitest deps: pdm lock $(DEFAULT_GROUPS) || pdm lock $(DEFAULT_GROUPS) --group legacy; \ pdm install + [ -z "$(EXTRA_DEPS)" ] || (pdm run python -m ensurepip && pdm run python -m pip install $(EXTRA_DEPS)) test: make deps; \ @@ -18,7 +25,7 @@ multitest: coverage: make deps; \ - pdm run coverage run -m pytest; \ + pdm run coverage run $(EXTRA_COVERAGEPY_ARGS) -m pytest --enable-coredumpy --coredumpy-dir dumps; \ pdm run coverage html; \ pdm run coverage report -m diff --git a/doc/VectorCode-cli.txt b/doc/VectorCode-cli.txt index a6680e11..5717e8ae 100644 --- a/doc/VectorCode-cli.txt +++ b/doc/VectorCode-cli.txt @@ -16,7 +16,6 @@ Table of Contents *VectorCode-cli-table-of-contents* - |VectorCode-cli-installation| - |VectorCode-cli-install-from-source| - - |VectorCode-cli-migration-from-`pipx`| - |VectorCode-cli-chromadb| - |VectorCode-cli-for-windows-users| - |VectorCode-cli-legacy-environments| @@ -66,7 +65,7 @@ virtual environments. After installing `uv`, run: >bash - uv tool install "vectorcode<1.0.0" + uv tool install "vectorcode[chroma0]" < in your shell. To specify a particular version of Python, use the `--python` @@ -76,40 +75,25 @@ If you want a CPU-only installation without CUDA dependencies required by default by PyTorch, run: >bash - uv tool install "vectorcode<1.0.0" --index https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match + uv tool install "vectorcode[chroma0]" --index https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match < If you need to install multiple dependency group (for |VectorCode-cli-lsp| or |VectorCode-cli-mcp|), you can use the following syntax: >bash - uv tool install "vectorcode[lsp,mcp]<1.0.0" + uv tool install "vectorcode[lsp,mcp,chroma0]" < [!NOTE] The command only install VectorCode and `SentenceTransformer`, the default embedding engine. If you need to install an extra dependency, you can - use `uv tool install vectorcode --with ` + use `uv tool install vectorcode[chroma0] --with ` INSTALL FROM SOURCE ~ -To install from source, either `git clone` this repository and run `uv tool -install `, or use `pipx`: - ->bash - pipx install git+https://github.com/Davidyz/VectorCode -< - - -MIGRATION FROM PIPX ~ - -The motivation behind the change from `pipx` to `uv tool` is mainly the -performance. The caching mechanism in uv makes it a lot faster than `pipx` for -a lot of operations. If you installed VectorCode via `pipx`, you can continue -to use `pipx` to manage your VectorCode installation. If you wish to switch to -`uv`, you need to uninstall VectorCode using `pipx` and then use `uv` to -install it as described above. All your VectorCode configurations and database -files will work out of the box on your new install. +To install from source, please `git clone` this repository and run `uv tool +install `. CHROMADB ~ @@ -124,9 +108,6 @@ instructions through docker significantly reduce the IO overhead and avoid potential race condition. - If you’re setting up a standalone ChromaDB server, I recommend sticking to - v0.6.3, because VectorCode is not ready for the upgrade to ChromaDB 1.0 yet. - FOR WINDOWS USERS ~ Windows support is not officially tested at this moment. This PR @@ -309,7 +290,10 @@ extension, the json5 syntax will be accepted. This allows you to leave trailing comma in the config file, as well as writing comments (`//`). This can be very useful if you’re experimenting with the configs. -The JSON configuration file may hold the following values: - +The JSON configuration file may hold the following values: - `db_type`: string, +default: `"ChromaDB0Connector"` (for chromadb 0.6.3), the database backend to +use; - `db_params`: dictionary. See the database connector documentation +<../src/vectorcode/database/README.md> for the default values; - `embedding_function`: string, one of the embedding functions supported by Chromadb (find more here and here @@ -329,38 +313,29 @@ model_name="nomic-embed-text")`. Default: `{}`; - `embedding_dims`: integer or model supports Matryoshka Representation Learning (MRL) before using this._ Learn more about MRL here . -When set to `null` (or unset), the embeddings won’t be truncated; - `db_url`: -string, the url that points to the Chromadb server. VectorCode will start an -HTTP server for Chromadb at a randomly picked free port on `localhost` if your -configured `http://host:port` is not accessible. Default: -`http://127.0.0.1:8000`; - `db_path`: string, Path to local persistent -database. If you didn’t set up a standalone Chromadb server, this is where -the files for your database will be stored. Default: -`~/.local/share/vectorcode/chromadb/`; - `db_log_path`: string, path to the -_directory_ where the built-in chromadb server will write the log to. Default: -`~/.local/share/vectorcode/`; - `chunk_size`: integer, the maximum number of -characters per chunk. A larger value reduces the number of items in the -database, and hence accelerates the search, but at the cost of potentially -truncated data and lost information. Default: `2500`. To disable chunking, set -it to a negative number; - `overlap_ratio`: float between 0 and 1, the ratio of -overlapping/shared content between 2 adjacent chunks. A larger ratio improves -the coherence of chunks, but at the cost of increasing number of entries in the -database and hence slowing down the search. Default: `0.2`. _Starting from -0.4.11, VectorCode will use treesitter to parse languages that it can -automatically detect. It uses pygments to guess the language from filename, and -tree-sitter-language-pack to fetch the correct parser. overlap_ratio has no -effects when treesitter works. If VectorCode fails to find an appropriate -parser, it’ll fallback to the legacy naive parser, in which case -overlap_ratio works exactly in the same way as before;_ - `query_multiplier`: -integer, when you use the `query` command to retrieve `n` documents, VectorCode -will check `n * query_multiplier` chunks and return at most `n` documents. A -larger value of `query_multiplier` guarantees the return of `n` documents, but -with the risk of including too many less-relevant chunks that may affect the -document selection. Default: `-1` (any negative value means selecting documents -based on all indexed chunks); - `reranker`: string, the reranking method to -use. Currently supports `NaiveReranker` (sort chunks by the "distance" between -the embedding vectors) and `CrossEncoderReranker` (using sentence-transformers -cross-encoder +When set to `null` (or unset), the embeddings won’t be truncated; - +`chunk_size`: integer, the maximum number of characters per chunk. A larger +value reduces the number of items in the database, and hence accelerates the +search, but at the cost of potentially truncated data and lost information. +Default: `2500`. To disable chunking, set it to a negative number; - +`overlap_ratio`: float between 0 and 1, the ratio of overlapping/shared content +between 2 adjacent chunks. A larger ratio improves the coherence of chunks, but +at the cost of increasing number of entries in the database and hence slowing +down the search. Default: `0.2`. _Starting from 0.4.11, VectorCode will use +treesitter to parse languages that it can automatically detect. It uses +pygments to guess the language from filename, and tree-sitter-language-pack to +fetch the correct parser. overlap_ratio has no effects when treesitter works. +If VectorCode fails to find an appropriate parser, it’ll fallback to the +legacy naive parser, in which case overlap_ratio works exactly in the same way +as before;_ - `query_multiplier`: integer, when you use the `query` command to +retrieve `n` documents, VectorCode will check `n * query_multiplier` chunks and +return at most `n` documents. A larger value of `query_multiplier` guarantees +the return of `n` documents, but with the risk of including too many +less-relevant chunks that may affect the document selection. Default: `-1` (any +negative value means selecting documents based on all indexed chunks); - +`reranker`: string, the reranking method to use. Currently supports +`NaiveReranker` (sort chunks by the "distance" between the embedding vectors) +and `CrossEncoderReranker` (using sentence-transformers cross-encoder ). - `reranker_params`: dictionary, similar to `embedding_params`. The options passed to the reranker class constructor. For `CrossEncoderReranker`, these are @@ -368,23 +343,15 @@ the options passed to the `CrossEncoder` class. For example, if you want to use a non-default model, you can use the following: `json { "reranker_params": { "model_name_or_path": "your_model_here" -} }` - `db_settings`: dictionary, works in a similar way to `embedding_params`, -but for Chromadb client settings so that you can configure authentication for -remote Chromadb ; - -`hnsw`: a dictionary of hnsw settings - that may -improve the query performances or avoid runtime errors during queries. **It’s -recommended to re-vectorise the collection after modifying these options, -because some of the options can only be set during collection creation.** -Example (and default): `json5 "hnsw": { "hnsw:M": 64, }` - `filetype_map`: -`dict[str, list[str]]`, a dictionary where keys are language name +} }` - `filetype_map`: `dict[str, list[str]]`, a dictionary where keys are +language name and values are lists of Python regex patterns that will match file extensions. This allows overriding automatic language detection and specifying a treesitter parser for certain file types for which the language parser cannot be correctly identified (e.g., `.phtml` files containing both php and html). Example -configuration: `json5 "filetype_map": { "php": ["^phtml$"] }` +configuration: `json5 { "filetype_map": { "php": ["^phtml$"], }, }` - `chunk_filters`: `dict[str, list[str]]`, a dictionary where the keys are language name @@ -395,10 +362,12 @@ configuration: `json5 "filetype_map": { "php": ["^phtml$"] }` treesitter chunker. By default, no filters will be added. Example configuration: >json5 - "chunk_filters": { - "python": ["^[^a-zA-Z0-9]+$"], // multiple patterns will be merged (unioned) - // or you can use wildcard to match any languages that has no dedicated filters: - "*": ["^[^a-zA-Z0-9]+$"], + { + "chunk_filters": { + "python": ["^[^a-zA-Z0-9]+$"], // multiple patterns will be merged (unioned) + // or you can use wildcard to match any languages that has no dedicated filters: + "*": ["^[^a-zA-Z0-9]+$"], + }, } < - `encoding`: string, alternative encoding used for this project. By default this @@ -743,13 +712,13 @@ A JSON array of collection information of the following format will be printed: >json { - "project_root": str, - "user": str, - "hostname": str, - "collection_name": str, - "size": int, - "num_files": int, - "embedding_function": str + "project_root": "project_root", + "user": "user", + "hostname": "host", + "collection_name": "fuerbvo13571943ofuib", + "size": 10, + "num_files": 100, + "embedding_function": "SomeEmbeddingFunction" } < diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 8a323ed4..f8686d11 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -40,6 +40,41 @@ You may also find it helpful to [enable logging](https://github.com/Davidyz/VectorCode/blob/main/docs/cli.md#debugging-and-diagnosing) for the CLI when developing new features or working on fixes. +### Local Dependencies + +Sometimes you want to run `make deps` that install non-default dependencies. The +`Makefile` provides easy ways to do that. + +When you want to install a dependency group of VectorCode: + +```bash +EXTRA_LOCK_ARGS="--group chroma0" make deps +``` + +When you want to install a library that is not declared in any of the dependency +groups (like `openai`): + +```bash +EXTRA_DEPS="openai\<2.0.0" make deps +``` + +Both environment variables apply to `make deps`, `make test` and `make coverage`. + +### Database Connectors + +Please take a look at [the database documentation](../src/vectorcode/database/README.md), +which contains a brief introduction on the API design that explains what you'd need +to do to add support for a new database. + +### Coverage Across Mutiple Runs + +If, for some reasons, you need to run the tests multiple times to get full coverage +(maybe when there are conflicting dependency groups like chromadb 0.6.3 vs chromadb 1.x), you can pass `--append` flag to the `coverage` command. +If you're using `make coverage`, you can set this flag via the `EXTRA_COVERAGEPY_ARGS` environment variable: +```bash +EXTRA_COVERAGEPY_ARGS="--append" make coverage +``` + ## Neovim Plugin At the moment, there isn't much to cover on here. As long as the code is diff --git a/docs/cli.md b/docs/cli.md index e3d2f53c..38aa6fda 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -5,7 +5,6 @@ * [Installation](#installation) * [Install from Source](#install-from-source) - * [Migration from `pipx`](#migration-from-pipx) * [Chromadb](#chromadb) * [For Windows Users](#for-windows-users) * [Legacy Environments](#legacy-environments) @@ -55,7 +54,7 @@ your system Python or project-local virtual environments. After installing `uv`, run: ```bash -uv tool install "vectorcode<1.0.0" +uv tool install "vectorcode[chroma0]" ``` in your shell. To specify a particular version of Python, use the `--python` flag. For example, `uv tool install vectorcode --python python3.11`. For hardware @@ -63,36 +62,23 @@ accelerated embedding, refer to [the relevant section](#hardware-acceleration). If you want a CPU-only installation without CUDA dependencies required by default by PyTorch, run: ```bash -uv tool install "vectorcode<1.0.0" --index https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match +uv tool install "vectorcode[chroma0]" --index https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match ``` If you need to install multiple dependency group (for [LSP](#lsp-mode) or [MCP](#mcp-server)), you can use the following syntax: ```bash -uv tool install "vectorcode[lsp,mcp]<1.0.0" +uv tool install "vectorcode[lsp,mcp,chroma0]" ``` > [!NOTE] > The command only install VectorCode and `SentenceTransformer`, the default > embedding engine. If you need to install an extra dependency, you can use -> `uv tool install vectorcode --with ` +> `uv tool install vectorcode[chroma0] --with ` ### Install from Source -To install from source, either `git clone` this repository and run `uv tool install -`, or use `pipx`: -```bash -pipx install git+https://github.com/Davidyz/VectorCode -``` - -### Migration from `pipx` - -The motivation behind the change from `pipx` to `uv tool` is mainly the -performance. The caching mechanism in uv makes it a lot faster than `pipx` for a -lot of operations. If you installed VectorCode via `pipx`, you can continue to -use `pipx` to manage your VectorCode installation. If you wish to switch to -`uv`, you need to uninstall VectorCode using `pipx` and then use `uv` to install -it as described above. All your VectorCode configurations and database files -will work out of the box on your new install. +To install from source, please `git clone` this repository and run `uv tool install +`. ### Chromadb [Chromadb](https://www.trychroma.com/) is the vector database used by VectorCode @@ -103,10 +89,6 @@ set up a standalone local server (they provides detailed instructions through [systemd](https://cookbook.chromadb.dev/running/systemd-service/)), because this will significantly reduce the IO overhead and avoid potential race condition. -> If you're setting up a standalone ChromaDB server, I recommend sticking to -> v0.6.3, -> because VectorCode is not ready for the upgrade to ChromaDB 1.0 yet. - ### For Windows Users Windows support is not officially tested at this moment. [This PR](https://github.com/Davidyz/VectorCode/pull/40) @@ -261,7 +243,12 @@ be accepted. This allows you to leave trailing comma in the config file, as well as writing comments (`//`). This can be very useful if you're experimenting with the configs. -The JSON configuration file may hold the following values: +The JSON configuration file may hold the following values: +- `db_type`: string, default: `"ChromaDB0Connector"` (for chromadb 0.6.3), the + database backend to use; +- `db_params`: dictionary. See + [the database connector documentation](../src/vectorcode/database/README.md) for the + default values; - `embedding_function`: string, one of the embedding functions supported by [Chromadb](https://www.trychroma.com/) (find more [here](https://docs.trychroma.com/docs/embeddings/embedding-functions) and [here](https://docs.trychroma.com/integrations/chroma-integrations)). For @@ -282,14 +269,6 @@ The JSON configuration file may hold the following values: to. _Make sure your model supports Matryoshka Representation Learning (MRL) before using this._ Learn more about MRL [here](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings). When set to `null` (or unset), the embeddings won't be truncated; -- `db_url`: string, the url that points to the Chromadb server. VectorCode will start an - HTTP server for Chromadb at a randomly picked free port on `localhost` if your - configured `http://host:port` is not accessible. Default: `http://127.0.0.1:8000`; -- `db_path`: string, Path to local persistent database. If you didn't set up a standalone - Chromadb server, this is where the files for your database will be stored. - Default: `~/.local/share/vectorcode/chromadb/`; -- `db_log_path`: string, path to the _directory_ where the built-in chromadb - server will write the log to. Default: `~/.local/share/vectorcode/`; - `chunk_size`: integer, the maximum number of characters per chunk. A larger value reduces the number of items in the database, and hence accelerates the search, but at the cost of potentially truncated data and lost information. @@ -330,32 +309,20 @@ The JSON configuration file may hold the following values: } } ``` -- `db_settings`: dictionary, works in a similar way to `embedding_params`, but - for Chromadb client settings so that you can configure - [authentication for remote Chromadb](https://docs.trychroma.com/production/administration/auth); -- `hnsw`: a dictionary of - [hnsw settings](https://cookbook.chromadb.dev/core/configuration/#hnsw-configuration) - that may improve the query performances or avoid runtime errors during - queries. **It's recommended to re-vectorise the collection after modifying these - options, because some of the options can only be set during collection - creation.** Example (and default): +- `filetype_map`: `dict[str, list[str]]`, a dictionary where keys are + [language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages) + and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html) + that will match file extensions. This allows overriding automatic language + detection and specifying a treesitter parser for certain file types for which the language parser cannot be + correctly identified (e.g., `.phtml` files containing both php and html). + Example configuration: ```json5 - "hnsw": { - "hnsw:M": 64, + { + "filetype_map": { + "php": ["^phtml$"], + }, } ``` -- `filetype_map`: `dict[str, list[str]]`, a dictionary where keys are - [language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages) - and values are lists of [Python regex patterns](https://docs.python.org/3/library/re.html) - that will match file extensions. This allows overriding automatic language - detection and specifying a treesitter parser for certain file types for which the language parser cannot be - correctly identified (e.g., `.phtml` files containing both php and html). - Example configuration: - ```json5 - "filetype_map": { - "php": ["^phtml$"] - } - ``` - `chunk_filters`: `dict[str, list[str]]`, a dictionary where the keys are [language name](https://github.com/Goldziher/tree-sitter-language-pack?tab=readme-ov-file#available-languages) @@ -364,10 +331,12 @@ The JSON configuration file may hold the following values: to languages supported by treesitter chunker. By default, no filters will be added. Example configuration: ```json5 - "chunk_filters": { - "python": ["^[^a-zA-Z0-9]+$"], // multiple patterns will be merged (unioned) - // or you can use wildcard to match any languages that has no dedicated filters: - "*": ["^[^a-zA-Z0-9]+$"], + { + "chunk_filters": { + "python": ["^[^a-zA-Z0-9]+$"], // multiple patterns will be merged (unioned) + // or you can use wildcard to match any languages that has no dedicated filters: + "*": ["^[^a-zA-Z0-9]+$"], + }, } ``` - `encoding`: string, alternative encoding used for this project. By default @@ -669,13 +638,13 @@ The output is in JSON format. It contains a dictionary with the following fields A JSON array of collection information of the following format will be printed: ```json { - "project_root": str, - "user": str, - "hostname": str, - "collection_name": str, - "size": int, - "num_files": int, - "embedding_function": str + "project_root": "project_root", + "user": "user", + "hostname": "host", + "collection_name": "fuerbvo13571943ofuib", + "size": 10, + "num_files": 100, + "embedding_function": "SomeEmbeddingFunction" } ``` - `"project_root"`: the path to the `project-root`; diff --git a/pyproject.toml b/pyproject.toml index 01fd3b28..9a3762c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,56 +1,70 @@ +[build-system] +build-backend = "pdm.backend" + +requires = [ "pdm-backend" ] + [project] -name = "VectorCode" -dynamic = ["version"] +name = "vectorcode" description = "A tool to vectorise repositories for RAG." -authors = [{ name = "Davidyz", email = "hzjlyz@gmail.com" }] -dependencies = [ - "chromadb<=0.6.3", - "sentence-transformers", - "pathspec", - "tabulate", - "shtab", - "numpy", - "psutil", - "httpx", - "tree-sitter!=0.25.0", - "tree-sitter-language-pack", - "pygments", - "transformers>=4.36.0,!=4.51.0,!=4.51.1,!=4.51.2", - "wheel<0.46.0", - "colorlog", - "charset-normalizer>=3.4.1", - "json5", - "posthog<6.0.0", - "filelock>=3.15.0", -] -requires-python = ">=3.11,<3.14" readme = "README.md" license = { text = "MIT" } -[project.urls] -homepage = "https://github.com/Davidyz/VectorCode" -github = "https://github.com/Davidyz/VectorCode" -documentation = "https://github.com/Davidyz/VectorCode/blob/main/docs/cli.md" +authors = [ { name = "Davidyz", email = "hzjlyz@gmail.com" } ] +requires-python = ">=3.11,<3.14" +classifiers = [ + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dynamic = [ "version" ] +dependencies = [ + "charset-normalizer>=3.4.1", + "chromadb", + "colorlog", + "filelock>=3.15", + "httpx", + "json5", + "numpy", + "pathspec", + "posthog<6", + "psutil", + "pygments", + "sentence-transformers", + "shtab", + "tabulate", + "transformers>=4.36,!=4.51,!=4.51.1,!=4.51.2", + "tree-sitter!=0.25", + "tree-sitter-language-pack", + "wheel<0.46", +] +optional-dependencies.chroma0 = [ "chromadb==0.6.3" ] +optional-dependencies.debug = [ "coredumpy>=0.4.1" ] +optional-dependencies.intel = [ "openvino", "optimum[openvino]" ] +optional-dependencies.legacy = [ "numpy<2", "torch==2.2.2", "transformers<=4.49" ] +optional-dependencies.lsp = [ "lsprotocol", "pygls<2" ] +optional-dependencies.mcp = [ "mcp<2", "pydantic" ] -[project.scripts] -vectorcode = "vectorcode.main:main" -vectorcode-server = "vectorcode.lsp_main:main" -vectorcode-mcp-server = "vectorcode.mcp_main:main" +urls.documentation = "https://github.com/Davidyz/VectorCode/blob/main/docs/cli.md" +urls.github = "https://github.com/Davidyz/VectorCode" +urls.homepage = "https://github.com/Davidyz/VectorCode" -[build-system] -requires = ["pdm-backend"] -build-backend = "pdm.backend" +scripts.vectorcode = "vectorcode.main:main" +scripts.vectorcode-mcp-server = "vectorcode.mcp_main:main" +scripts.vectorcode-server = "vectorcode.lsp_main:main" -[tool.coverage.run] -omit = [ - "./tests/*", - "src/vectorcode/_version.py", - "src/vectorcode/__init__.py", - "src/vectorcode/debugging.py", - "/tmp/*", +[dependency-groups] +dev = [ + "basedpyright>=1.29.2", + "coverage>=7.6.12", + "debugpy>=1.8.12", + "ipython>=8.31", + "pdm-backend>=2.4.3", + "pre-commit>=4.0.1", + "pytest>=8.3.4", + "pytest-asyncio>=0.25.3", + "ruff>=0.9.1", ] -include = ['src/vectorcode/**/*.py'] - [tool.pdm] distribution = true @@ -60,26 +74,16 @@ source = "scm" write_to = "./vectorcode/_version.py" write_template = "__version__ = '{}' # pragma: no cover" -[dependency-groups] -dev = [ - "ipython>=8.31.0", - "ruff>=0.9.1", - "pre-commit>=4.0.1", - "pytest>=8.3.4", - "pdm-backend>=2.4.3", - "coverage>=7.6.12", - "pytest-asyncio>=0.25.3", - "debugpy>=1.8.12", - "basedpyright>=1.29.2", +[tool.coverage.run] +omit = [ + "./tests/*", + "src/vectorcode/_version.py", + "src/vectorcode/__init__.py", + "src/vectorcode/debugging.py", + "/tmp/*", ] - -[project.optional-dependencies] -legacy = ["numpy<2.0.0", "torch==2.2.2", "transformers<=4.49.0"] -intel = ['optimum[openvino]', 'openvino'] -lsp = ['pygls<2.0.0', 'lsprotocol'] -mcp = ['mcp<2.0.0', 'pydantic'] -debug = ["coredumpy>=0.4.1"] +include = [ 'src/vectorcode/**/*.py' ] [tool.basedpyright] typeCheckingMode = "standard" -ignore = ["./tests/"] +ignore = [ "./tests/" ] diff --git a/scripts/coverage.sh b/scripts/coverage.sh new file mode 100755 index 00000000..de32acc6 --- /dev/null +++ b/scripts/coverage.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +export EXTRA_COVERAGEPY_ARGS='--append' + +make deps + +pdm run coverage erase + +# chroma 0.6.3 +EXTRA_LOCK_ARGS="--group chroma0" make coverage +# default install (chroma 1.x) +make coverage + +pdm run coverage report -m diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 0131a5e2..ee85dd34 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -4,11 +4,22 @@ import logging import os import sys +from asyncio import Lock from dataclasses import dataclass, field, fields from datetime import datetime from enum import Enum, StrEnum from pathlib import Path -from typing import Any, Generator, Iterable, Optional, Sequence, Union +from typing import ( + Any, + Generator, + Iterable, + Literal, + Optional, + Sequence, + Type, + Union, + overload, +) import json5 import shtab @@ -87,15 +98,13 @@ class Config: files: list[Union[str, os.PathLike]] = field(default_factory=list) project_root: Optional[Union[str, Path]] = None query: Optional[list[str]] = None - db_url: str = "http://127.0.0.1:8000" + db_type: str = "ChromaDB0" + db_params: dict[str, Any] = field(default_factory=dict) embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is. embedding_params: dict[str, Any] = field(default_factory=(lambda: {})) embedding_dims: Optional[int] = None n_result: int = 1 force: bool = False - db_path: Optional[str] = "~/.local/share/vectorcode/chromadb/" - db_log_path: str = "~/.local/share/vectorcode/" - db_settings: Optional[dict] = None chunk_size: int = 2500 overlap_ratio: float = 0.2 query_multiplier: int = -1 @@ -107,7 +116,6 @@ class Config: include: list[QueryInclude] = field( default_factory=lambda: [QueryInclude.path, QueryInclude.document] ) - hnsw: dict[str, str | int] = field(default_factory=dict) chunk_filters: dict[str, list[str]] = field(default_factory=dict) filetype_map: dict[str, list[str]] = field(default_factory=dict) encoding: str = "utf8" @@ -116,7 +124,7 @@ class Config: files_action: Optional[FilesAction] = None rm_paths: list[str] = field(default_factory=list) - def __hash__(self) -> int: + def __hash__(self) -> int: # pragma: nocover return hash(self.__repr__()) @classmethod @@ -125,14 +133,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": Raise IOError if db_path is not valid. """ default_config = Config() - db_path = config_dict.get("db_path") - if db_path is None: - db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") - elif not os.path.isdir(db_path): - raise IOError( - f"The configured db_path ({str(db_path)}) is not a valid directory." - ) return Config( **{ "embedding_function": config_dict.get( @@ -144,11 +145,8 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": "embedding_dims": config_dict.get( "embedding_dims", default_config.embedding_dims ), - "db_url": config_dict.get("db_url", default_config.db_url), - "db_path": db_path, - "db_log_path": os.path.expanduser( - config_dict.get("db_log_path", default_config.db_log_path) - ), + "db_type": config_dict.get("db_type", default_config.db_type), + "db_params": config_dict.get("db_params", default_config.db_params), "chunk_size": config_dict.get("chunk_size", default_config.chunk_size), "overlap_ratio": config_dict.get( "overlap_ratio", default_config.overlap_ratio @@ -160,10 +158,6 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": "reranker_params": config_dict.get( "reranker_params", default_config.reranker_params ), - "db_settings": config_dict.get( - "db_settings", default_config.db_settings - ), - "hnsw": config_dict.get("hnsw", default_config.hnsw), "chunk_filters": config_dict.get( "chunk_filters", default_config.chunk_filters ), @@ -599,7 +593,7 @@ async def expand_globs( def cleanup_path(path: str): - if os.path.isabs(path) and os.environ.get("HOME") is not None: + if os.path.isabs(path) and os.environ.get("HOME", "") != "": return path.replace(os.environ["HOME"], "~") return path @@ -661,12 +655,15 @@ def config_logging( ) +LockType = AsyncFileLock | Lock + + class LockManager: """ A class that manages file locks that protects the database files in daemon processes (LSP, MCP). """ - __locks: dict[str, AsyncFileLock] + __locks: dict[tuple[str, Type[LockType]], LockType] singleton: Optional["LockManager"] = None def __new__(cls) -> "LockManager": @@ -675,7 +672,23 @@ def __new__(cls) -> "LockManager": cls.singleton.__locks = {} return cls.singleton - def get_lock(self, path: str | os.PathLike) -> AsyncFileLock: + @overload + def get_lock( + self, path: str | os.PathLike, lock_type_name: Literal["asyncio"] + ) -> Lock: ... + + @overload + def get_lock( + self, + path: str | os.PathLike, + lock_type_name: Literal["filelock"] | None, + ) -> AsyncFileLock: ... + + def get_lock( + self, + path: str | os.PathLike, + lock_type_name: Literal["filelock"] | Literal["asyncio"] | None = "filelock", + ): path = str(expand_path(str(path), True)) if os.path.isdir(path): lock_file = os.path.join(path, "vectorcode.lock") @@ -684,9 +697,19 @@ def get_lock(self, path: str | os.PathLike) -> AsyncFileLock: with open(lock_file, mode="w") as fin: fin.write("") path = lock_file - if self.__locks.get(path) is None: - self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType] - return self.__locks[path] + lock: LockType + match lock_type_name: + case "filelock": + lock = AsyncFileLock(path) # pyright: ignore[reportAssignmentType] + case "asyncio": + lock = Lock() + case _: # pragma: nocover + raise ValueError(f"Unsupported lock type: {lock_type_name}") + + cache_key = (path, type(lock)) + if self.__locks.get(cache_key) is None: + self.__locks[cache_key] = lock + return self.__locks[cache_key] class SpecResolver: @@ -718,6 +741,7 @@ def from_path(cls, spec_path: str, project_root: Optional[str] = None): return cls(spec_path, base_dir) def __init__(self, spec: str | GitIgnoreSpec, base_dir: str = "."): + self.spec: GitIgnoreSpec if isinstance(spec, str): with open(spec) as fin: self.spec = GitIgnoreSpec.from_lines( @@ -725,20 +749,21 @@ def __init__(self, spec: str | GitIgnoreSpec, base_dir: str = "."): ) else: self.spec = spec - self.base_dir = base_dir + self.base_dir = Path(base_dir).resolve() + + def match_file(self, path: str, negated: bool = False) -> bool: + if self.base_dir in Path(path).resolve().parents: + matched = self.spec.match_file(os.path.relpath(path, self.base_dir)) + if negated: + matched = not matched + return matched + return True def match( self, paths: Iterable[str], negated: bool = False ) -> Generator[str, None, None]: # get paths relative to `base_dir` - base = Path(self.base_dir).resolve() for p in paths: - if base in Path(p).resolve().parents: - should_yield = self.spec.match_file(os.path.relpath(p, self.base_dir)) - if negated: - should_yield = not should_yield - if should_yield: - yield p - else: + if self.match_file(p, negated): yield p diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py deleted file mode 100644 index 94443f4f..00000000 --- a/src/vectorcode/common.py +++ /dev/null @@ -1,320 +0,0 @@ -import asyncio -import contextlib -import hashlib -import logging -import os -import socket -import subprocess -import sys -from asyncio.subprocess import Process -from dataclasses import dataclass -from functools import cache -from typing import Any, AsyncGenerator, Optional -from urllib.parse import urlparse - -import chromadb -import httpx -from chromadb.api import AsyncClientAPI -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.types import IncludeEnum -from chromadb.config import APIVersion, Settings -from chromadb.utils import embedding_functions - -from vectorcode.cli_utils import Config, LockManager, expand_path - -logger = logging.getLogger(name=__name__) - - -async def get_collections( - client: AsyncClientAPI, -) -> AsyncGenerator[AsyncCollection, None]: - for collection_name in await client.list_collections(): - collection = await client.get_collection(collection_name, None) - meta = collection.metadata - if meta is None: - continue - if meta.get("created-by") != "VectorCode": - continue - if meta.get("username") not in ( - os.environ.get("USER"), - os.environ.get("USERNAME"), - "DEFAULT_USER", - ): - continue - if meta.get("hostname") != socket.gethostname(): - continue - yield collection - - -async def try_server(base_url: str): - for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb. - heartbeat_url = f"{base_url}/api/{ver}/heartbeat" - try: - async with httpx.AsyncClient() as client: - response = await client.get(url=heartbeat_url) - logger.debug(f"Heartbeat {heartbeat_url} returned {response=}") - if response.status_code == 200: - return True - except (httpx.ConnectError, httpx.ConnectTimeout): - pass - return False - - -async def wait_for_server(url: str, timeout=10): - # Poll the server until it's ready or timeout is reached - - start_time = asyncio.get_event_loop().time() - while True: - if await try_server(url): - return - - if asyncio.get_event_loop().time() - start_time > timeout: - raise TimeoutError(f"Server did not start within {timeout} seconds.") - - await asyncio.sleep(0.1) # Wait before retrying - - -async def start_server(configs: Config): - assert configs.db_path is not None - db_path = os.path.expanduser(configs.db_path) - configs.db_log_path = os.path.expanduser(configs.db_log_path) - if not os.path.isdir(configs.db_log_path): - os.makedirs(configs.db_log_path) - if not os.path.isdir(db_path): - logger.warning( - f"Using local database at {os.path.expanduser('~/.local/share/vectorcode/chromadb/')}.", - ) - db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") - env = os.environ.copy() - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # OS selects a free ephemeral port - port = int(s.getsockname()[1]) - - server_url = f"http://127.0.0.1:{port}" - logger.warning(f"Starting bundled ChromaDB server at {server_url}.") - env.update({"ANONYMIZED_TELEMETRY": "False"}) - process = await asyncio.create_subprocess_exec( - sys.executable, - "-m", - "chromadb.cli.cli", - "run", - "--host", - "localhost", - "--port", - str(port), - "--path", - db_path, - "--log-path", - os.path.join(str(configs.db_log_path), "chroma.log"), - stdout=subprocess.DEVNULL, - stderr=sys.stderr, - env=env, - ) - - await wait_for_server(server_url) - configs.db_url = server_url - return process - - -def get_collection_name(full_path: str) -> str: - full_path = str(expand_path(full_path, absolute=True)) - hasher = hashlib.sha256() - plain_collection_name = f"{os.environ.get('USER', os.environ.get('USERNAME', 'DEFAULT_USER'))}@{socket.gethostname()}:{full_path}" - hasher.update(plain_collection_name.encode()) - collection_id = hasher.hexdigest()[:63] - logger.debug( - f"Hashing {plain_collection_name} as the collection name for {full_path}." - ) - return collection_id - - -@cache -def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction: - try: - ef = getattr(embedding_functions, configs.embedding_function)( - **configs.embedding_params - ) - if ef is None: # pragma: nocover - raise AttributeError() - return ef - except AttributeError: - logger.warning( - f"Failed to use {configs.embedding_function}. Falling back to Sentence Transformer.", - ) - return embedding_functions.SentenceTransformerEmbeddingFunction() # type:ignore - except Exception as e: - e.add_note( - "\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment." - ) - logger.error( - f"Failed to use {configs.embedding_function} with following error.", - ) - raise - - -__COLLECTION_CACHE: dict[str, AsyncCollection] = {} - - -async def get_collection( - client: AsyncClientAPI, configs: Config, make_if_missing: bool = False -): - """ - Raise ValueError when make_if_missing is False and no collection is found; - Raise IndexError on hash collision. - """ - assert configs.project_root is not None - full_path = str(expand_path(str(configs.project_root), absolute=True)) - if __COLLECTION_CACHE.get(full_path) is None: - collection_name = get_collection_name(full_path) - - collection_meta: dict[str, str | int] = { - "path": full_path, - "hostname": socket.gethostname(), - "created-by": "VectorCode", - "username": os.environ.get( - "USER", os.environ.get("USERNAME", "DEFAULT_USER") - ), - "embedding_function": configs.embedding_function, - "hnsw:M": 64, - } - if configs.hnsw: - for key in configs.hnsw.keys(): - target_key = key - if not key.startswith("hnsw:"): - target_key = f"hnsw:{key}" - collection_meta[target_key] = configs.hnsw[key] - logger.debug( - f"Getting/Creating collection with the following metadata: {collection_meta}" - ) - if not make_if_missing: - __COLLECTION_CACHE[full_path] = await client.get_collection(collection_name) - else: - collection = await client.get_or_create_collection( - collection_name, - metadata=collection_meta, - ) - if ( - not collection.metadata.get("hostname") == socket.gethostname() - or collection.metadata.get("username") - not in ( - os.environ.get("USER"), - os.environ.get("USERNAME"), - "DEFAULT_USER", - ) - or not collection.metadata.get("created-by") == "VectorCode" - ): - logger.error( - f"Failed to use existing collection due to metadata mismatch: {collection_meta}" - ) - raise IndexError( - "Failed to create the collection due to hash collision. Please file a bug report." - ) - __COLLECTION_CACHE[full_path] = collection - return __COLLECTION_CACHE[full_path] - - -def verify_ef(collection: AsyncCollection, configs: Config): - collection_ef = collection.metadata.get("embedding_function") - collection_ep = collection.metadata.get("embedding_params") - if collection_ef and collection_ef != configs.embedding_function: - logger.error(f"The collection was embedded using {collection_ef}.") - logger.error( - "Embeddings and query must use the same embedding function and parameters. Please double-check your config." - ) - return False - elif collection_ep and collection_ep != configs.embedding_params: - logger.warning( - f"The collection was embedded with a different set of configurations: {collection_ep}. The result may be inaccurate.", - ) - return True - - -async def list_collection_files(collection: AsyncCollection) -> list[str]: - return sorted( - list( - set( - str(c.get("path", None)) - for c in (await collection.get(include=[IncludeEnum.metadatas])).get( - "metadatas" - ) - or [] - ) - ) - ) - - -@dataclass -class _ClientModel: - client: AsyncClientAPI - is_bundled: bool = False - process: Optional[Process] = None - - -class ClientManager: - singleton: Optional["ClientManager"] = None - __clients: dict[str, _ClientModel] - - def __new__(cls) -> "ClientManager": - if cls.singleton is None: - cls.singleton = super().__new__(cls) - cls.singleton.__clients = {} - return cls.singleton - - @contextlib.asynccontextmanager - async def get_client(self, configs: Config, need_lock: bool = True): - project_root = str(expand_path(str(configs.project_root), True)) - is_bundled = False - if self.__clients.get(project_root) is None: - process = None - if not await try_server(configs.db_url): - logger.info(f"Starting a new server at {configs.db_url}") - process = await start_server(configs) - is_bundled = True - - self.__clients[project_root] = _ClientModel( - client=await self._create_client(configs), - is_bundled=is_bundled, - process=process, - ) - lock = None - if self.__clients[project_root].is_bundled and need_lock: - lock = LockManager().get_lock(str(configs.db_path)) - logger.debug(f"Locking {configs.db_path}") - await lock.acquire() - yield self.__clients[project_root].client - if lock is not None: - logger.debug(f"Unlocking {configs.db_path}") - await lock.release() - - def get_processes(self) -> list[Process]: - return [i.process for i in self.__clients.values() if i.process is not None] - - async def kill_servers(self): - termination_tasks: list[asyncio.Task] = [] - for p in self.get_processes(): - logger.info(f"Killing bundled chroma server with PID: {p.pid}") - p.terminate() - termination_tasks.append(asyncio.create_task(p.wait())) - await asyncio.gather(*termination_tasks) - - async def _create_client(self, configs: Config) -> AsyncClientAPI: - settings: dict[str, Any] = {"anonymized_telemetry": False} - if isinstance(configs.db_settings, dict): - valid_settings = { - k: v for k, v in configs.db_settings.items() if k in Settings.__fields__ - } - settings.update(valid_settings) - parsed_url = urlparse(configs.db_url) - settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" - settings["chroma_server_http_port"] = parsed_url.port or 8000 - settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" - settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2 - settings_obj = Settings(**settings) - return await chromadb.AsyncHttpClient( - settings=settings_obj, - host=str(settings_obj.chroma_server_host), - port=int(settings_obj.chroma_server_http_port or 8000), - ) - - def clear(self): - self.__clients.clear() diff --git a/src/vectorcode/database/DEVELOPERS.md b/src/vectorcode/database/DEVELOPERS.md new file mode 100644 index 00000000..24be3e9d --- /dev/null +++ b/src/vectorcode/database/DEVELOPERS.md @@ -0,0 +1,80 @@ +# Database Connectors + +A database connector is a compatibility layer that converts data structures used by a +database to the ones that VectorCode works with. +The connector classes provide abstractions for VectorCode operations (`vectorise`, `query`, etc.), which enables the use of different database backends. + + + +* [Adding a New Database Connector](#adding-a-new-database-connector) +* [Key Implementation Details](#key-implementation-details) + * [The `Config` Object](#the-config-object) + * [Implementing Abstract Methods](#implementing-abstract-methods) + * [Error Handling](#error-handling) +* [Testing](#testing) + + + +## Adding a New Database Connector + +To add support for a new database backend, you will need to: + +1. **Implement a connector class**: Create a new file in this directory and implement a child class of `vectorcode.database.base.DatabaseConnectorBase`. You must implement all of its abstract methods. +2. **Write tests**: Add tests for your new connector in the `tests/database/` directory. The tests should mock the database's API and verify that your connector correctly converts data between the database's native format and VectorCode's data structures. +3. **Register your connector**: Add a new entry in the `get_database_connector` function in `src/vectorcode/database/__init__.py` to initialize your new connector. + +For a concrete example, refer to the implementation of `DatabaseConnectorBase` and the `ChromaDB0Connector`. + +## Key Implementation Details + +### The `Config` Object + +All settings for a connector are passed through a single `vectorcode.cli_utils.Config` object, which is available as `self._configs`. This includes: + +- **Database Settings**: The `db_type` string and `db_params` dictionary are used to configure the connection to the database backend. As a contributor, you should document the specific `db_params` your connector requires in the class's docstring. +- **Operation Parameters**: Parameters for operations like `query` or `vectorise` are also present in this object. + +The `self._configs` attribute is mutable and can be updated for subsequent operations, but the database connection settings (`db_type`, `db_params`) should not be changed after initialization. + +### Implementing Abstract Methods + +When implementing the abstract methods from `DatabaseConnectorBase`, you should: + +- Read the necessary parameters from the `self._configs` object. +- Perform the corresponding operation against the database. +- Return data in the format specified by the method's type hints (e.g., `QueryResult`, `CollectionInfo`). + +**Please refer to the docstrings in `DatabaseConnectorBase` for the specific API contract of each method.** +They contain detailed information about what each method is expected to do and what parameters it uses from the `Config` object. + +There are also some helper methods (non-abstract methods) in `DatabaseConnectorBase` that may be +helpful. +For example, `self.get_embedding(texts)` provides a convenient way to get the embeddings of some strings that takes all config parameters into account (embedding function, embedding dimension, etc.). +Using these methods helps keeping your implementation consistent with the overall design. + +### Error Handling + +Some exceptions raised by database backends are actually DB-agnostic (despite that each database may have different exception classes for them). +For example, ChromaDB 0.6.3 have a `chromadb.errors.InvalidCollectionException` class, which is raised when accessing a collection that doesn't exist. +This is not a chroma-specific error, but different database backends _probably_ have their own implementations. +For better error handling, it's recommended to wrap them into [VectorCode's in-house exception classes](./errors.py). +This ensures consistent error handling in the CLI and other clients. + +For example: +```python +from vectorcode.database.errors import CollectionNotFoundError + +try: + some_action_here() +except SomeCustomException as e: + raise CollectionNotFoundError("The collection was not found.") from e +``` + +## Testing + +The unit tests for database backends should go under [`tests/database/`](../../../tests/database/). +The tests should mock the request body and return values of the database. Integration +tests that interact with an actual database are out of scope for now. + +> The tests for the subcommands currently use mocked database connectors. They're not +> supposed to interact with live databases. diff --git a/src/vectorcode/database/README.md b/src/vectorcode/database/README.md new file mode 100644 index 00000000..2aeb48a0 --- /dev/null +++ b/src/vectorcode/database/README.md @@ -0,0 +1,40 @@ +# Database Configuration + +This document provides the `db_params` configuration for the supported database connectors in VectorCode. + +For instructions on how to add a new database connector, please refer to [DEVELOPERS.md](./DEVELOPERS.md). + + + + +* [ChromaDB (v0.6.3)](#chromadb-v063) + + + +## ChromaDB (v0.6.3) + +The `ChromaDB0Connector` is used for ChromaDB versions 0.6.3. + +- **`db_type`**: `"ChromaDB0"`. The `Connector` suffix is optional and will be added automatically. + +- **`db_params`**: + An example of the `db_params` for `ChromaDB0Connector` in your `config.json5`: + ```json5 + { + "db_params": { + "db_url": "http://127.0.0.1:8000", + "db_path": "~/.local/share/vectorcode/chromadb/", + "db_log_path": "~/.local/share/vectorcode/", + "db_settings": {}, + "hnsw": { + "hnsw:M": 64, + }, + }, + } + ``` + + - `db_url`: The URL of the ChromaDB server. Defaults to `"http://127.0.0.1:8000"`. + - `db_path`: Path to the directory where ChromaDB stores its data. Defaults to `"~/.local/share/vectorcode/chromadb/"`. + - `db_log_path`: Path to the directory for ChromaDB log files. Defaults to `"~/.local/share/vectorcode/"`. + - `db_settings`: Additional ChromaDB settings. You usually don't need to touch this, but in case you do, you can refer to [ChromaDB source](https://github.com/chroma-core/chroma/blob/a3b86a0302a385350a8f092a5f89a2dcdebcf6be/chromadb/config.py#L101) for details. Defaults to `{}`. + - `hnsw`: HNSW index parameters. Defaults to `{"hnsw:M": 64}`. diff --git a/src/vectorcode/database/__init__.py b/src/vectorcode/database/__init__.py new file mode 100644 index 00000000..55cf8a45 --- /dev/null +++ b/src/vectorcode/database/__init__.py @@ -0,0 +1,39 @@ +import logging +from typing import Type + +from vectorcode.cli_utils import Config +from vectorcode.database.base import DatabaseConnectorBase + +logger = logging.getLogger(name=__name__) + + +def get_database_connector(config: Config) -> DatabaseConnectorBase: + """ + It's CRUCIAL to keep the `import`s of the database connectors in the branches. + This allow them to be lazy-imported. This also allow us to keep the main package + lightweight because we don't have to include dependencies for EVERY database. + + > Raises a `ValueError` in case the database connector is not supported. + """ + cls: Type[DatabaseConnectorBase] | None = None + + if not config.db_type.endswith("Connector"): + config.db_type = f"{config.db_type}Connector" + logger.debug(f"Correcting the name of the db connector to {config.db_type}") + + match config.db_type: + case "ChromaDB0Connector": + from vectorcode.database.chroma0 import ChromaDB0Connector + + cls = ChromaDB0Connector + case "ChromaDBConnector": + from vectorcode.database.chroma import ChromaDBConnector + + cls = ChromaDBConnector + case _: + raise ValueError(f"Unrecognised database type: {config.db_type}") + + return cls.create(config) + + +__all__ = ["get_database_connector"] diff --git a/src/vectorcode/database/base.py b/src/vectorcode/database/base.py new file mode 100644 index 00000000..67b754b3 --- /dev/null +++ b/src/vectorcode/database/base.py @@ -0,0 +1,264 @@ +import logging +import os +from abc import ABC, abstractmethod +from typing import Optional, Self, Sequence + +from numpy.typing import NDArray + +from vectorcode.chunking import Chunk, TreeSitterChunker +from vectorcode.cli_utils import Config +from vectorcode.database.types import ( + CollectionContent, + CollectionInfo, + QueryResult, + ResultType, + VectoriseStats, +) +from vectorcode.database.utils import get_embedding_function + +logger = logging.getLogger(name=__name__) + + +""" +For developers: + +To implement a custom database connector, you should inherit the following +`DatabaseConnectorBase` class and implement ALL abstract methods. + +You should also try to wrap the exceptions with the ones in +`src/vectorcode/database/errors.py` where appropriate, because this helps the +CLI/LSP/MCP interfaces to handle some common edge cases (for example, querying +from an unindexed project). To do this, you should do the following in a +try-except block: +```python +from vectorcode.database.errors import CollectionNotFoundError + +try: + some_action_here() +except SomeCustomException as e: + raise CollectionNotFoundError("The collection was not found.") from e +``` +This will preserve the correct call stack in the error message and makes debugging +easier. +""" + + +class DatabaseConnectorBase(ABC): # pragma: nocover + @classmethod + def create(cls, configs: Config): + """ + Create a new instance of the database connector. + This classmethod will add the docstring of the child class to the exception if the initialisation fails. + """ + try: + return cls(configs) + except Exception as e: # pragma: nocover + doc_string = cls.__doc__ + if doc_string: + e.add_note(doc_string) + raise + + def __init__(self, configs: Config): + """ + Initialises the database connector with the given configs. + It is recommended to use the `create` classmethod instead of calling this directly, + as it provides better error handling during initialisation. + """ + self._configs = configs + + async def count(self, what: ResultType = ResultType.chunk) -> int: + """ + Returns the chunk count or file count of the given collection, depending on the value passed for `what`. + This method is implemented in the base class and relies on `list_collection_content`. + Child classes should not need to override this method if `list_collection_content` is implemented correctly. + """ + collection_content = await self.list_collection_content(what=what) + match what: + case ResultType.chunk: + return len(collection_content.chunks) + case ResultType.document: + return len(collection_content.files) + + @abstractmethod + async def query( + self, + ) -> list[QueryResult]: + """ + Query the database for similar chunks. + The query keywords are stored in `self._configs.query`. + The implementation of this method should handle the conversion from the native database query result to a list of `vectorcode.database.types.QueryResult` objects. + """ + pass + + @abstractmethod + async def vectorise( + self, + file_path: str, + chunker: TreeSitterChunker | None = None, + ) -> VectoriseStats: + """ + Vectorise the given file and add it to the database. + The duplicate checking (using file hash) should be done outside of this function. + + For developers: + The implementation should chunk the file, generate embeddings for the chunks, and store them in the database. + It should return a `VectoriseStats` object to report the outcome. + """ + pass + + @abstractmethod + async def list_collections(self) -> Sequence[CollectionInfo]: + """ + List all collections available in the database. + + For developers: + The implementation should retrieve all collections and return them as a sequence of `CollectionInfo` objects. + This includes metadata about each collection like its ID, path, and size. + """ + pass + + @abstractmethod + async def list_collection_content( + self, + *, + what: Optional[ResultType] = None, + collection_id: str | None = None, + collection_path: str | None = None, + ) -> CollectionContent: + """ + List the content of a collection (from `self._configs.project_root`). + You may use the keyword arguments to temporarily override the collection of interests. + + When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`. + Otherwise, this method may populate only one of them to save waiting time. + """ + pass + + @abstractmethod + async def delete( + self, + ) -> int: + """ + Delete files from the database (doesn't remove files on disk). + Returns the actual number of files deleted. + + For developers: + The file paths to be deleted are stored in `self._configs.rm_paths`. + The implementation should remove all chunks associated with these files from the database. + """ + pass + + @abstractmethod + async def drop( + self, *, collection_id: str | None = None, collection_path: str | None = None + ): + """ + Delete a collection from the database. + The collection to be dropped is specified by `collection_id` or `collection_path`. + If not provided, it defaults to `self._configs.project_root`. + """ + pass + + def _check_new_config(self, new_config: Config) -> bool: + """ + Ensures that the new config does not attempt to change database-specific settings. + It copies the `db_type` and `db_params` from the existing config to the new one. + This is a helper method for `update_config` and `replace_config`. + """ + assert isinstance(new_config, Config), "`new_config` is not a `Config` object." + new_config.db_type = self._configs.db_type + new_config.db_params = self._configs.db_params + return True + + async def update_config(self, new_config: Config) -> Self: + """ + Merge the new config with the existing one. + This method will not change the database configs. + Child classes should not need to override this method. + """ + assert self._check_new_config(new_config), ( + "The new config has different database configs." + ) + + self._configs = await self._configs.merge_from(new_config) + + return self + + async def replace_config(self, new_config: Config) -> Self: + """ + Replace the existing config with the new one. + This method will not change the database configs. + Child classes should not need to override this method. + """ + assert self._check_new_config(new_config), ( + "The new config has different database configs." + ) + self._configs = new_config + return self + + async def check_orphanes(self) -> int: + """ + Check for files that are in the database but no longer on disk, and remove them. + Returns the number of orphaned files removed. + This method relies on `list_collection_content` and `delete`. + Child classes should not need to override this. + """ + + orphanes: list[str] = [] + database_files = ( + await self.list_collection_content(what=ResultType.document) + ).files + for file in database_files: + path = file.path + if not os.path.isfile(path): + orphanes.append(path) + logger.debug(f"Discovered orphaned file: {path}") + + await self.update_config(Config(rm_paths=orphanes)) + await self.delete() + + return len(orphanes) + + def get_embedding(self, texts: str | list[str]) -> list[NDArray]: + """ + Generate embeddings for the given texts. + It uses the embedding function specified in `self._configs.embedding_function`. + If `self._configs.embedding_dims` is set, it truncates the embeddings. + Child classes should use this method to get embeddings. + """ + if isinstance(texts, str): + texts = [texts] + if len(texts) == 0: + return [] + texts = [i for i in texts] + logger.debug(f"Getting embeddings for {texts}") + embeddings = get_embedding_function(self._configs)(texts) + if self._configs.embedding_dims: + embeddings = [e[: self._configs.embedding_dims] for e in embeddings] + return embeddings + + @abstractmethod + async def get_chunks(self, file_path) -> list[Chunk]: + """ + Retrieve all chunks for a given file from the database. + If the file is not found in the database, it should return an empty list. + + For developers: + This is useful for operations that need to inspect the chunked content of a file, for example, for debugging or analysis. + """ + pass + + async def cleanup(self) -> list[str]: + """ + Remove empty collections from the database. + Returns a list of paths of the removed collections. + This method relies on `list_collections` and `drop`. + Child classes should not need to override this. + """ + removed: list[str] = [] + for collection in await self.list_collections(): + if collection.chunk_count == 0: + removed.append(collection.path) + await self.drop(collection_path=collection.path) + + return removed diff --git a/src/vectorcode/database/chroma.py b/src/vectorcode/database/chroma.py new file mode 100644 index 00000000..937b0b11 --- /dev/null +++ b/src/vectorcode/database/chroma.py @@ -0,0 +1,472 @@ +import asyncio +import contextlib +import logging +import os +import socket +import sys +from asyncio import Lock +from typing import Any, Literal, Optional, Sequence, cast +from urllib.parse import urlparse + +import chromadb +from filelock import AsyncFileLock +from tree_sitter import Point + +from vectorcode.chunking import Chunk, TreeSitterChunker +from vectorcode.cli_utils import ( + Config, + LockManager, + QueryInclude, + expand_globs, + expand_path, +) +from vectorcode.database import DatabaseConnectorBase +from vectorcode.database.chroma_common import convert_chroma_query_results +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import ( + CollectionContent, + CollectionInfo, + FileInCollection, + QueryResult, + ResultType, + VectoriseStats, +) +from vectorcode.database.utils import get_collection_id, get_uuid, hash_file + +if not chromadb.__version__.startswith("1."): # pragma: nocover + logging.error( + f""" +Found ChromaDB {chromadb.__version__}, which is incompatible wiht your VectorCode installation. Please install `vectorcode`. + +For example: +uv tool install vectorcode +""" + ) + sys.exit(1) + + +from chromadb import Collection +from chromadb.api import ClientAPI +from chromadb.config import APIVersion, Settings +from chromadb.errors import NotFoundError + +logger = logging.getLogger(name=__name__) + +SupportedClientType = Literal["http"] | Literal["persistent"] + +_SUPPORTED_CLIENT_TYPE: set[SupportedClientType] = {"http", "persistent"} + +_default_settings: dict[str, Any] = { + "db_url": None, + "db_path": os.path.expanduser("~/.local/share/vectorcode/chromadb/"), + "db_log_path": os.path.expanduser("~/.local/share/vectorcode/"), + "db_settings": {}, + "hnsw": {"hnsw:M": 64}, +} + + +class ChromaDBConnector(DatabaseConnectorBase): + """ + This is the connector layer for **ChromaDB 1.x** + + Valid `db_params` options for ChromaDB 1.x: + - `db_url`: default to `http://127.0.0.1:8000` + - `db_path`: default to `~/.local/share/vectorcode/chromadb/`; + - `db_log_path`: default to `~/.local/share/vectorcode/` + - `db_settings`: See https://github.com/chroma-core/chroma/blob/508080841d2b2ebb3a9fbdc612087248df6f1382/chromadb/config.py#L120 + - `hnsw`: default to `{ "hnsw:M": 64 }` + """ + + def __init__(self, configs: Config): + super().__init__(configs) + params = _default_settings.copy() + params.update(self._configs.db_params.copy()) + params["db_path"] = os.path.expanduser(params["db_path"]) + params["db_log_path"] = os.path.expanduser(params["db_log_path"]) + self._configs.db_params = params + + self._client: ClientAPI | None = None + self._client_type: SupportedClientType + + # locks for persistent client + self._file_lock: AsyncFileLock | None = None # inter-process lock + self._thread_lock: Lock | None = None # inter-thread lock + + def _create_client(self) -> ClientAPI: + global _SUPPORTED_CLIENT_TYPE + settings: dict[str, Any] = {"anonymized_telemetry": False} + db_params = self._configs.db_params + settings.update(db_params["db_settings"]) + if db_params.get("db_url"): + parsed_url = urlparse(db_params["db_url"]) + + settings["chroma_server_host"] = settings.get( + "chroma_server_host", parsed_url.hostname or "127.0.0.1" + ) + settings["chroma_server_http_port"] = settings.get( + "chroma_server_http_port", parsed_url.port or 8000 + ) + settings["chroma_server_ssl_enabled"] = settings.get( + "chroma_server_ssl_enabled", parsed_url.scheme == "https" + ) + settings["chroma_server_api_default_path"] = settings.get( + "chroma_server_api_default_path", parsed_url.path or APIVersion.V2 + ) + settings_obj = Settings(**settings) + logger.info( + f"Created chromadb.HttpClient from the following settings: {settings_obj}" + ) + self._client = chromadb.HttpClient( + host=settings["chroma_server_host"], + port=settings["chroma_server_http_port"], + ssl=settings["chroma_server_ssl_enabled"], + settings=settings_obj, + ) + self._client_type = "http" + else: + logger.info( + f"Created chromadb.PersistentClient at `{db_params['db_path']}` from the following settings: {settings}" + ) + os.makedirs(db_params["db_path"], exist_ok=True) + self._client = chromadb.PersistentClient(path=db_params["db_path"]) + + self._client_type = "persistent" + assert self._client_type in _SUPPORTED_CLIENT_TYPE + return self._client + + async def get_client(self) -> ClientAPI: + if self._client is None: + self._create_client() + assert self._client is not None + if self._client_type == "persistent": + lock_manager = LockManager() + self._file_lock = lock_manager.get_lock( + str(self._configs.db_params["db_path"]), "filelock" + ) + self._thread_lock = lock_manager.get_lock( + str(self._configs.db_params["db_path"]), "asyncio" + ) + return self._client + + @contextlib.asynccontextmanager + async def maybe_lock(self): + """ + Acquire a file (dir) lock if using persistent client. + """ + locked = False + if self._file_lock is not None: + assert self._thread_lock is not None + await self._file_lock.acquire() + await self._thread_lock.acquire() + locked = True + yield + if locked: + assert self._thread_lock is not None + assert self._file_lock is not None + await self._file_lock.release() + self._thread_lock.release() + + async def _create_or_get_collection( + self, collection_path: str, allow_create: bool = False + ) -> Collection: + """ + This method should be used by ChromaDB methods that are expected to **create a collection when not found**. + For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist. + """ + + collection_meta: dict[str, str | int] = { + "path": os.path.abspath(str(self._configs.project_root)), + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get( + "USER", os.environ.get("USERNAME", "DEFAULT_USER") + ), + "embedding_function": self._configs.embedding_function, + } + db_params = self._configs.db_params + user_hnsw = db_params.get("hnsw", {}) + for key in user_hnsw.keys(): + meta_field_name: str = key + if not meta_field_name.startswith("hnsw:"): + meta_field_name = f"hnsw:{meta_field_name}" + if user_hnsw.get(key) is not None: + collection_meta[meta_field_name] = user_hnsw[key] + + async with self.maybe_lock(): + collection_id = get_collection_id(collection_path) + client = await self.get_client() + if not allow_create: + try: + return client.get_collection(collection_id) + except (ValueError, NotFoundError) as e: + raise CollectionNotFoundError( + f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}" + ) from e + col = client.get_or_create_collection( + collection_id, metadata=collection_meta + ) + for key in collection_meta.keys(): + # validate metadata + assert collection_meta[key] == col.metadata.get(key), ( + f"Metadata field {key} mismatch!" + ) + + return col + + async def query(self) -> list[QueryResult]: + collection = await self._create_or_get_collection( + str(self._configs.project_root), False + ) + + assert self._configs.query is not None + assert len(self._configs.query), "Keywords cannot be empty" + keywords_embeddings = self.get_embedding(self._configs.query) + + query_count = self._configs.n_result or (await self.count(ResultType.chunk)) + query_filter = None + if len(self._configs.query_exclude): + query_filter = cast( + chromadb.Where, {"path": {"$nin": list(self._configs.query_exclude)}} + ) + if QueryInclude.chunk in self._configs.include: + if query_filter is None: + query_filter = cast(chromadb.Where, {"start": {"$gte": 0}}) + else: + query_filter = cast( + chromadb.Where, + {"$and": [query_filter.copy(), {"start": {"$gte": 0}}]}, + ) + + async with self.maybe_lock(): + raw_result = await asyncio.to_thread( + collection.query, + include=[ + "metadatas", + "documents", + "distances", + ], + query_embeddings=keywords_embeddings, + where=query_filter, + n_results=query_count, + ) + return convert_chroma_query_results(raw_result, self._configs.query) + + async def vectorise( + self, file_path: str, chunker: TreeSitterChunker | None = None + ) -> VectoriseStats: + collection_path = str(self._configs.project_root) + collection = await self._create_or_get_collection( + collection_path, allow_create=True + ) + chunker = chunker or TreeSitterChunker(self._configs) + + chunks = tuple(chunker.chunk(file_path)) + embeddings = self.get_embedding(list(i.text for i in chunks)) + if len(embeddings) == 0: + return VectoriseStats(skipped=1) + + file_hash = hash_file(file_path) + + def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata: + meta: dict[str, int | str] = {"path": file_path, "sha256": file_hash} + if chunk.start: + meta["start"] = chunk.start.row + + if chunk.end: + meta["end"] = chunk.end.row + return meta + + max_bs = (await self.get_client()).get_max_batch_size() + for batch_start_idx in range(0, len(chunks), max_bs): + batch_chunks = [ + chunks[i].text + for i in range( + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) + ) + ] + batch_embeddings = embeddings[batch_start_idx : batch_start_idx + max_bs] + batch_meta = [ + chunk_to_meta(chunks[i]) + for i in range( + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) + ) + ] + async with self.maybe_lock(): + await asyncio.to_thread( + collection.add, + documents=batch_chunks, + embeddings=batch_embeddings, + metadatas=batch_meta, + ids=[get_uuid() for _ in batch_chunks], + ) + return VectoriseStats(add=1) + + async def delete(self) -> int: + project_root = self._configs.project_root + collection = await self._create_or_get_collection(str(project_root), False) + + rm_paths = self._configs.rm_paths + if isinstance(rm_paths, str): + rm_paths = [rm_paths] + rm_paths = [ + str(expand_path(path=i, absolute=True)) + for i in await expand_globs( + paths=self._configs.rm_paths, + recursive=self._configs.recursive, + include_hidden=self._configs.include_hidden, + ) + ] + + files_in_collection = set( + str(expand_path(i.path, True)) + for i in ( + await self.list_collection_content(what=ResultType.document) + ).files + ) + + rm_paths = { + str(expand_path(i, True)) + for i in rm_paths + if os.path.isfile(i) and (i in files_in_collection) + } + + if rm_paths: + async with self.maybe_lock(): + collection.delete( + where=cast(chromadb.Where, {"path": {"$in": list(rm_paths)}}) + ) + return len(rm_paths) + + async def drop( + self, *, collection_id: str | None = None, collection_path: str | None = None + ): + collection_path = str(collection_path or self._configs.project_root) + collection_id = collection_id or get_collection_id(collection_path) + try: + async with self.maybe_lock(): + await asyncio.to_thread( + (await self.get_client()).delete_collection, collection_id + ) + except ValueError as e: + raise CollectionNotFoundError( + f"Collection at {collection_path} is not found." + ) from e + + async def get_chunks(self, file_path) -> list[Chunk]: + file_path = os.path.abspath(file_path) + try: + collection = await self._create_or_get_collection( + str(self._configs.project_root), False + ) + except CollectionNotFoundError: + logger.warning( + f"There's no existing collection at {self._configs.project_root}." + ) + return [] + + raw_results = collection.get( + where={"path": file_path}, + include=["metadatas", "documents"], + ) + assert raw_results["metadatas"] is not None + assert raw_results["documents"] is not None + + result: list[Chunk] = [] + for i in range(len(raw_results["ids"])): + meta = raw_results["metadatas"][i] + text = raw_results["documents"][i] + _id = raw_results["ids"][i] + chunk = Chunk(text=text, id=_id) + if meta.get("start") is not None: + chunk.start = Point(row=cast(int, meta["start"]), column=0) + if meta.get("end") is not None: + chunk.end = Point(row=cast(int, meta["end"]), column=0) + + result.append(chunk) + return result + + async def list_collection_content( + self, + *, + what: Optional[ResultType] = None, + collection_id: str | None = None, + collection_path: str | None = None, + ) -> CollectionContent: + """ + When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`. + Otherwise, this method may populate only one of them to save waiting time. + """ + if collection_id is None: + collection_path = str(collection_path or self._configs.project_root) + collection = await self._create_or_get_collection(collection_path, False) + else: + try: + collection = (await self.get_client()).get_collection(collection_id) + except (ValueError, NotFoundError) as e: + raise CollectionNotFoundError( + f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}" + ) from e + content = CollectionContent() + raw_content = await asyncio.to_thread( + collection.get, + include=[ + "metadatas", + "documents", + ], + ) + metadatas = raw_content.get("metadatas", []) + documents = raw_content.get("documents", []) + ids = raw_content.get("ids", []) + assert metadatas is not None + assert documents is not None + assert ids is not None + if what is None or what == ResultType.document: + content.files.extend( + set( + FileInCollection( + path=str(i.get("path")), sha256=str(i.get("sha256")) + ) + for i in metadatas + ) + ) + if what is None or what == ResultType.chunk: + for i in range(len(ids)): + start, end = None, None + if metadatas[i].get("start") is not None: + start = Point(row=cast(int, metadatas[i]["start"]), column=0) + if metadatas[i].get("end") is not None: + end = Point(row=cast(int, metadatas[i]["end"]), column=0) + content.chunks.append( + Chunk( + text=documents[i], + path=str(metadatas[i].get("path", "")) or None, + id=ids[i], + start=start, + end=end, + ) + ) + + return content + + async def list_collections(self) -> Sequence[CollectionInfo]: + client = await self.get_client() + result: list[CollectionInfo] = [] + for col in client.list_collections(): + project_root = str(col.metadata.get("path")) + col_counts = await self.list_collection_content( + collection_path=project_root + ) + result.append( + CollectionInfo( + id=col.name, + path=project_root, + embedding_function=col.metadata.get( + "embedding_function", + Config().embedding_function, # fallback to default + ), + database_backend="Chroma", + file_count=len(col_counts.files), + chunk_count=len(col_counts.chunks), + ) + ) + return result diff --git a/src/vectorcode/database/chroma0.py b/src/vectorcode/database/chroma0.py new file mode 100644 index 00000000..268ae5e3 --- /dev/null +++ b/src/vectorcode/database/chroma0.py @@ -0,0 +1,553 @@ +import asyncio +import atexit +import contextlib +import copy +import logging +import os +import socket +import subprocess +import sys +from asyncio.subprocess import Process +from dataclasses import dataclass +from typing import Any, Optional, cast +from urllib.parse import urlparse + +import chromadb + +if not chromadb.__version__.startswith("0.6.3"): # pragma: nocover + logging.error( + f""" + Found ChromaDB {chromadb.__version__}, which is incompatible with your VectorCode installation. Please install vectorcode[chroma0]. + + For example: + uv tool install vectorcode[chroma0] + """ + ) + sys.exit(1) +import httpx +from chromadb.api import AsyncClientAPI +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.api.types import IncludeEnum +from chromadb.config import APIVersion, Settings +from tree_sitter import Point + +from vectorcode.chunking import Chunk, TreeSitterChunker +from vectorcode.cli_utils import ( + Config, + LockManager, + QueryInclude, + expand_globs, + expand_path, +) +from vectorcode.database.base import DatabaseConnectorBase +from vectorcode.database.chroma_common import convert_chroma_query_results +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import ( + CollectionContent, + CollectionInfo, + FileInCollection, + ResultType, + VectoriseStats, +) +from vectorcode.database.utils import get_collection_id, get_uuid, hash_file + +_logger = logging.getLogger(name=__name__) + + +async def _try_server(base_url: str): + for ver in ("v1", "v2"): # v1 for legacy, v2 for latest chromadb. + heartbeat_url = f"{base_url}/api/{ver}/heartbeat" + try: + async with httpx.AsyncClient() as client: + response = await client.get(url=heartbeat_url) + _logger.debug(f"Heartbeat {heartbeat_url} returned {response=}") + if response.status_code == 200: + return True + except (httpx.ConnectError, httpx.ConnectTimeout): # pragma: nocover + pass + return False + + +async def _wait_for_server(base_url: str, timeout: int = 10): # pragma: nocover + # Poll the server until it's ready or timeout is reached + + start_time = asyncio.get_event_loop().time() + while True: + if await _try_server(base_url): + return + + if asyncio.get_event_loop().time() - start_time > timeout: # pragma: nocover + raise TimeoutError(f"Server did not start within {timeout} seconds.") + + await asyncio.sleep(0.1) # Wait before retrying + + +async def _start_server(configs: Config): + assert configs.db_params.get("db_url") is not None + db_path = os.path.expanduser(configs.db_params["db_path"]) + db_log_path = configs.db_params["db_log_path"] + if not os.path.isdir(db_log_path): + os.makedirs(db_log_path) + if not os.path.isdir(db_path): + _logger.warning( + f"Using local database at {os.path.expanduser('~/.local/share/vectorcode/chromadb/')}.", + ) + db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") + env = os.environ.copy() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # OS selects a free ephemeral port + port = int(s.getsockname()[1]) + + server_url = f"http://127.0.0.1:{port}" + _logger.warning(f"Starting bundled ChromaDB server at {server_url}.") + env.update({"ANONYMIZED_TELEMETRY": "False"}) + process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "chromadb.cli.cli", + "run", + "--host", + "localhost", + "--port", + str(port), + "--path", + db_path, + "--log-path", + os.path.join(str(db_log_path), "chroma.log"), + stdout=subprocess.DEVNULL, + stderr=sys.stderr, + env=env, + ) + + await _wait_for_server(server_url) + configs.db_params["db_url"] = server_url + return process + + +@dataclass +class _Chroma0ClientModel: + client: AsyncClientAPI + is_bundled: bool = False + process: Optional[Process] = None + + +class _Chroma0ClientManager: + singleton: Optional["_Chroma0ClientManager"] = None + __clients: dict[str, _Chroma0ClientModel] + + def __new__(cls) -> "_Chroma0ClientManager": + if cls.singleton is None: + cls.singleton = super().__new__(cls) + cls.singleton.__clients = {} + + atexit.register(cls.singleton.kill_servers) + + return cls.singleton + + @contextlib.asynccontextmanager + async def get_client(self, configs: Config, need_lock: bool = True): + project_root = str(expand_path(str(configs.project_root), True)) + is_bundled = False + url = configs.db_params["db_url"] + db_path = configs.db_params["db_path"] + db_log_path = configs.db_params["db_log_path"] + if self.__clients.get(project_root) is None: + process = None + if not await _try_server(url): + _logger.info(f"Starting a new server at {url}") + process = await _start_server(configs) + is_bundled = True + + try: + self.__clients[project_root] = _Chroma0ClientModel( + client=await self._create_client(configs), + is_bundled=is_bundled, + process=process, + ) + except httpx.RemoteProtocolError as e: # pragma: nocover + e.add_note(f"Please verify that {url} is a working chromadb server.") + raise + lock = None + if self.__clients[project_root].is_bundled and need_lock: + lock = LockManager().get_lock(str(db_path)) + _logger.debug(f"Locking {db_path}") + await lock.acquire() + + new_client = self.__clients[project_root].client + assert (await new_client.get_version()).split(".")[0] == "0" + yield new_client + + if lock is not None: + _logger.debug(f"Unlocking {db_log_path}") + await lock.release() + + def get_processes(self) -> list[Process]: # pragma: nocover + return [i.process for i in self.__clients.values() if i.process is not None] + + def kill_servers(self): # pragma: nocover + for p in self.get_processes(): + if p.returncode is None: + _logger.info(f"Killing bundled chroma server with PID: {p.pid}") + p.terminate() + + async def _create_client(self, configs: Config) -> AsyncClientAPI: + settings: dict[str, Any] = {"anonymized_telemetry": False} + db_settings = configs.db_params["db_settings"] + if isinstance(db_settings, dict): + valid_settings = { + k: v for k, v in db_settings.items() if k in Settings.__fields__ + } + settings.update(valid_settings) + parsed_url = urlparse(configs.db_params["db_url"]) + _logger.debug(f"Creating chromadb0 client from {db_settings}") + settings["chroma_server_host"] = settings.get( + "chroma_server_host", parsed_url.hostname or "127.0.0.1" + ) + settings["chroma_server_http_port"] = settings.get( + "chroma_server_http_port", parsed_url.port or 8000 + ) + settings["chroma_server_ssl_enabled"] = settings.get( + "chroma_server_ssl_enabled", parsed_url.scheme == "https" + ) + settings["chroma_server_api_default_path"] = settings.get( + "chroma_server_api_default_path", parsed_url.path or APIVersion.V2 + ) + settings_obj = Settings(**settings) + return await chromadb.AsyncHttpClient( + settings=settings_obj, + host=str(settings_obj.chroma_server_host), + port=int(settings_obj.chroma_server_http_port or 8000), + ) + + def clear(self): # pragma: nocover + self.__clients.clear() + + +_default_settings: dict[str, Any] = { + "db_url": "http://127.0.0.1:8000", + "db_path": os.path.expanduser("~/.local/share/vectorcode/chromadb/"), + "db_log_path": os.path.expanduser("~/.local/share/vectorcode/"), + "db_settings": {}, + "hnsw": {"hnsw:M": 64}, +} + + +class ChromaDB0Connector(DatabaseConnectorBase): + """ + This is the connector layer for **ChromaDB 0.6.3** + + Valid `db_params` options for ChromaDB 0.6.x: + - `db_url`: default to `http://127.0.0.1:8000` + - `db_path`: default to `~/.local/share/vectorcode/chromadb/`; + - `db_log_path`: default to `~/.local/share/vectorcode/` + - `db_settings`: See https://github.com/chroma-core/chroma/blob/a3b86a0302a385350a8f092a5f89a2dcdebcf6be/chromadb/config.py#L101 + - `hnsw`: default to `{ "hnsw:M": 64 }` + """ + + def __init__(self, configs: Config): + super().__init__(configs) + params = copy.deepcopy(_default_settings) + params.update(self._configs.db_params) + self._configs.db_params = params + + async def query(self): + assert self._configs.query is not None + assert len(self._configs.query), "Keywords cannot be empty" + keywords_embeddings = self.get_embedding(self._configs.query) + assert len(keywords_embeddings) == len(self._configs.query), ( + "Number of embeddings must match number of keywords." + ) + + collection_path = str(self._configs.project_root) + collection: AsyncCollection = await self._create_or_get_async_collection( + collection_path=collection_path, allow_create=False + ) + query_count = self._configs.n_result or (await self.count(ResultType.chunk)) + query_filter = None + if len(self._configs.query_exclude): + query_filter = cast( + chromadb.Where, {"path": {"$nin": list(self._configs.query_exclude)}} + ) + if QueryInclude.chunk in self._configs.include: + if query_filter is None: + query_filter = cast(chromadb.Where, {"start": {"$gte": 0}}) + else: + query_filter = cast( + chromadb.Where, + {"$and": [query_filter.copy(), {"start": {"$gte": 0}}]}, + ) + query_result = await collection.query( + query_embeddings=keywords_embeddings, + include=[ + IncludeEnum.metadatas, + IncludeEnum.documents, + IncludeEnum.distances, + ], + n_results=query_count, + where=query_filter, + ) + return convert_chroma_query_results(query_result, self._configs.query) + + async def _create_or_get_async_collection( + self, collection_path: str, allow_create: bool = False + ) -> AsyncCollection: + """ + This method should be used by ChromaDB methods that are expected to **create a collection when not found**. + For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist. + """ + + collection_meta: dict[str, str | int] = { + "path": os.path.abspath(str(self._configs.project_root)), + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get( + "USER", os.environ.get("USERNAME", "DEFAULT_USER") + ), + "embedding_function": self._configs.embedding_function, + } + db_params = self._configs.db_params + user_hnsw = db_params.get("hnsw", {}) + for key in user_hnsw.keys(): + meta_field_name: str = key + if not meta_field_name.startswith("hnsw:"): + meta_field_name = f"hnsw:{meta_field_name}" + if user_hnsw.get(key) is not None: + collection_meta[meta_field_name] = user_hnsw[key] + + async with _Chroma0ClientManager().get_client(self._configs, True) as client: + collection_id = get_collection_id(collection_path) + if not allow_create: + from chromadb.errors import InvalidCollectionException + + try: + return await client.get_collection(collection_id) + except (InvalidCollectionException, ValueError) as e: + raise CollectionNotFoundError( + f"There's no existing collection for {collection_path} in ChromaDB0 {self._configs.db_params.get('db_url')}" + ) from e + col = await client.get_or_create_collection( + collection_id, metadata=collection_meta + ) + for key in collection_meta.keys(): + # validate metadata + assert collection_meta[key] == col.metadata.get(key), ( + f"Metadata field {key} mismatch!" + ) + + return col + + async def vectorise( + self, + file_path: str, + chunker: TreeSitterChunker | None = None, + ) -> VectoriseStats: + collection_path = str(self._configs.project_root) + collection = await self._create_or_get_async_collection( + collection_path, allow_create=True + ) + chunker = chunker or TreeSitterChunker(self._configs) + + chunks = tuple(chunker.chunk(file_path)) + embeddings = self.get_embedding(list(i.text for i in chunks)) + if len(embeddings) == 0: + return VectoriseStats(skipped=1) + + file_hash = hash_file(file_path) + + def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata: + meta: dict[str, int | str] = {"path": file_path, "sha256": file_hash} + if chunk.start: + meta["start"] = chunk.start.row + + if chunk.end: + meta["end"] = chunk.end.row + return meta + + async with _Chroma0ClientManager().get_client(self._configs) as client: + max_bs = await client.get_max_batch_size() + for batch_start_idx in range(0, len(chunks), max_bs): + batch_chunks = [ + chunks[i].text + for i in range( + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) + ) + ] + batch_embeddings = embeddings[ + batch_start_idx : batch_start_idx + max_bs + ] + batch_meta = [ + chunk_to_meta(chunks[i]) + for i in range( + batch_start_idx, min(batch_start_idx + max_bs, len(chunks)) + ) + ] + await collection.add( + documents=batch_chunks, + embeddings=batch_embeddings, + metadatas=batch_meta, + ids=[get_uuid() for _ in batch_chunks], + ) + return VectoriseStats(add=1) + + async def list_collections(self): + async with _Chroma0ClientManager().get_client( + self._configs, need_lock=False + ) as client: + result: list[CollectionInfo] = [] + for col_name in await client.list_collections(): + col = await client.get_collection(col_name) + project_root = str(col.metadata.get("path")) + col_counts = await self.list_collection_content(collection_id=col_name) + result.append( + CollectionInfo( + id=col_name, + path=project_root, + embedding_function=col.metadata.get( + "embedding_function", + Config().embedding_function, # fallback to default + ), + database_backend="Chroma0", + file_count=len(col_counts.files), + chunk_count=len(col_counts.chunks), + ) + ) + return result + + async def list_collection_content( + self, + *, + what=None, + collection_id: str | None = None, + collection_path: str | None = None, + ) -> CollectionContent: + """ + When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`. + Otherwise, this method may populate only one of them to save waiting time. + """ + if collection_id is None: + collection_path = str(collection_path or self._configs.project_root) + collection = await self._create_or_get_async_collection((collection_path)) + else: + async with _Chroma0ClientManager().get_client( + self._configs, False + ) as client: + collection = await client.get_collection(collection_id) + content = CollectionContent() + raw_content = await collection.get( + include=[ + IncludeEnum.metadatas, + IncludeEnum.documents, + ] + ) + metadatas = raw_content.get("metadatas", []) + documents = raw_content.get("documents", []) + ids = raw_content.get("ids", []) + assert metadatas is not None + assert documents is not None + assert ids is not None + if what is None or what == ResultType.document: + content.files.extend( + set( + FileInCollection( + path=str(i.get("path")), sha256=str(i.get("sha256")) + ) + for i in metadatas + ) + ) + if what is None or what == ResultType.chunk: + for i in range(len(ids)): + start, end = None, None + if metadatas[i].get("start") is not None: + start = Point(row=int(metadatas[i]["start"]), column=0) + if metadatas[i].get("end") is not None: + end = Point(row=int(metadatas[i]["end"]), column=0) + content.chunks.append( + Chunk( + text=documents[i], + path=str(metadatas[i].get("path", "")) or None, + id=ids[i], + start=start, + end=end, + ) + ) + + return content + + async def delete(self) -> int: + collection_path = str(self._configs.project_root) + collection = await self._create_or_get_async_collection(collection_path, False) + rm_paths = self._configs.rm_paths + if isinstance(rm_paths, str): + rm_paths = [rm_paths] + rm_paths = [ + str(expand_path(path=i, absolute=True)) + for i in await expand_globs( + paths=self._configs.rm_paths, + recursive=self._configs.recursive, + include_hidden=self._configs.include_hidden, + ) + ] + files_in_collection = set( + str(expand_path(i.path, True)) + for i in ( + await self.list_collection_content(what=ResultType.document) + ).files + ) + + rm_paths = { + str(expand_path(i, True)) + for i in rm_paths + if os.path.isfile(i) and (i in files_in_collection) + } + if rm_paths: + await collection.delete( + where=cast(chromadb.Where, {"path": {"$in": list(rm_paths)}}) + ) + + return len(rm_paths) + + async def drop(self, *, collection_id=None, collection_path=None): + project_root = str(collection_path or self._configs.project_root) + collection_id = collection_id or get_collection_id(project_root) + async with _Chroma0ClientManager().get_client(self._configs) as client: + try: + await client.delete_collection(collection_id) + except ValueError as e: + raise CollectionNotFoundError( + f"Collection at {project_root} is not found." + ) from e + + async def get_chunks(self, file_path) -> list[Chunk]: + file_path = os.path.abspath(file_path) + try: + collection = await self._create_or_get_async_collection( + collection_path=str(self._configs.project_root), allow_create=False + ) + except CollectionNotFoundError: + _logger.warning( + f"There's no existing collection at {self._configs.project_root}." + ) + return [] + except Exception: + raise + + raw_results = await collection.get( + where={"path": file_path}, + include=[IncludeEnum.metadatas, IncludeEnum.documents], + ) + assert raw_results["metadatas"] is not None + assert raw_results["documents"] is not None + + result: list[Chunk] = [] + for i in range(len(raw_results["ids"])): + meta = raw_results["metadatas"][i] + text = raw_results["documents"][i] + _id = raw_results["ids"][i] + chunk = Chunk(text=text, id=_id) + if meta.get("start") is not None: + chunk.start = Point(row=int(meta["start"]), column=0) + if meta.get("end") is not None: + chunk.end = Point(row=int(meta["end"]), column=0) + + result.append(chunk) + return result diff --git a/src/vectorcode/database/chroma_common.py b/src/vectorcode/database/chroma_common.py new file mode 100644 index 00000000..7610ea5c --- /dev/null +++ b/src/vectorcode/database/chroma_common.py @@ -0,0 +1,42 @@ +from typing import Sequence, cast + +from chromadb.api.types import QueryResult as ChromaQueryResult +from tree_sitter import Point + +from vectorcode.chunking import Chunk +from vectorcode.database import types + + +def convert_chroma_query_results( + chroma_result: ChromaQueryResult, queries: Sequence[str] +) -> list[types.QueryResult]: + """Convert chromadb query result to in-house query results""" + assert chroma_result["documents"] is not None + assert chroma_result["distances"] is not None + assert chroma_result["metadatas"] is not None + assert chroma_result["ids"] is not None + + chroma_results_list: list[types.QueryResult] = [] + for q_i in range(len(queries)): + q = queries[q_i] + documents = chroma_result["documents"][q_i] + distances = chroma_result["distances"][q_i] + metadatas = chroma_result["metadatas"][q_i] + ids = chroma_result["ids"][q_i] + for doc, dist, meta, _id in zip(documents, distances, metadatas, ids): + chunk = Chunk(text=doc, id=_id) + if meta.get("start"): + chunk.start = Point(cast(int, meta.get("start", 0)), 0) + if meta.get("end"): + chunk.end = Point(cast(int, meta.get("end", 0)), 0) + if meta.get("path"): + chunk.path = str(meta["path"]) + chroma_results_list.append( + types.QueryResult( + chunk=chunk, + path=str(meta.get("path", "")), + query=(q,), + scores=(-dist,), + ) + ) + return chroma_results_list diff --git a/src/vectorcode/database/errors.py b/src/vectorcode/database/errors.py new file mode 100644 index 00000000..f7d2188d --- /dev/null +++ b/src/vectorcode/database/errors.py @@ -0,0 +1,6 @@ +class CollectionNotFoundError(Exception): + """ + When a requested collection doesn't exist in the database. + """ + + pass diff --git a/src/vectorcode/subcommands/query/types.py b/src/vectorcode/database/types.py similarity index 55% rename from src/vectorcode/subcommands/query/types.py rename to src/vectorcode/database/types.py index e7e5507f..73fa0e82 100644 --- a/src/vectorcode/subcommands/query/types.py +++ b/src/vectorcode/database/types.py @@ -1,12 +1,105 @@ import heapq +import json from collections import defaultdict -from dataclasses import dataclass -from typing import Literal, Union +from dataclasses import dataclass, field, fields +from enum import StrEnum +from typing import Any, Literal, Self, Sequence, Union import numpy +import tabulate from vectorcode.chunking import Chunk +CollectionID = str + + +class ResultType(StrEnum): + document = "document" + chunk = "chunk" + + +@dataclass +class QueryOpts: + keywords: Sequence[str] + count: int | None = None + return_type: ResultType = ResultType.chunk + excluded_files: Sequence[str] = field(default_factory=list) + + +@dataclass +class VectoriseStats: + add: int = 0 + update: int = 0 + removed: int = 0 + skipped: int = 0 + failed: int = 0 + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + def to_dict(self) -> dict[str, int]: + return {i.name: getattr(self, i.name) for i in fields(self)} + + def to_table(self) -> str: + _fields = fields(self) + return tabulate.tabulate( + [ + [i.name.capitalize() for i in _fields], + [getattr(self, i.name) for i in _fields], + ], + headers="firstrow", + ) + + def __add__(self, other) -> "VectoriseStats": + assert isinstance(other, VectoriseStats), ( + "`VectoriseStats` can only perform arithmatics with objects of the same type." + ) + new = VectoriseStats() + for f in fields(self): + f_name = f.name + setattr(new, f_name, sum(getattr(i, f_name) for i in (self, other))) + return new + + def __iadd__(self, other) -> Self: + for f in fields(self): + setattr(self, f.name, sum(getattr(obj, f.name) for obj in (self, other))) + return self + + +@dataclass +class CollectionInfo: + id: CollectionID + path: str # absolute path to the directory + embedding_function: str + database_backend: str + file_count: int = 0 + chunk_count: int = 0 + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, int | str]: + return { + "project-root": self.path, + "size": self.chunk_count, + "num_files": self.file_count, + "collection_name": self.id, + "embedding_function": self.embedding_function, + } + + +@dataclass +class FileInCollection: + path: str + sha256: str + + def __hash__(self): + return hash(self.sha256) + + +@dataclass +class CollectionContent: + files: list[FileInCollection] = field(default_factory=list) + chunks: list[Chunk] = field(default_factory=list) + @dataclass class QueryResult: diff --git a/src/vectorcode/database/utils.py b/src/vectorcode/database/utils.py new file mode 100644 index 00000000..cd5a5b63 --- /dev/null +++ b/src/vectorcode/database/utils.py @@ -0,0 +1,73 @@ +import hashlib +import logging +import os +import socket +import uuid +from functools import cache + +import chromadb +from chromadb.utils import embedding_functions + +from vectorcode.cli_utils import Config, expand_path + +logger = logging.getLogger(name=__name__) + + +def hash_str(string: str) -> str: + """Return the sha-256 hash of a string.""" + return hashlib.sha256(string.encode()).hexdigest() + + +def hash_file(path: str) -> str: + """return the sha-256 hash of a file.""" + hasher = hashlib.sha256() + with open(path, "rb") as file: + while True: + chunk = file.read(8192) + if chunk: + hasher.update(chunk) + else: + break + return hasher.hexdigest() + + +def get_uuid() -> str: + return uuid.uuid4().hex + + +def get_collection_id(full_path: str) -> str: + full_path = str(expand_path(full_path, absolute=True)) + hasher = hashlib.sha256() + plain_collection_name = f"{os.environ.get('USER', os.environ.get('USERNAME', 'DEFAULT_USER'))}@{socket.gethostname()}:{full_path}" + hasher.update(plain_collection_name.encode()) + collection_id = hasher.hexdigest()[:63] + logger.debug( + f"Hashing {plain_collection_name} as the collection name for {full_path}." + ) + return collection_id + + +@cache +def get_embedding_function( + configs: Config, +) -> chromadb.EmbeddingFunction: # pragma: nocover + try: + ef = getattr(embedding_functions, configs.embedding_function)( + **configs.embedding_params + ) + if ef is None: # pragma: nocover + raise AttributeError() + return ef + except AttributeError: + logger.warning( + f"Failed to use {configs.embedding_function}. Falling back to Sentence Transformer.", + ) + return embedding_functions.SentenceTransformerEmbeddingFunction() # type:ignore + except Exception as e: + e.add_note( + "\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment." + ) + logger.error( + f"Failed to use {configs.embedding_function} with following error.", + ) + raise diff --git a/src/vectorcode/lsp_main.py b/src/vectorcode/lsp_main.py index 80820f11..4c993fd6 100644 --- a/src/vectorcode/lsp_main.py +++ b/src/vectorcode/lsp_main.py @@ -6,20 +6,19 @@ import time import traceback import uuid -from typing import cast from urllib.parse import urlparse import shtab -from chromadb.types import Where +from vectorcode.database import get_database_connector +from vectorcode.database.types import ResultType from vectorcode.subcommands.vectorise import ( VectoriseStats, - chunked_add, - exclude_paths_by_spec, find_exclude_specs, load_files_from_include, - remove_orphanes, + vectorise_worker, ) +from vectorcode.subcommands.vectorise.filter import FilterManager try: # pragma: nocover from lsprotocol import types @@ -35,12 +34,12 @@ file=sys.stderr, ) sys.exit(1) -from chromadb.errors import InvalidCollectionException from vectorcode import __version__ from vectorcode.cli_utils import ( CliAction, FilesAction, + SpecResolver, cleanup_path, config_logging, expand_globs, @@ -49,9 +48,12 @@ get_project_config, parse_cli_args, ) -from vectorcode.common import ClientManager, get_collection, list_collection_files -from vectorcode.subcommands.ls import get_collection_list -from vectorcode.subcommands.query import build_query_results +from vectorcode.subcommands.query import ( + _prepare_formatted_result, + get_reranked_results, + preprocess_query_keywords, + verify_include, +) DEFAULT_PROJECT_ROOT: str | None = None logger = logging.getLogger(__name__) @@ -108,7 +110,6 @@ async def execute_command(ls: LanguageServer, args: list[str]): ) DEFAULT_PROJECT_ROOT = str(parsed_args.project_root) - collection = None if parsed_args.project_root is not None: parsed_args.project_root = os.path.abspath(str(parsed_args.project_root)) @@ -119,176 +120,191 @@ async def execute_command(ls: LanguageServer, args: list[str]): else: final_configs = parsed_args logger.info("Merged final configs: %s", final_configs) - async with ClientManager().get_client(final_configs) as client: - if final_configs.action in { - CliAction.vectorise, - CliAction.query, - CliAction.files, - }: - collection = await get_collection( - client=client, - configs=final_configs, - make_if_missing=final_configs.action in {CliAction.vectorise}, + progress_token = str(uuid.uuid4()) + + assert final_configs.action in { + CliAction.vectorise, + CliAction.query, + CliAction.files, + CliAction.ls, + }, f"Action {final_configs.action} is not supported by vectorcode LSP." + if final_configs.action in { + CliAction.vectorise, + CliAction.query, + CliAction.files, + }: + assert final_configs.project_root + + await ls.progress.create_async(progress_token) + database = get_database_connector(final_configs) + match final_configs.action: + case CliAction.query: + ls.progress.begin( + progress_token, + types.WorkDoneProgressBegin( + "VectorCode", + message=f"Querying {cleanup_path(str(final_configs.project_root))}", + ), ) - await ls.progress.create_async(progress_token) - match final_configs.action: - case CliAction.query: - ls.progress.begin( + if not verify_include(final_configs): + log_msg = "Invalid `--include` parameters!" + logger.error(log_msg) + ls.progress.end( progress_token, - types.WorkDoneProgressBegin( - "VectorCode", - message=f"Querying {cleanup_path(str(final_configs.project_root))}", - ), + types.WorkDoneProgressEnd(message=log_msg), ) - final_results = [] - try: - assert collection is not None, ( - "Failed to find the correct collection." - ) - final_results.extend( - await build_query_results(collection, final_configs) - ) - finally: - log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s." - ls.progress.end( - progress_token, - types.WorkDoneProgressEnd(message=log_message), + progress_token = None + return [] + final_results = [] + try: + preprocess_query_keywords(final_configs) + final_results.extend( + _prepare_formatted_result( + await get_reranked_results( + config=final_configs, database=database + ) ) - - progress_token = None - logger.info(log_message) - return final_results - case CliAction.ls: - ls.progress.begin( + ) + finally: + log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s." + ls.progress.end( progress_token, - types.WorkDoneProgressBegin( - "VectorCode", - message="Looking for available projects indexed by VectorCode", - ), + types.WorkDoneProgressEnd(message=log_message), ) - projects: list[dict] = [] - try: - projects.extend(await get_collection_list(client)) - finally: - ls.progress.end( - progress_token, - types.WorkDoneProgressEnd(message="List retrieved."), + progress_token = None + logger.info(log_message) + return final_results + case CliAction.ls: + ls.progress.begin( + progress_token, + types.WorkDoneProgressBegin( + "VectorCode", + message="Looking for available projects indexed by VectorCode", + ), + ) + projects: list[dict] = [] + try: + projects.extend( + i.to_dict() + for i in await get_database_connector( + final_configs + ).list_collections() + ) + finally: + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd(message="List retrieved."), + ) + progress_token = None + logger.info(f"Retrieved {len(projects)} project(s).") + return projects + case CliAction.vectorise: + assert final_configs.project_root is not None, ( + "Failed to find the correct collection." + ) + ls.progress.begin( + progress_token, + types.WorkDoneProgressBegin( + title="VectorCode", + message="Vectorising files...", + percentage=0, + ), + ) + files = await expand_globs( + final_configs.files + or load_files_from_include(str(final_configs.project_root)), + recursive=final_configs.recursive, + include_hidden=final_configs.include_hidden, + ) + total_file_count = len(files) + if not final_configs.force: # pragma: nocover + filters = FilterManager() + # tested in 'vectorise' subcommands + for spec_path in find_exclude_specs(final_configs): + if os.path.isfile(spec_path): + logger.info(f"Loading ignore specs from {spec_path}.") + spec = SpecResolver.from_path( + spec_path=spec_path, + project_root=str(final_configs.project_root) + if final_configs.project_root + else None, + ) + filters.add_filter(lambda x: spec.match_file(x, True)) + final_configs.files = list(str(i) for i in filters(files)) + stats = VectoriseStats( + skipped=total_file_count - len(final_configs.files) + ) + stats_lock = asyncio.Lock() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + tasks = [ + asyncio.create_task( + vectorise_worker( + database, str(file), semaphore, stats, stats_lock ) - logger.info(f"Retrieved {len(projects)} project(s).") - progress_token = None - return projects - case CliAction.vectorise: - assert collection is not None, ( - "Failed to find the correct collection." ) - ls.progress.begin( + for file in final_configs.files + ] + for i, task in enumerate(asyncio.as_completed(tasks), start=1): + await task + ls.progress.report( progress_token, - types.WorkDoneProgressBegin( - title="VectorCode", + types.WorkDoneProgressReport( message="Vectorising files...", - percentage=0, + percentage=int(100 * i / len(tasks)), ), ) - files = await expand_globs( - final_configs.files - or load_files_from_include(str(final_configs.project_root)), - recursive=final_configs.recursive, - include_hidden=final_configs.include_hidden, - ) - if not final_configs.force: # pragma: nocover - # tested in 'vectorise.py' - for spec in find_exclude_specs(final_configs): - if os.path.isfile(spec): - logger.info(f"Loading ignore specs from {spec}.") - files = exclude_paths_by_spec( - (str(i) for i in files), spec + + await database.check_orphanes() + + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd( + message=f"Vectorised {stats.add + stats.update} files." + ), + ) + progress_token = None + return stats.to_dict() + case CliAction.files: + match final_configs.files_action: + case FilesAction.ls: + return list( + i.path + for i in ( + await database.list_collection_content( + what=ResultType.document ) - stats = VectoriseStats() - collection_lock = asyncio.Lock() - stats_lock = asyncio.Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - final_configs, - max_batch_size, - semaphore, - ) + ).files ) - for file in files - ] - for i, task in enumerate(asyncio.as_completed(tasks), start=1): - await task - ls.progress.report( + case FilesAction.rm: + to_be_removed = list( + str(expand_path(p, True)) + for p in final_configs.rm_paths + if os.path.isfile(p) + ) + if len(to_be_removed) == 0: + return + ls.progress.begin( progress_token, - types.WorkDoneProgressReport( - message="Vectorising files...", - percentage=int(100 * i / len(tasks)), + types.WorkDoneProgressBegin( + title="VectorCode", + message=f"Removing {len(to_be_removed)} file(s).", ), ) - - await remove_orphanes( - collection, collection_lock, stats, stats_lock - ) - - ls.progress.end( - progress_token, - types.WorkDoneProgressEnd( - message=f"Vectorised {stats.add + stats.update} files." - ), - ) - - progress_token = None - return stats.to_dict() - case CliAction.files: - if collection is None: # pragma: nocover - raise InvalidCollectionException( - f"Failed to find the corresponding collection for {final_configs.project_root}" + final_configs.rm_paths = to_be_removed + await database.delete() + ls.progress.end( + progress_token, + types.WorkDoneProgressEnd( + message="Removal finished.", + ), ) - match final_configs.files_action: - case FilesAction.ls: - progress_token = None - return await list_collection_files(collection) - case FilesAction.rm: - to_be_removed = list( - str(expand_path(p, True)) - for p in final_configs.rm_paths - if os.path.isfile(p) - ) - if len(to_be_removed) == 0: - return - ls.progress.begin( - progress_token, - types.WorkDoneProgressBegin( - title="VectorCode", - message=f"Removing {len(to_be_removed)} file(s).", - ), - ) - await collection.delete( - where=cast( - Where, - {"path": {"$in": to_be_removed}}, - ) - ) - ls.progress.end( - progress_token, - types.WorkDoneProgressEnd( - message="Removal finished.", - ), - ) - progress_token = None - case _ as c: # pragma: nocover - error_message = f"Unsupported vectorcode subcommand: {str(c)}" - logger.error( - error_message, - ) - raise JsonRpcInvalidRequest(error_message) + progress_token = None + case _ as c: # pragma: nocover + error_message = f"Unsupported vectorcode subcommand: {str(c)}" + logger.error( + error_message, + ) + raise JsonRpcInvalidRequest(error_message) except Exception as e: # pragma: nocover if isinstance(e, JsonRpcException): # pygls exception. raise it as is. @@ -329,7 +345,6 @@ async def lsp_start() -> int: try: await asyncio.to_thread(server.start_io) finally: - await ClientManager().kill_servers() return 0 diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 345aedc5..aa58dcb9 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -4,8 +4,6 @@ import sys import traceback -import httpx - from vectorcode import __version__ from vectorcode.cli_utils import ( CliAction, @@ -14,7 +12,6 @@ get_project_config, parse_cli_args, ) -from vectorcode.common import ClientManager logger = logging.getLogger(name=__name__) @@ -24,7 +21,7 @@ async def async_main(): if cli_args.no_stderr: sys.stderr = open(os.devnull, "w") - if cli_args.debug: + if cli_args.debug: # pragma: nocover from vectorcode import debugging debugging.enable() @@ -108,15 +105,10 @@ async def async_main(): from vectorcode.subcommands import files return_val = await files(final_configs) - except Exception as e: + except Exception: return_val = 1 - if isinstance(e, httpx.RemoteProtocolError): # pragma: nocover - e.add_note( - f"Please verify that {final_configs.db_url} is a working chromadb server." - ) logger.error(traceback.format_exc()) finally: - await ClientManager().kill_servers() return return_val diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py index 808ba6f4..5ff83f3a 100644 --- a/src/vectorcode/mcp_main.py +++ b/src/vectorcode/mcp_main.py @@ -6,17 +6,17 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import Optional, cast +from typing import Optional import shtab -from chromadb.types import Where +from vectorcode.database import get_database_connector +from vectorcode.database.types import ResultType from vectorcode.subcommands.vectorise import ( + FilterManager, VectoriseStats, - chunked_add, - exclude_paths_by_spec, find_exclude_specs, - remove_orphanes, + vectorise_worker, ) try: # pragma: nocover @@ -31,8 +31,7 @@ from vectorcode.cli_utils import ( Config, - LockManager, - cleanup_path, + SpecResolver, config_logging, expand_globs, expand_path, @@ -40,17 +39,14 @@ get_project_config, load_config_file, ) -from vectorcode.common import ( - ClientManager, - get_collection, - get_collections, - list_collection_files, -) from vectorcode.subcommands.prompt import prompt_by_categories -from vectorcode.subcommands.query import get_query_result_files +from vectorcode.subcommands.query import ( + _prepare_formatted_result, + get_reranked_results, + preprocess_query_keywords, +) logger = logging.getLogger(name=__name__) -locks = LockManager() @dataclass @@ -91,15 +87,12 @@ def get_arg_parser(): async def list_collections() -> list[str]: - names: list[str] = [] - async with ClientManager().get_client( - await load_config_file(default_project_root) - ) as client: - async for col in get_collections(client): - if col.metadata is not None: - names.append(cleanup_path(str(col.metadata.get("path")))) - logger.info("Retrieved the following collections: %s", names) - return names + """ + Returns a list of paths to the projects that have been indexed in the database. + """ + + config = await load_config_file(default_project_root) + return [i.path for i in await get_database_connector(config).list_collections()] async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]: @@ -113,56 +106,43 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int] ErrorData(code=1, message=f"{project_root} is not a valid path.") ) config = await get_project_config(project_root) + + paths = [os.path.expanduser(i) for i in await expand_globs(paths)] + final_config = await config.merge_from( + Config( + files=[i for i in paths if os.path.isfile(i)], + project_root=project_root, + ) + ) + total_file_count = len(paths) + filters = FilterManager() + for ignore_spec_file in find_exclude_specs(final_config): + if os.path.isfile(ignore_spec_file): + logger.info(f"Loading ignore specs from {ignore_spec_file}.") + spec = SpecResolver.from_path(ignore_spec_file) + filters.add_filter(lambda x: spec.match_file(x, True)) + + final_config.files = list(filters(paths)) + + database = get_database_connector(final_config) try: - async with ClientManager().get_client(config) as client: - collection = await get_collection(client, config, True) - if collection is None: # pragma: nocover - raise McpError( - ErrorData( - code=1, - message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", - ) - ) - paths = [os.path.expanduser(i) for i in await expand_globs(paths)] - final_config = await config.merge_from( - Config( - files=[i for i in paths if os.path.isfile(i)], - project_root=project_root, - ) + stats = VectoriseStats(skipped=total_file_count - len(final_config.files)) + stats_lock = asyncio.Lock() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + tasks = [ + asyncio.create_task( + vectorise_worker(database, str(file), semaphore, stats, stats_lock) ) - for ignore_spec in find_exclude_specs(final_config): - if os.path.isfile(ignore_spec): - logger.info(f"Loading ignore specs from {ignore_spec}.") - paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec) - - stats = VectoriseStats() - collection_lock = asyncio.Lock() - stats_lock = asyncio.Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - final_config, - max_batch_size, - semaphore, - ) - ) - for file in paths - ] - for i, task in enumerate(asyncio.as_completed(tasks), start=1): - await task + for file in final_config.files + ] + for i, task in enumerate(asyncio.as_completed(tasks), start=1): + await task - await remove_orphanes(collection, collection_lock, stats, stats_lock) + await database.check_orphanes() return stats.to_dict() - except Exception as e: # pragma: nocover - if isinstance(e, McpError): + except Exception as e: + if isinstance(e, McpError): # pragma: nocover logger.error("Failed to access collection at %s", project_root) raise else: @@ -194,40 +174,20 @@ async def query_tool( message="Use `list_collections` tool to get a list of valid paths for this field.", ) ) + config = await get_project_config(project_root) + config.query = query_messages + config.n_result = n_query + preprocess_query_keywords(config) + config.n_result = n_query + try: - async with ClientManager().get_client(config) as client: - collection = await get_collection(client, config, False) - - if collection is None: # pragma: nocover - raise McpError( - ErrorData( - code=1, - message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.", - ) - ) - query_config = await config.merge_from( - Config(n_result=n_query, query=query_messages) - ) - logger.info("Built the final config: %s", query_config) - result_paths = await get_query_result_files( - collection=collection, - configs=query_config, - ) - results: list[str] = [] - for result in result_paths: - if isinstance(result, str): - if os.path.isfile(result): - with open(result) as fin: - rel_path = os.path.relpath(result, config.project_root) - results.append( - f"{rel_path}\n{fin.read()}", - ) - logger.info("Retrieved the following files: %s", result_paths) - return results - - except Exception as e: # pragma: nocover - if isinstance(e, McpError): + database = get_database_connector(config) + reranked_results = await get_reranked_results(config, database) + return list(str(i) for i in _prepare_formatted_result(reranked_results)) + + except Exception as e: + if isinstance(e, McpError): # pragma: nocover logger.error("Failed to access collection at %s", project_root) raise else: @@ -244,8 +204,13 @@ async def ls_files(project_root: str) -> list[str]: project_root: Directory to the repository. MUST be from the vectorcode `ls` tool or user input; """ configs = await get_project_config(expand_path(project_root, True)) - async with ClientManager().get_client(configs) as client: - return await list_collection_files(await get_collection(client, configs, False)) + database = get_database_connector(configs) + return list( + i.path + for i in ( + await database.list_collection_content(what=ResultType.document) + ).files + ) async def rm_files(files: list[str], project_root: str): @@ -254,17 +219,14 @@ async def rm_files(files: list[str], project_root: str): project_root: Directory to the repository. MUST be from the vectorcode `ls` tool or user input; """ configs = await get_project_config(expand_path(project_root, True)) - async with ClientManager().get_client(configs) as client: - try: - collection = await get_collection(client, configs, False) - files = [str(expand_path(i, True)) for i in files if os.path.isfile(i)] - if files: - await collection.delete(where=cast(Where, {"path": {"$in": files}})) - else: # pragma: nocover - logger.warning(f"All paths were invalid: {files}") - except ValueError: # pragma: nocover - logger.warning(f"Failed to find the collection at {configs.project_root}") - return + configs.rm_paths = [str(expand_path(i, True)) for i in files if os.path.isfile(i)] + + if configs.rm_paths: + database = get_database_connector(configs) + num_deleted = await database.delete() + return f"Removed {num_deleted} files from the database of the project located at {project_root}" + else: + logger.warning(f"The provided paths were invalid: {configs.rm_paths}") async def mcp_server(): @@ -282,24 +244,11 @@ async def mcp_server(): default_project_root = project_root default_config = await get_project_config(project_root) default_config.project_root = project_root - async with ClientManager().get_client(default_config) as client: - logger.info("Collection initialised for %s.", project_root) - - if client is None: - if mcp_config.ls_on_start: # pragma: nocover - logger.warning( - "Failed to initialise a chromadb client. Ignoring --ls-on-start flag." - ) - else: - if mcp_config.ls_on_start: - logger.info( - "Adding available collections to the server instructions." - ) - default_instructions += ( - "\nYou have access to the following collections:\n" - ) - for name in await list_collections(): - default_instructions += f"{name}" + if mcp_config.ls_on_start: + logger.info("Adding available collections to the server instructions.") + default_instructions += "\nYou have access to the following collections:\n" + for name in await list_collections(): + default_instructions += f"{name}" mcp = FastMCP("VectorCode", instructions=default_instructions) mcp.add_tool( @@ -352,7 +301,6 @@ async def run_server(): # pragma: nocover mcp = await mcp_server() await mcp.run_stdio_async() finally: - await ClientManager().kill_servers() return 0 diff --git a/src/vectorcode/subcommands/clean.py b/src/vectorcode/subcommands/clean.py index bae7ed48..7d3f1d6b 100644 --- a/src/vectorcode/subcommands/clean.py +++ b/src/vectorcode/subcommands/clean.py @@ -1,26 +1,16 @@ import logging -import os - -from chromadb.api import AsyncClientAPI from vectorcode.cli_utils import Config -from vectorcode.common import ClientManager, get_collections +from vectorcode.database import get_database_connector logger = logging.getLogger(name=__name__) -async def run_clean_on_client(client: AsyncClientAPI, pipe_mode: bool): - async for collection in get_collections(client): - meta = collection.metadata - logger.debug(f"{meta.get('path')}: {await collection.count()} chunk(s)") - if await collection.count() == 0 or not os.path.isdir(meta["path"]): - await client.delete_collection(collection.name) - logger.info(f"Deleted collection for {meta['path']}") - if not pipe_mode: - print(f"Deleted {meta['path']}.") - - async def clean(configs: Config) -> int: - async with ClientManager().get_client(configs) as client: - await run_clean_on_client(client, configs.pipe) - return 0 + database = get_database_connector(configs) + for removed in await database.cleanup(): + message = f"Deleted collection: {removed}" + logger.info(message) + if not configs.pipe: + print(message) + return 0 diff --git a/src/vectorcode/subcommands/drop.py b/src/vectorcode/subcommands/drop.py index 155c303f..ff11f538 100644 --- a/src/vectorcode/subcommands/drop.py +++ b/src/vectorcode/subcommands/drop.py @@ -1,24 +1,21 @@ import logging -from chromadb.errors import InvalidCollectionException - from vectorcode.cli_utils import Config -from vectorcode.common import ClientManager, get_collection +from vectorcode.database import get_database_connector +from vectorcode.database.errors import CollectionNotFoundError logger = logging.getLogger(name=__name__) async def drop(config: Config) -> int: - async with ClientManager().get_client(config) as client: - try: - collection = await get_collection(client, config) - collection_path = collection.metadata["path"] - await client.delete_collection(collection.name) - print(f"Collection for {collection_path} has been deleted.") - logger.info(f"Deteted collection at {collection_path}.") - return 0 - except (ValueError, InvalidCollectionException) as e: - logger.error( - f"{e.__class__.__name__}: There's no existing collection for {config.project_root}" - ) - return 1 + try: + database = get_database_connector(config) + await database.drop() + if not config.pipe: + print(f"Collection for {config.project_root} has been deleted.") + return 0 + except CollectionNotFoundError: + logger.warning(f"Collection for {config.project_root} doesn't exist.") + return 1 + except Exception: # pragma: nocover + raise diff --git a/src/vectorcode/subcommands/files/ls.py b/src/vectorcode/subcommands/files/ls.py index 6dffd3d7..9d0d01c2 100644 --- a/src/vectorcode/subcommands/files/ls.py +++ b/src/vectorcode/subcommands/files/ls.py @@ -2,22 +2,29 @@ import logging from vectorcode.cli_utils import Config -from vectorcode.common import ClientManager, get_collection, list_collection_files +from vectorcode.database import get_database_connector +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import ResultType logger = logging.getLogger(name=__name__) async def ls(configs: Config) -> int: - async with ClientManager().get_client(configs=configs) as client: - try: - collection = await get_collection(client, configs, False) - except ValueError: - logger.error(f"There's no existing collection at {configs.project_root}.") - return 1 - paths = await list_collection_files(collection) + try: + database = get_database_connector(configs) + files = list( + i.path + for i in ( + await database.list_collection_content(what=ResultType.document) + ).files + ) if configs.pipe: - print(json.dumps(list(paths))) + print(json.dumps(files)) else: - for p in paths: - print(p) - return 0 + print("\n".join(files)) + return 0 + except CollectionNotFoundError: + logger.error(f"There's no existing collection for `{configs.project_root}`.") + return 1 + except Exception: # pragma: nocover + raise diff --git a/src/vectorcode/subcommands/files/rm.py b/src/vectorcode/subcommands/files/rm.py index 1d2e9fb3..9e98d9cc 100644 --- a/src/vectorcode/subcommands/files/rm.py +++ b/src/vectorcode/subcommands/files/rm.py @@ -1,31 +1,28 @@ import logging -import os -from typing import cast -from chromadb.types import Where - -from vectorcode.cli_utils import Config, expand_path -from vectorcode.common import ClientManager, get_collection +from vectorcode.cli_utils import Config +from vectorcode.database import get_database_connector +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import ResultType logger = logging.getLogger(name=__name__) async def rm(configs: Config) -> int: - async with ClientManager().get_client(configs=configs) as client: - try: - collection = await get_collection(client, configs, False) - except ValueError: - logger.error(f"There's no existing collection at {configs.project_root}.") - return 1 - paths = list( - str(expand_path(p, True)) for p in configs.rm_paths if os.path.isfile(p) - ) - await collection.delete(where=cast(Where, {"path": {"$in": paths}})) + try: + database = get_database_connector(configs) + remove_count = await database.delete() + if not configs.pipe: - print(f"Removed {len(paths)} file(s).") - if await collection.count() == 0: + print(f"Removed {remove_count} file(s).") + if await database.count(ResultType.chunk) == 0: logger.warning( - f"The collection at {configs.project_root} is now empty and will be removed." + f"The collection at {configs.project_root} is now empty, and will be removed." ) - await client.delete_collection(collection.name) - return 0 + await database.drop() + return 0 + except CollectionNotFoundError: + logger.error(f"There's no existing collection for `{configs.project_root}`.") + return 1 + except Exception: # pragma: nocover + raise diff --git a/src/vectorcode/subcommands/ls.py b/src/vectorcode/subcommands/ls.py index c78d82ac..3d951dac 100644 --- a/src/vectorcode/subcommands/ls.py +++ b/src/vectorcode/subcommands/ls.py @@ -1,69 +1,44 @@ import json import logging import os -import socket import tabulate -from chromadb.api import AsyncClientAPI -from chromadb.api.types import IncludeEnum -from vectorcode.cli_utils import Config, cleanup_path -from vectorcode.common import ClientManager, get_collections +from vectorcode.cli_utils import Config +from vectorcode.database import get_database_connector logger = logging.getLogger(name=__name__) -async def get_collection_list(client: AsyncClientAPI) -> list[dict]: - result = [] - async for collection in get_collections(client): - meta = collection.metadata - document_meta = await collection.get(include=[IncludeEnum.metadatas]) - unique_files = set( - i.get("path") for i in (document_meta["metadatas"] or []) if i is not None - ) - result.append( - { - "project-root": cleanup_path(meta["path"]), - "user": meta.get("username"), - "hostname": socket.gethostname(), - "collection_name": collection.name, - "size": await collection.count(), - "embedding_function": meta["embedding_function"], - "num_files": len(unique_files), - } - ) - return result - - async def ls(configs: Config) -> int: - async with ClientManager().get_client(configs) as client: - result: list[dict] = await get_collection_list(client) - logger.info(f"Found the following collections: {result}") - - if configs.pipe: - print(json.dumps(result)) - else: - table = [] - for meta in result: - project_root = meta["project-root"] - if os.environ.get("HOME"): - project_root = project_root.replace(os.environ["HOME"], "~") - row = [ - project_root, - meta["size"], - meta["num_files"], - meta["embedding_function"], - ] - table.append(row) - print( - tabulate.tabulate( - table, - headers=[ - "Project Root", - "Collection Size", - "Number of Files", - "Embedding Function", - ], - ) + result = [ + i.to_dict() for i in await get_database_connector(configs).list_collections() + ] + + if configs.pipe: + print(json.dumps(result)) + else: + table = [] + for meta in result: + project_root = str(meta["project-root"]) + if os.environ.get("HOME"): + project_root = project_root.replace(os.environ["HOME"], "~") + row = [ + project_root, + meta["size"], + meta["num_files"], + meta["embedding_function"], + ] + table.append(row) + print( + tabulate.tabulate( + table, + headers=[ + "Project Root", + "Number of Embeddings", + "Number of Files", + "Embedding Function", + ], ) - return 0 + ) + return 0 diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py index 22ab5abc..e954b70a 100644 --- a/src/vectorcode/subcommands/query/__init__.py +++ b/src/vectorcode/subcommands/query/__init__.py @@ -1,233 +1,105 @@ import json import logging import os -from typing import Any, cast - -from chromadb import Where -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.types import IncludeEnum, QueryResult -from chromadb.errors import InvalidCollectionException, InvalidDimensionException -from tree_sitter import Point from vectorcode.chunking import Chunk, StringChunker from vectorcode.cli_utils import ( Config, QueryInclude, - cleanup_path, - expand_globs, - expand_path, -) -from vectorcode.common import ( - ClientManager, - get_collection, - get_embedding_function, - verify_ef, ) -from vectorcode.subcommands.query import types as vectorcode_types +from vectorcode.database import get_database_connector +from vectorcode.database.base import DatabaseConnectorBase from vectorcode.subcommands.query.reranker import ( - RerankerError, get_reranker, ) logger = logging.getLogger(name=__name__) -def convert_query_results( - chroma_result: QueryResult, queries: list[str] -) -> list[vectorcode_types.QueryResult]: - """Convert chromadb query result to in-house query results""" - assert chroma_result["documents"] is not None - assert chroma_result["distances"] is not None - assert chroma_result["metadatas"] is not None - assert chroma_result["ids"] is not None - - chroma_results_list: list[vectorcode_types.QueryResult] = [] - for q_i in range(len(queries)): - q = queries[q_i] - documents = chroma_result["documents"][q_i] - distances = chroma_result["distances"][q_i] - metadatas = chroma_result["metadatas"][q_i] - ids = chroma_result["ids"][q_i] - for doc, dist, meta, _id in zip(documents, distances, metadatas, ids): - chunk = Chunk(text=doc, id=_id) - if meta.get("start"): - chunk.start = Point(int(meta.get("start", 0)), 0) - if meta.get("end"): - chunk.end = Point(int(meta.get("end", 0)), 0) - if meta.get("path"): - chunk.path = str(meta["path"]) - chroma_results_list.append( - vectorcode_types.QueryResult( - chunk=chunk, - path=str(meta.get("path", "")), - query=(q,), - scores=(-dist,), +def _prepare_formatted_result( + reranked_results: list[str | Chunk], +) -> list[dict[str, str | int]]: + results: list[dict[str, str | int]] = [] + for res in reranked_results: + if isinstance(res, str): + if os.path.isfile(res): + # path to a file + with open(res) as fin: + results.append({"path": res, "document": fin.read()}) + else: # pragma: nocover + logger.warning(f"Skipping non-existent file: {res}") + else: + assert isinstance(res, Chunk) + if res.start is None or res.end is None: # pragma: nocover + logger.warning( + "This chunk doesn't have line range metadata. Please try re-vectorising the project." ) - ) - return chroma_results_list + output_dict = { + "path": res.path, + "chunk": res.text, + "end_line": res.end.row if res.end is not None else None, + "chunk_id": res.id, + } + if res.start: + output_dict["start_line"] = res.start.row + if res.end: + output_dict["end_line"] = res.end.row + results.append(output_dict) + return results + + +async def get_reranked_results( + config: Config, + database: DatabaseConnectorBase, +) -> list[str | Chunk]: + """ + Return a list of paths or `Chunk`s ranked by similarity. + """ + reranker = get_reranker(config) + reranked_results = await reranker.rerank(results=await database.query()) + return reranked_results -async def get_query_result_files( - collection: AsyncCollection, configs: Config -) -> list[str | Chunk]: - query_chunks = [] - assert configs.query, "Query messages cannot be empty." +def preprocess_query_keywords(configs: Config): + assert configs.query + query_chunks: list[str] = [] chunker = StringChunker(configs) for q in configs.query: query_chunks.extend(str(i) for i in chunker.chunk(q)) + configs.query[:] = query_chunks + return configs - configs.query_exclude = [ - expand_path(i, True) - for i in await expand_globs(configs.query_exclude) - if os.path.isfile(i) - ] - if (await collection.count()) == 0: - logger.error("Empty collection!") - return [] - try: - if len(configs.query_exclude): - logger.info(f"Excluding {len(configs.query_exclude)} files from the query.") - filter: dict[str, Any] = {"path": {"$nin": configs.query_exclude}} - else: - filter = {} - num_query = configs.n_result - if QueryInclude.chunk in configs.include: - if filter: - filter = {"$and": [filter.copy(), {"start": {"$gte": 0}}]} - else: - filter["start"] = {"$gte": 0} - else: - num_query = await collection.count() - if configs.query_multiplier > 0: - num_query = min( - int(configs.n_result * configs.query_multiplier), - await collection.count(), - ) - logger.info(f"Querying {num_query} chunks for reranking.") - query_embeddings = get_embedding_function(configs)(query_chunks) - if isinstance(configs.embedding_dims, int) and configs.embedding_dims > 0: - query_embeddings = [e[: configs.embedding_dims] for e in query_embeddings] - chroma_query_results: QueryResult = await collection.query( - query_embeddings=query_embeddings, - n_results=num_query, - include=[ - IncludeEnum.metadatas, - IncludeEnum.distances, - IncludeEnum.documents, - ], - where=cast(Where, filter) or None, - ) - except IndexError: - # no results found - return [] - - reranker = get_reranker(configs) - converted_results = convert_query_results(chroma_query_results, configs.query) - return await reranker.rerank(converted_results) - -async def build_query_results( - collection: AsyncCollection, configs: Config -) -> list[dict[str, str | int]]: - assert configs.project_root - - def make_output_path(path: str, absolute: bool) -> str: - if absolute: - if os.path.isabs(path): - return path - return os.path.abspath(os.path.join(str(configs.project_root), path)) - else: - rel_path = os.path.relpath(path, configs.project_root) - if isinstance(rel_path, bytes): # pragma: nocover - # for some reasons, some python versions report that `os.path.relpath` returns a string. - rel_path = rel_path.decode() - return rel_path - - structured_result = [] - for res in await get_query_result_files(collection, configs): - if isinstance(res, str): - output_path = make_output_path(res, configs.use_absolute_path) - io_path = make_output_path(res, True) - if not os.path.isfile(io_path): - logger.warning(f"{io_path} is no longer a valid file.") - continue - with open(io_path) as fin: - structured_result.append({"path": output_path, "document": fin.read()}) - else: - res = cast(Chunk, res) - assert res.path, f"{res} has no `path` attribute." - structured_result.append( - { - "path": make_output_path(res.path, configs.use_absolute_path) - if res.path is not None - else None, - "chunk": res.text, - "start_line": res.start.row if res.start is not None else None, - "end_line": res.end.row if res.end is not None else None, - "chunk_id": res.id, - } - ) - for result in structured_result: - if result.get("path") is not None: - result["path"] = cleanup_path(result["path"]) - return structured_result - - -async def query(configs: Config) -> int: +def verify_include(configs: Config): + if QueryInclude.path not in configs.include: + configs.include.append(QueryInclude.path) if ( QueryInclude.chunk in configs.include and QueryInclude.document in configs.include ): logger.error( - "Having both chunk and document in the output is not supported!", + "`chunk` and `document` cannot be used at the same time for `--include`." ) - return 1 - async with ClientManager().get_client(configs) as client: - try: - collection = await get_collection(client, configs, False) - if not verify_ef(collection, configs): - return 1 - except (ValueError, InvalidCollectionException) as e: - logger.error( - f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", - ) - return 1 - except InvalidDimensionException as e: - logger.error( - f"{e.__class__.__name__}: The collection was embedded with a different embedding model.", - ) - return 1 - except IndexError as e: # pragma: nocover - logger.error( - f"{e.__class__.__name__}: Failed to get the collection. Please check your config." - ) - return 1 - - if not configs.pipe: - print("Starting querying...") + return False + return True - if QueryInclude.chunk in configs.include: - if len((await collection.get(where={"start": {"$gte": 0}}))["ids"]) == 0: - logger.warning( - """ - This collection doesn't contain line range metadata. Falling back to `--include path document`. - Please re-vectorise it to use `--include chunk`.""", - ) - configs.include = [QueryInclude.path, QueryInclude.document] - try: - structured_result = await build_query_results(collection, configs) - except RerankerError as e: # pragma: nocover - # error logs should be handled where they're raised - logger.error(f"{e.__class__.__name__}") - return 1 +async def query(configs: Config) -> int: + if not verify_include(configs): + return 1 - if configs.pipe: - print(json.dumps(structured_result)) - else: - for idx, result in enumerate(structured_result): - for include_item in configs.include: - print(f"{include_item.to_header()}{result.get(include_item.value)}") - if idx != len(structured_result) - 1: - print() - return 0 + assert configs.query + preprocess_query_keywords(configs) + + database = get_database_connector(configs) + reranked_results = await get_reranked_results(configs, database) + formatted_results = _prepare_formatted_result(reranked_results) + if configs.pipe: + print(json.dumps(formatted_results)) + else: + for idx, result in enumerate(formatted_results): + for include_item in configs.include: + print(f"{include_item.to_header()}{result.get(include_item.value)}") + if idx != len(formatted_results) - 1: + print() + return 0 diff --git a/src/vectorcode/subcommands/query/reranker/base.py b/src/vectorcode/subcommands/query/reranker/base.py index 18a4c68a..0b3178e6 100644 --- a/src/vectorcode/subcommands/query/reranker/base.py +++ b/src/vectorcode/subcommands/query/reranker/base.py @@ -7,7 +7,7 @@ from vectorcode.chunking import Chunk from vectorcode.cli_utils import Config, QueryInclude -from vectorcode.subcommands.query.types import QueryResult +from vectorcode.database.types import QueryResult logger = logging.getLogger(name=__name__) diff --git a/src/vectorcode/subcommands/query/reranker/cross_encoder.py b/src/vectorcode/subcommands/query/reranker/cross_encoder.py index 3f2fcd1d..ff41853d 100644 --- a/src/vectorcode/subcommands/query/reranker/cross_encoder.py +++ b/src/vectorcode/subcommands/query/reranker/cross_encoder.py @@ -2,7 +2,7 @@ from typing import Any from vectorcode.cli_utils import Config -from vectorcode.subcommands.query.types import QueryResult +from vectorcode.database.types import QueryResult from .base import RerankerBase diff --git a/src/vectorcode/subcommands/query/reranker/naive.py b/src/vectorcode/subcommands/query/reranker/naive.py index 65478c09..3163a61a 100644 --- a/src/vectorcode/subcommands/query/reranker/naive.py +++ b/src/vectorcode/subcommands/query/reranker/naive.py @@ -2,7 +2,7 @@ from typing import Any from vectorcode.cli_utils import Config -from vectorcode.subcommands.query.types import QueryResult +from vectorcode.database.types import QueryResult from .base import RerankerBase diff --git a/src/vectorcode/subcommands/update.py b/src/vectorcode/subcommands/update.py index 1416a7b8..9bfa32e2 100644 --- a/src/vectorcode/subcommands/update.py +++ b/src/vectorcode/subcommands/update.py @@ -1,93 +1,65 @@ import asyncio import logging import os -import sys from asyncio import Lock import tqdm -from chromadb.api.types import IncludeEnum -from chromadb.errors import InvalidCollectionException from vectorcode.cli_utils import Config -from vectorcode.common import ClientManager, get_collection, verify_ef -from vectorcode.subcommands.vectorise import VectoriseStats, chunked_add, show_stats +from vectorcode.database import get_database_connector +from vectorcode.database.types import ResultType +from vectorcode.database.utils import hash_file +from vectorcode.subcommands.vectorise import ( + VectoriseStats, + show_stats, + vectorise_worker, +) +from vectorcode.subcommands.vectorise.filter import FilterManager logger = logging.getLogger(name=__name__) async def update(configs: Config) -> int: - async with ClientManager().get_client(configs) as client: - try: - collection = await get_collection(client, configs, False) - except IndexError as e: - print( - f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." - ) - return 1 - except (ValueError, InvalidCollectionException) as e: - print( - f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}", - file=sys.stderr, - ) - return 1 - if collection is None: # pragma: nocover - logger.error( - f"Failed to find a collection at {configs.project_root} from {configs.db_url}" - ) - return 1 - if not verify_ef(collection, configs): # pragma: nocover - return 1 + assert configs.project_root is not None + database = get_database_connector(configs) - metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"] - if metas is None or len(metas) == 0: # pragma: nocover - logger.debug("Empty collection.") - return 0 + filters = FilterManager() - files_gen = (str(meta.get("path", "")) for meta in metas) - files = set() - orphanes = set() - for file in files_gen: - if os.path.isfile(file): - files.add(file) - else: - orphanes.add(file) + collection_files = ( + await database.list_collection_content(what=ResultType.document) + ).files - stats = VectoriseStats(removed=len(orphanes)) - collection_lock = Lock() - stats_lock = Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) + existing_hashes = set(i.sha256 for i in collection_files) - with tqdm.tqdm( - total=len(files), desc="Vectorising files...", disable=configs.pipe - ) as bar: - logger.info(f"Updating embeddings for {len(files)} file(s).") - try: - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - ) - for file in files - ] - for task in asyncio.as_completed(tasks): - await task - bar.update(1) - except asyncio.CancelledError: # pragma: nocover - print("Abort.", file=sys.stderr) - return 1 + files = (i.path for i in collection_files) + if not configs.force: + filters.add_filter(lambda x: hash_file(x) not in existing_hashes) + else: # pragma: nocover + logger.info("Ignoring exclude specs.") + + files = list(filters(files)) + stats = VectoriseStats(skipped=len(collection_files) - len(files)) + stats_lock = Lock() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + + with tqdm.tqdm( + total=len(files), desc="Vectorising files...", disable=configs.pipe + ) as bar: + try: + tasks = [ + asyncio.create_task( + vectorise_worker(database, file, semaphore, stats, stats_lock) + ) + for file in files + ] + for task in asyncio.as_completed(tasks): + await task + bar.update(1) + except asyncio.CancelledError: + logger.warning("Abort.") + return 1 - if len(orphanes): - logger.info(f"Removing {len(orphanes)} orphaned files from database.") - await collection.delete(where={"path": {"$in": list(orphanes)}}) + await database.check_orphanes() - show_stats(configs, stats) - return 0 + show_stats(configs=configs, stats=stats) + return 0 diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py deleted file mode 100644 index 2ce0b249..00000000 --- a/src/vectorcode/subcommands/vectorise.py +++ /dev/null @@ -1,334 +0,0 @@ -import asyncio -import glob -import hashlib -import json -import logging -import os -import sys -import uuid -from asyncio import Lock -from dataclasses import dataclass, fields -from typing import Iterable, Optional - -import pathspec -import tabulate -import tqdm -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.types import IncludeEnum - -from vectorcode.chunking import Chunk, TreeSitterChunker -from vectorcode.cli_utils import ( - GLOBAL_EXCLUDE_SPEC, - GLOBAL_INCLUDE_SPEC, - Config, - SpecResolver, - expand_globs, - expand_path, -) -from vectorcode.common import ( - ClientManager, - get_collection, - get_embedding_function, - list_collection_files, - verify_ef, -) - -logger = logging.getLogger(name=__name__) - - -@dataclass -class VectoriseStats: - add: int = 0 - update: int = 0 - removed: int = 0 - skipped: int = 0 - failed: int = 0 - - def to_json(self) -> str: - return json.dumps(self.to_dict()) - - def to_dict(self) -> dict[str, int]: - return {i.name: getattr(self, i.name) for i in fields(self)} - - def to_table(self) -> str: - _fields = fields(self) - return tabulate.tabulate( - [ - [i.name.capitalize() for i in _fields], - [getattr(self, i.name) for i in _fields], - ], - headers="firstrow", - ) - - -def hash_str(string: str) -> str: - """Return the sha-256 hash of a string.""" - return hashlib.sha256(string.encode()).hexdigest() - - -def hash_file(path: str) -> str: - """return the sha-256 hash of a file.""" - hasher = hashlib.sha256() - with open(path, "rb") as file: - while True: - chunk = file.read(8192) - if chunk: - hasher.update(chunk) - else: - break - return hasher.hexdigest() - - -def get_uuid() -> str: - return uuid.uuid4().hex - - -async def chunked_add( - file_path: str, - collection: AsyncCollection, - collection_lock: Lock, - stats: VectoriseStats, - stats_lock: Lock, - configs: Config, - max_batch_size: int, - semaphore: asyncio.Semaphore, -): - embedding_function = get_embedding_function(configs) - full_path_str = str(expand_path(str(file_path), True)) - orig_sha256 = None - new_sha256 = hash_file(full_path_str) - async with collection_lock: - existing_chunks = await collection.get( - where={"path": full_path_str}, - include=[IncludeEnum.metadatas], - ) - num_existing_chunks = len((existing_chunks)["ids"]) - if existing_chunks["metadatas"]: - orig_sha256 = existing_chunks["metadatas"][0].get("sha256") - if orig_sha256 and orig_sha256 == new_sha256: - logger.debug( - f"Skipping {full_path_str} because it's unchanged since last vectorisation." - ) - stats.skipped += 1 - return - - if num_existing_chunks: - logger.debug( - "Deleting %s existing chunks for the current file.", num_existing_chunks - ) - async with collection_lock: - await collection.delete(where={"path": full_path_str}) - - logger.debug(f"Vectorising {file_path}") - try: - async with semaphore: - chunks: list[Chunk | str] = list( - TreeSitterChunker(configs).chunk(full_path_str) - ) - if len(chunks) == 0 or (len(chunks) == 1 and chunks[0] == ""): - # empty file - logger.debug(f"Skipping {full_path_str} because it's empty.") - stats.skipped += 1 - return - chunks.append(str(os.path.relpath(full_path_str, configs.project_root))) - logger.debug(f"Chunked into {len(chunks)} pieces.") - metas = [] - for chunk in chunks: - meta: dict[str, str | int] = { - "path": full_path_str, - "sha256": new_sha256, - } - if isinstance(chunk, Chunk): - if chunk.start: - meta["start"] = chunk.start.row - if chunk.end: - meta["end"] = chunk.end.row - - metas.append(meta) - async with collection_lock: - for idx in range(0, len(chunks), max_batch_size): - inserted_chunks = chunks[idx : idx + max_batch_size] - embeddings = embedding_function( - list(str(c) for c in inserted_chunks) - ) - if ( - isinstance(configs.embedding_dims, int) - and configs.embedding_dims > 0 - ): - logger.debug( - f"Truncating embeddings to {configs.embedding_dims} dimensions." - ) - embeddings = [e[: configs.embedding_dims] for e in embeddings] - await collection.add( - ids=[get_uuid() for _ in inserted_chunks], - documents=[str(i) for i in inserted_chunks], - embeddings=embeddings, - metadatas=metas, - ) - except (UnicodeDecodeError, UnicodeError): # pragma: nocover - logger.warning(f"Failed to decode {full_path_str}.") - stats.failed += 1 - return - - if num_existing_chunks: - async with stats_lock: - stats.update += 1 - else: - async with stats_lock: - stats.add += 1 - - -async def remove_orphanes( - collection: AsyncCollection, - collection_lock: Lock, - stats: VectoriseStats, - stats_lock: Lock, -): - async with collection_lock: - paths = await list_collection_files(collection) - orphans = set() - for path in paths: - if isinstance(path, str) and not os.path.isfile(path): - orphans.add(path) - async with stats_lock: - stats.removed = len(orphans) - if len(orphans): - logger.info(f"Removing {len(orphans)} orphaned files from database.") - await collection.delete(where={"path": {"$in": list(orphans)}}) - - -def show_stats(configs: Config, stats: VectoriseStats): - if configs.pipe: - print(stats.to_json()) - else: - print(stats.to_table()) - - -def exclude_paths_by_spec( - paths: Iterable[str], spec_path: str, project_root: Optional[str] = None -) -> list[str]: - """ - Files matched by the specs will be excluded. - """ - - return list(SpecResolver.from_path(spec_path, project_root).match(paths, True)) - - -def load_files_from_include(project_root: str) -> list[str]: - include_file_path = os.path.join(project_root, ".vectorcode", "vectorcode.include") - specs: Optional[pathspec.GitIgnoreSpec] = None - if os.path.isfile(include_file_path): - logger.debug("Loading from local `vectorcode.include`.") - with open(include_file_path) as fin: - specs = pathspec.GitIgnoreSpec.from_lines( - lines=(os.path.expanduser(i) for i in fin.readlines()), - ) - elif os.path.isfile(GLOBAL_INCLUDE_SPEC): - logger.debug("Loading from global `vectorcode.include`.") - with open(GLOBAL_INCLUDE_SPEC) as fin: - specs = pathspec.GitIgnoreSpec.from_lines( - lines=(os.path.expanduser(i) for i in fin.readlines()), - ) - if specs is not None: - logger.info("Populating included files from loaded specs.") - return [ - result.file - for result in specs.check_tree_files(project_root) - if result.include - ] - return [] - - -def find_exclude_specs(configs: Config) -> list[str]: - """ - Load a list of paths to exclude specs. - Can be `.gitignore` or local/global `vectorcode.exclude` - """ - if configs.recursive: - specs = glob.glob( - os.path.join(str(configs.project_root), "**", ".gitignore"), recursive=True - ) + glob.glob( - os.path.join(str(configs.project_root), "**", "vectorcode.exclude"), - recursive=True, - ) - else: - specs = [os.path.join(str(configs.project_root), ".gitignore")] - - exclude_spec_path = os.path.join( - str(configs.project_root), ".vectorcode", "vectorcode.exclude" - ) - if os.path.isfile(exclude_spec_path): - specs.append(exclude_spec_path) - elif os.path.isfile(GLOBAL_EXCLUDE_SPEC): - specs.append(GLOBAL_EXCLUDE_SPEC) - specs = [i for i in specs if os.path.isfile(i)] - logger.debug(f"Loaded exclude specs: {specs}") - return specs - - -async def vectorise(configs: Config) -> int: - assert configs.project_root is not None - async with ClientManager().get_client(configs) as client: - try: - collection = await get_collection(client, configs, True) - except IndexError as e: - print( - f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config." - ) - return 1 - if not verify_ef(collection, configs): - return 1 - - files = await expand_globs( - configs.files or load_files_from_include(str(configs.project_root)), - recursive=configs.recursive, - include_hidden=configs.include_hidden, - ) - - if not configs.force: - for spec_path in find_exclude_specs(configs): - if os.path.isfile(spec_path): - logger.info(f"Loading ignore specs from {spec_path}.") - files = exclude_paths_by_spec( - (str(i) for i in files), spec_path, str(configs.project_root) - ) - logger.debug(f"Files after excluding: {files}") - else: # pragma: nocover - logger.info("Ignoring exclude specs.") - - stats = VectoriseStats() - collection_lock = Lock() - stats_lock = Lock() - max_batch_size = await client.get_max_batch_size() - semaphore = asyncio.Semaphore(os.cpu_count() or 1) - - with tqdm.tqdm( - total=len(files), desc="Vectorising files...", disable=configs.pipe - ) as bar: - try: - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - ) - for file in files - ] - for task in asyncio.as_completed(tasks): - await task - bar.update(1) - except asyncio.CancelledError: - print("Abort.", file=sys.stderr) - return 1 - - await remove_orphanes(collection, collection_lock, stats, stats_lock) - - show_stats(configs=configs, stats=stats) - return 0 diff --git a/src/vectorcode/subcommands/vectorise/__init__.py b/src/vectorcode/subcommands/vectorise/__init__.py new file mode 100644 index 00000000..96fc52f1 --- /dev/null +++ b/src/vectorcode/subcommands/vectorise/__init__.py @@ -0,0 +1,165 @@ +import asyncio +import glob +import logging +import os +from asyncio import Lock, Semaphore +from typing import Optional + +import pathspec +import tqdm + +from vectorcode.cli_utils import ( + GLOBAL_EXCLUDE_SPEC, + GLOBAL_INCLUDE_SPEC, + Config, + SpecResolver, + expand_globs, +) +from vectorcode.database import get_database_connector +from vectorcode.database.base import DatabaseConnectorBase +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import ResultType, VectoriseStats +from vectorcode.database.utils import hash_file +from vectorcode.subcommands.vectorise.filter import FilterManager + +logger = logging.getLogger(name=__name__) + + +def show_stats(configs: Config, stats: VectoriseStats): + if configs.pipe: + print(stats.to_json()) + else: + print(stats.to_table()) + + +def load_files_from_include(project_root: str) -> list[str]: + include_file_path = os.path.join(project_root, ".vectorcode", "vectorcode.include") + specs: Optional[pathspec.GitIgnoreSpec] = None + if os.path.isfile(include_file_path): + logger.debug("Loading from local `vectorcode.include`.") + with open(include_file_path) as fin: + specs = pathspec.GitIgnoreSpec.from_lines( + lines=(os.path.expanduser(i) for i in fin.readlines()), + ) + elif os.path.isfile(GLOBAL_INCLUDE_SPEC): + logger.debug("Loading from global `vectorcode.include`.") + with open(GLOBAL_INCLUDE_SPEC) as fin: + specs = pathspec.GitIgnoreSpec.from_lines( + lines=(os.path.expanduser(i) for i in fin.readlines()), + ) + if specs is not None: + logger.info("Populating included files from loaded specs.") + return [ + result.file + for result in specs.check_tree_files(project_root) + if result.include + ] + return [] + + +def find_exclude_specs(configs: Config) -> list[str]: + """ + Load a list of paths to exclude specs. + Can be `.gitignore` or local/global `vectorcode.exclude` + """ + if configs.recursive: + specs = glob.glob( + os.path.join(str(configs.project_root), "**", ".gitignore"), recursive=True + ) + glob.glob( + os.path.join(str(configs.project_root), "**", "vectorcode.exclude"), + recursive=True, + ) + else: + specs = [os.path.join(str(configs.project_root), ".gitignore")] + + exclude_spec_path = os.path.join( + str(configs.project_root), ".vectorcode", "vectorcode.exclude" + ) + if os.path.isfile(exclude_spec_path): + specs.append(exclude_spec_path) + elif os.path.isfile(GLOBAL_EXCLUDE_SPEC): + specs.append(GLOBAL_EXCLUDE_SPEC) + specs = [i for i in specs if os.path.isfile(i)] + logger.debug(f"Loaded exclude specs: {specs}") + return specs + + +async def vectorise_worker( + database: DatabaseConnectorBase, + file_path: str, + semaphore: Semaphore, + stats: VectoriseStats, + stats_lock: Lock, +): + async with semaphore, stats_lock: + if os.path.isfile(file_path): + stats += await database.vectorise( + file_path=file_path, + ) + + +async def vectorise(configs: Config) -> int: + assert configs.project_root is not None + database = get_database_connector(configs) + + files = await expand_globs( + configs.files or load_files_from_include(str(configs.project_root)), + recursive=configs.recursive, + include_hidden=configs.include_hidden, + ) + + total_file_count = len(files) + + filters = FilterManager() + + try: + collection_files = ( + await database.list_collection_content(what=ResultType.document) + ).files + + existing_hashes = set(i.sha256 for i in collection_files) + except CollectionNotFoundError: + existing_hashes = set() + + if not configs.force: + for spec_path in find_exclude_specs(configs): + # filter by gitignore/vectorcode.exclude + if os.path.isfile(spec_path): + logger.info(f"Loading ignore specs from {spec_path}.") + spec = SpecResolver.from_path( + spec_path, + str(configs.project_root) if configs.project_root else None, + ) + filters.add_filter(lambda x: spec.match_file(x, True)) + + # filter by sha256 + filters.add_filter(lambda x: hash_file(x) not in existing_hashes) + else: # pragma: nocover + logger.info("Ignoring exclude specs.") + + files = list(filters(files)) + stats = VectoriseStats(skipped=total_file_count - len(files)) + stats_lock = Lock() + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + + with tqdm.tqdm( + total=len(files), desc="Vectorising files...", disable=configs.pipe + ) as bar: + try: + tasks = [ + asyncio.create_task( + vectorise_worker(database, file, semaphore, stats, stats_lock) + ) + for file in files + ] + for task in asyncio.as_completed(tasks): + await task + bar.update(1) + except asyncio.CancelledError: # pragma: nocover + logger.warning("Abort.") + return 1 + + await database.check_orphanes() + + show_stats(configs=configs, stats=stats) + return 0 diff --git a/src/vectorcode/subcommands/vectorise/filter.py b/src/vectorcode/subcommands/vectorise/filter.py new file mode 100644 index 00000000..19180729 --- /dev/null +++ b/src/vectorcode/subcommands/vectorise/filter.py @@ -0,0 +1,49 @@ +import logging +import os +import sys +from typing import Callable, Iterable, Self, Sequence + +logger = logging.getLogger(name=__name__) + +FileFilter = Callable[[str], bool] + + +class FilterManager: + def __init__(self, from_filters: Sequence[FileFilter] | None = None) -> None: + self._filters: list[FileFilter] = [] + if from_filters: # pragma: nocover + self._filters.extend(from_filters) + + def add_filter(self, f: FileFilter = lambda x: bool(x)) -> Self: + self._filters.append(f) + return self + + def _has_debugging(self): # pragma: nocover + """ + Iterators are difficult to debug. + Use this function to decide whether we should convert iterators to tuples + to make debugging easier. + """ + return ( + sys.gettrace() is not None + or os.environ.get("VECTORCODE_LOG_LEVEL") is not None + ) + + def __call__(self, files: Iterable[str]) -> Iterable[str]: + if self._has_debugging(): # pragma: nocover + files = tuple(files) + logger.debug( + f"Applying the following filters: {list(i.__name__ for i in self._filters)} to the following files ({len(files)}): {files}" + ) + + if self._filters: + for f in self._filters: + files = filter(f, files) + + if self._has_debugging(): # pragma: nocover + files = tuple(files) + logger.debug( + f"{f.__name__} remaining items ({len(files)}): {files}" + ) + + return files diff --git a/tests/database/test_chroma.py b/tests/database/test_chroma.py new file mode 100644 index 00000000..a7e06d98 --- /dev/null +++ b/tests/database/test_chroma.py @@ -0,0 +1,704 @@ +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + import chromadb + + chroma_version = chromadb.__version__ + if not chroma_version.startswith("1."): + pytest.skip( + f"Found chromadb {chroma_version}. Skipping chroma tests.", + allow_module_level=True, + ) +except ModuleNotFoundError: + pytest.skip( + "ChromaDB not found. Skipping choma tests.", + allow_module_level=True, + ) + +from chromadb.api.types import QueryResult +from chromadb.errors import NotFoundError +from tree_sitter import Point + +from vectorcode.cli_utils import Config, QueryInclude +from vectorcode.database import types +from vectorcode.database.chroma import ( + ChromaDBConnector, +) +from vectorcode.database.chroma_common import convert_chroma_query_results +from vectorcode.database.errors import CollectionNotFoundError + + +@pytest.fixture +def mock_config(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Config( + project_root=tmpdir, + embedding_function="default", + db_params={ + "db_url": "http://localhost:1234", + "db_path": os.path.join(tmpdir, "db"), + "db_log_path": os.path.join(tmpdir, "log"), + "db_settings": {}, + }, + ) + + +@pytest.mark.asyncio +async def test_initialization(mock_config): + """Test that the ChromaDBConnector is initialized correctly.""" + connector = ChromaDBConnector(mock_config) + assert connector._configs.project_root == mock_config.project_root + assert "hnsw" in connector._configs.db_params + + +@pytest.mark.asyncio +async def test_query(mock_config): + """Test the query method.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._configs.query = ["test query"] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.count = AsyncMock(return_value=1) + + with patch( + "vectorcode.database.chroma.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + results = await connector.query() + assert results == ["converted_results"] + mock_collection.query.assert_called_once() + mock_convert.assert_called_once() + + +@pytest.mark.asyncio +async def test_vectorise(mock_config): + """Test the vectorise method.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_client.get_max_batch_size.return_value = 100 + mock_http_client.return_value = mock_client + + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + + mock_chunker = MagicMock() + mock_chunker.chunk.return_value = [MagicMock(text="chunk1")] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + with ( + patch("vectorcode.database.chroma.hash_file", return_value="hash1"), + patch("vectorcode.database.chroma.get_uuid", return_value="uuid1"), + ): + stats = await connector.vectorise( + os.path.join(mock_config.project_root, "file1"), chunker=mock_chunker + ) + + assert stats.add == 1 + mock_collection.add.assert_called_once() + + +@pytest.mark.asyncio +async def test_list_collections(mock_config): + """Test the list_collections method.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_collection = MagicMock() + mock_collection.name = "collection1" + mock_collection.metadata = {"path": mock_config.project_root} + mock_client.list_collections.return_value = [mock_collection] + mock_http_client.return_value = mock_client + + connector = ChromaDBConnector(mock_config) + + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[], chunks=[]) + ) + + collections = await connector.list_collections() + assert len(collections) == 1 + assert collections[0].id == "collection1" + + +@pytest.mark.asyncio +async def test_list_collection_content(mock_config): + """Test the list_collection_content method.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + "start": 1, + "end": 2, + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + + content = await connector.list_collection_content() + assert len(content.files) == 1 + assert len(content.chunks) == 1 + + +@pytest.mark.asyncio +async def test_list_collection_content_no_collection(mock_config): + mock_client = MagicMock() + mock_client.get_collection.side_effect = NotFoundError + with patch("chromadb.HttpClient", return_value=mock_client): + connector = ChromaDBConnector(mock_config) + + with pytest.raises(CollectionNotFoundError): + await connector.list_collection_content() + + +@pytest.mark.asyncio +async def test_delete(mock_config): + """Test the delete method.""" + file_to_delete = os.path.join(mock_config.project_root, "file1") + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_delete)]) + ) + mock_config.rm_paths = [file_to_delete] + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma.expand_globs", return_value=[file_to_delete] + ), + patch( + "vectorcode.database.chroma.expand_path", side_effect=mock_expand_path + ), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 1 + mock_collection.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_drop(mock_config): + """Test the drop method.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + with patch( + "vectorcode.database.chroma.get_collection_id", + return_value="collection_id", + ): + await connector.drop() + mock_client.delete_collection.assert_called_once_with("collection_id") + + +@pytest.mark.asyncio +async def test_drop_invalid_collection(mock_config): + """Test the drop method.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_client.delete_collection.side_effect = ValueError + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + with patch( + "vectorcode.database.chroma.get_collection_id", + return_value="collection_id", + ): + with pytest.raises(CollectionNotFoundError): + await connector.drop() + + +@pytest.mark.asyncio +async def test_get_chunks(mock_config): + """Test the get_chunks method.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + mock_collection.get.return_value = { + "metadatas": [{"start": 1, "end": 2}], + "documents": ["doc1"], + "ids": ["id1"], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + + chunks = await connector.get_chunks( + os.path.join(mock_config.project_root, "file1") + ) + assert len(chunks) == 1 + assert chunks[0].text == "doc1" + + +def test_convert_chroma_query_results(mock_config): + file1_path = os.path.join(mock_config.project_root, "file1") + file2_path = os.path.join(mock_config.project_root, "file2") + chroma_result: QueryResult = { + "documents": [["doc1", "doc2"]], + "distances": [[0.1, 0.2]], + "metadatas": [ + [{"path": file1_path, "start": 1, "end": 2}, {"path": file2_path}] + ], + "ids": [["id1", "id2"]], + "embeddings": None, + "uris": None, + "data": None, + } + queries = ["query1"] + results = convert_chroma_query_results(chroma_result, queries) + assert len(results) == 2 + assert results[0].chunk.text == "doc1" + assert results[0].path == file1_path + assert results[0].scores == (-0.1,) + assert results[0].chunk.start == Point(1, 0) + assert results[0].chunk.end == Point(2, 0) + assert results[1].chunk.text == "doc2" + assert results[1].path == file2_path + assert results[1].scores == (-0.2,) + + +@pytest.mark.asyncio +async def test_get_chunks_collection_not_found(mock_config): + """Test get_chunks when collection is not found.""" + connector = ChromaDBConnector(mock_config) + connector._create_or_get_collection = AsyncMock(side_effect=CollectionNotFoundError) + with patch("vectorcode.database.chroma.logger") as mock_logger: + result = await connector.get_chunks( + os.path.join(mock_config.project_root, "file1") + ) + assert result == [] + mock_logger.warning.assert_called_once() + + +@pytest.mark.asyncio +async def test_vectorise_no_embeddings(mock_config): + """Test vectorise when there are no embeddings.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_chunker = MagicMock() + mock_chunker.chunk.return_value = [MagicMock(text="chunk1")] + connector.get_embedding = MagicMock(return_value=[]) + with patch( + "vectorcode.database.chroma.ChromaDBConnector._create_or_get_collection", + new_callable=AsyncMock, + ) as mock_create_collection: + stats = await connector.vectorise( + os.path.join(mock_config.project_root, "file1"), chunker=mock_chunker + ) + assert stats.skipped == 1 + mock_create_collection.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_with_exclude(mock_config): + """Test query with exclude paths.""" + file1_path = os.path.join(mock_config.project_root, "file1") + file2_path = os.path.join(mock_config.project_root, "file2") + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._configs.query = ["test query"] + connector._configs.query_exclude = [file2_path] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": file1_path}]], + "ids": [["id1"]], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.count = AsyncMock(return_value=1) + + with patch( + "vectorcode.database.chroma.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == {"path": {"$nin": [file2_path]}} + + +@pytest.mark.asyncio +async def test_query_with_include_chunk(mock_config): + """Test query with include chunk.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._configs.query = ["test query"] + connector._configs.include = [QueryInclude.chunk] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.count = AsyncMock(return_value=1) + + with patch( + "vectorcode.database.chroma.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == {"start": {"$gte": 0}} + + +@pytest.mark.asyncio +async def test_create_or_get_collection_not_found(mock_config): + """Test _create_or_get_collection when collection is not found and allow_create is False.""" + from chromadb.errors import NotFoundError + + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_client.get_collection.side_effect = NotFoundError + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + + with pytest.raises(CollectionNotFoundError): + await connector._create_or_get_collection( + "collection_path", allow_create=False + ) + + +@pytest.mark.asyncio +async def test_delete_no_paths(mock_config): + """Test delete with no paths to remove.""" + file_to_keep = os.path.join(mock_config.project_root, "file1") + non_existent_file = os.path.join(mock_config.project_root, "non_existent_file") + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_keep)]) + ) + mock_config.rm_paths = [non_existent_file] + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma.expand_globs", + return_value=[non_existent_file], + ), + patch( + "vectorcode.database.chroma.expand_path", side_effect=mock_expand_path + ), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 0 + mock_collection.delete.assert_not_called() + + +@pytest.mark.asyncio +async def test_list_collection_content_with_what(mock_config): + """Test the list_collection_content method with the 'what' parameter.""" + with ( + patch("chromadb.HttpClient"), + patch( + "vectorcode.database.chroma.ChromaDBConnector._create_or_get_collection", + new_callable=AsyncMock, + ) as mock_create_collection, + ): + mock_collection = MagicMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + mock_create_collection.return_value = mock_collection + connector = ChromaDBConnector(mock_config) + + # Test with what=ResultType.document + content = await connector.list_collection_content( + what=types.ResultType.document + ) + assert len(content.files) == 1 + assert len(content.chunks) == 0 + + # Test with what=ResultType.chunk + content = await connector.list_collection_content(what=types.ResultType.chunk) + assert len(content.files) == 0 + assert len(content.chunks) == 1 + + +@pytest.mark.asyncio +async def test_delete_with_string_rm_paths(mock_config): + """Test delete with rm_paths as a string.""" + file_to_delete = os.path.join(mock_config.project_root, "file1") + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_delete)]) + ) + mock_config.rm_paths = file_to_delete + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma.expand_globs", return_value=[file_to_delete] + ), + patch( + "vectorcode.database.chroma.expand_path", side_effect=mock_expand_path + ), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 1 + mock_collection.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_drop_with_collection_path(mock_config): + """Test drop with collection_path.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + with patch( + "vectorcode.database.chroma.get_collection_id", + return_value="collection_id", + ) as mock_get_collection_id: + await connector.drop(collection_path=mock_config.project_root) + mock_get_collection_id.assert_called_once_with(mock_config.project_root) + mock_client.delete_collection.assert_called_once_with("collection_id") + + +@pytest.mark.asyncio +async def test_get_chunks_generic_exception(mock_config): + """Test get_chunks with a generic exception.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._create_or_get_collection = AsyncMock( + side_effect=Exception("test error") + ) + with pytest.raises(Exception) as excinfo: + await connector.get_chunks(os.path.join(mock_config.project_root, "file1")) + assert "test error" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_query_no_n_result(mock_config): + """Test the query method without n_result.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._configs.query = ["test query"] + connector._configs.n_result = None + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.count = AsyncMock(return_value=10) + + with patch( + "vectorcode.database.chroma.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + _, kwargs = mock_collection.query.call_args + assert kwargs["n_results"] == 10 + + +@pytest.mark.asyncio +async def test_create_or_get_collection_exists(mock_config: Config): + """Test _create_or_get_collection when collection exists and allow_create is True.""" + mock_config.db_params["hnsw"] = {"M": 64} + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + + mock_collection = MagicMock() + mock_collection.metadata = { + "path": os.path.abspath(str(mock_config.project_root)), + "hostname": "test-host", + "created-by": "VectorCode", + "username": "DEFAULT_USER", + "embedding_function": "default", + "hnsw:M": 64, + } + mock_client.get_or_create_collection.return_value = mock_collection + + with ( + patch("os.environ.get", return_value="DEFAULT_USER"), + patch("socket.gethostname", return_value="test-host"), + ): + collection = await connector._create_or_get_collection( + "collection_path", allow_create=True + ) + assert collection == mock_collection + mock_client.get_or_create_collection.assert_called_once() + + +@pytest.mark.asyncio +async def test_list_collection_content_with_id(mock_config): + """Test the list_collection_content method with collection_id.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + + mock_collection = MagicMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + mock_client.get_collection.return_value = mock_collection + + content = await connector.list_collection_content(collection_id="test_id") + assert len(content.files) == 1 + assert len(content.chunks) == 1 + mock_client.get_collection.assert_called_once_with("test_id") + + +@pytest.mark.asyncio +async def test_query_with_exclude_and_include_chunk(mock_config): + """Test query with exclude paths and include chunk.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + connector._configs.query = ["test query"] + connector._configs.query_exclude = ["file2"] + connector._configs.include = [QueryInclude.chunk] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = MagicMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": "file1"}]], + "ids": [["id1"]], + } + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.count = AsyncMock(return_value=1) + + with patch( + "vectorcode.database.chroma.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == { + "$and": [{"path": {"$nin": ["file2"]}}, {"start": {"$gte": 0}}] + } + + +@pytest.mark.asyncio +async def test_create_or_get_collection_metadata_mismatch(mock_config): + """Test _create_or_get_collection when metadata mismatches.""" + with patch("chromadb.HttpClient") as mock_http_client: + mock_client = MagicMock() + mock_http_client.return_value = mock_client + connector = ChromaDBConnector(mock_config) + + mock_collection = MagicMock() + mock_collection.metadata = { + "path": os.path.abspath(str(mock_config.project_root)), + "hostname": "test-host", + "created-by": "VectorCode", + "username": "DIFFERENT_USER", + "embedding_function": "default", + "hnsw:M": 64, + } + mock_client.get_or_create_collection.return_value = mock_collection + + with ( + patch("os.environ.get", return_value="DEFAULT_USER"), + patch("socket.gethostname", return_value="test-host"), + ): + with pytest.raises(AssertionError): + await connector._create_or_get_collection( + "collection_path", allow_create=True + ) + + +@pytest.mark.asyncio +async def test_delete_no_matching_files(mock_config): + """Test delete with no matching files.""" + with patch("chromadb.HttpClient"): + connector = ChromaDBConnector(mock_config) + mock_collection = MagicMock() + connector._create_or_get_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path="file1")]) + ) + mock_config.rm_paths = ["file2"] + + def mock_expand_path(path, absolute): + return path + + with ( + patch("vectorcode.database.chroma.expand_globs", return_value=["file2"]), + patch( + "vectorcode.database.chroma.expand_path", side_effect=mock_expand_path + ), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 0 + mock_collection.delete.assert_not_called() + + +@pytest.mark.asyncio +async def test_persistent_client(mock_config): + with tempfile.TemporaryDirectory() as tmp_db_dir: + mock_config.db_params = {"db_path": tmp_db_dir} + connector = ChromaDBConnector(mock_config) + await connector.get_client() + assert connector._client_type == "persistent" + assert os.path.isfile(os.path.join(tmp_db_dir, "vectorcode.lock")) + async with connector.maybe_lock(): + assert connector._file_lock.is_locked + assert connector._thread_lock.locked() diff --git a/tests/database/test_chroma0.py b/tests/database/test_chroma0.py new file mode 100644 index 00000000..f215b9b1 --- /dev/null +++ b/tests/database/test_chroma0.py @@ -0,0 +1,849 @@ +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +try: + import chromadb + + if not chromadb.__version__.startswith("0.6.3"): + pytest.skip( + f"Found chromadb {chromadb.__version__}. Skipping chroma0 tests.", + allow_module_level=True, + ) +except ModuleNotFoundError: + pytest.skip( + "ChromaDB 0.6.3 not found. Skipping choma0 tests.", + allow_module_level=True, + ) + +from chromadb.api.types import QueryResult +from chromadb.errors import InvalidCollectionException +from tree_sitter import Point + +from vectorcode.cli_utils import Config, QueryInclude +from vectorcode.database import types +from vectorcode.database.chroma0 import ( + ChromaDB0Connector, + _Chroma0ClientManager, + _start_server, + _try_server, + _wait_for_server, + convert_chroma_query_results, +) +from vectorcode.database.errors import CollectionNotFoundError + + +@pytest.fixture +def mock_config(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Config( + project_root=tmpdir, + embedding_function="default", + db_params={ + "db_url": "http://localhost:1234", + "db_path": os.path.join(tmpdir, "db"), + "db_log_path": os.path.join(tmpdir, "log"), + "db_settings": {}, + }, + ) + + +@pytest.mark.asyncio +async def test_initialization(mock_config): + """Test that the ChromaDB0Connector is initialized correctly.""" + with patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client", + new_callable=AsyncMock, + ) as mock_get_client: + # Mock the async context manager + mock_async_context = AsyncMock() + mock_get_client.return_value = mock_async_context + + # Mock the client object itself + mock_client = AsyncMock() + mock_async_context.__aenter__.return_value = mock_client + mock_client.get_version.return_value = "0.6.3" + + connector = ChromaDB0Connector(mock_config) + assert connector._configs == mock_config + + +@pytest.mark.asyncio +async def test_query(mock_config): + """Test the query method.""" + connector = ChromaDB0Connector(mock_config) + connector._configs.query = ["test query"] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = AsyncMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + with patch( + "vectorcode.database.chroma0.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + results = await connector.query() + assert results == ["converted_results"] + mock_collection.query.assert_called_once() + mock_convert.assert_called_once() + + +@pytest.mark.asyncio +async def test_vectorise(mock_config): + """Test the vectorise method.""" + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + mock_chunker = MagicMock() + mock_chunker.chunk.return_value = [MagicMock(text="chunk1")] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + with ( + patch("vectorcode.database.chroma0.hash_file", return_value="hash1"), + patch("vectorcode.database.chroma0.get_uuid", return_value="uuid1"), + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + ): + mock_client = AsyncMock() + mock_client.get_max_batch_size.return_value = 100 + mock_get_client.return_value.__aenter__.return_value = mock_client + + stats = await connector.vectorise( + os.path.join(mock_config.project_root, "file1"), chunker=mock_chunker + ) + + assert stats.add == 1 + mock_collection.add.assert_called_once() + + +@pytest.mark.asyncio +async def test_list_collections(mock_config): + """Test the list_collections method.""" + connector = ChromaDB0Connector(mock_config) + with patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client: + mock_client = AsyncMock() + mock_client.list_collections.return_value = ["collection1"] + mock_collection = AsyncMock() + mock_collection.metadata = {"path": mock_config.project_root} + mock_client.get_collection.return_value = mock_collection + mock_get_client.return_value.__aenter__.return_value = mock_client + + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[], chunks=[]) + ) + + collections = await connector.list_collections() + assert len(collections) == 1 + assert collections[0].id == "collection1" + + +@pytest.mark.asyncio +async def test_list_collection_content(mock_config): + """Test the list_collection_content method.""" + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + "start": 1, + "end": 2, + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + content = await connector.list_collection_content() + assert len(content.files) == 1 + assert len(content.chunks) == 1 + + +@pytest.mark.asyncio +async def test_delete(mock_config): + """Test the delete method.""" + file_to_delete = os.path.join(mock_config.project_root, "file1") + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_delete)]) + ) + mock_config.rm_paths = [file_to_delete] + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma0.expand_globs", return_value=[file_to_delete] + ), + patch("vectorcode.database.chroma0.expand_path", side_effect=mock_expand_path), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 1 + mock_collection.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_drop(mock_config): + """Test the drop method.""" + connector = ChromaDB0Connector(mock_config) + with ( + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + patch( + "vectorcode.database.chroma0.get_collection_id", + return_value="collection_id", + ), + ): + mock_client = AsyncMock() + mock_get_client.return_value.__aenter__.return_value = mock_client + await connector.drop() + mock_client.delete_collection.assert_called_once_with("collection_id") + + +@pytest.mark.asyncio +async def test_drop_invalid_collection(mock_config): + """Test the drop method.""" + connector = ChromaDB0Connector(mock_config) + with ( + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + patch( + "vectorcode.database.chroma0.get_collection_id", + return_value="collection_id", + ), + ): + mock_client = AsyncMock() + mock_get_client.return_value.__aenter__.return_value = mock_client + mock_client.delete_collection.side_effect = ValueError + with pytest.raises(CollectionNotFoundError): + await connector.drop() + + +@pytest.mark.asyncio +async def test_get_chunks(mock_config): + """Test the get_chunks method.""" + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + mock_collection.get.return_value = { + "metadatas": [{"start": 1, "end": 2}], + "documents": ["doc1"], + "ids": ["id1"], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + chunks = await connector.get_chunks(os.path.join(mock_config.project_root, "file1")) + assert len(chunks) == 1 + assert chunks[0].text == "doc1" + + +def test_convert_chroma_query_results(mock_config): + file1_path = os.path.join(mock_config.project_root, "file1") + file2_path = os.path.join(mock_config.project_root, "file2") + chroma_result: QueryResult = { + "documents": [["doc1", "doc2"]], + "distances": [[0.1, 0.2]], + "metadatas": [ + [{"path": file1_path, "start": 1, "end": 2}, {"path": file2_path}] + ], + "ids": [["id1", "id2"]], + "embeddings": None, + "uris": None, + "data": None, + } + queries = ["query1"] + results = convert_chroma_query_results(chroma_result, queries) + assert len(results) == 2 + assert results[0].chunk.text == "doc1" + assert results[0].path == file1_path + assert results[0].scores == (-0.1,) + assert results[0].chunk.start == Point(1, 0) + assert results[0].chunk.end == Point(2, 0) + assert results[1].chunk.text == "doc2" + assert results[1].path == file2_path + assert results[1].scores == (-0.2,) + + +@pytest.mark.asyncio +async def test_get_chunks_collection_not_found(mock_config): + """Test get_chunks when collection is not found.""" + connector = ChromaDB0Connector(mock_config) + connector._create_or_get_async_collection = AsyncMock( + side_effect=CollectionNotFoundError + ) + with patch("vectorcode.database.chroma0._logger") as mock_logger: + result = await connector.get_chunks( + os.path.join(mock_config.project_root, "file1") + ) + assert result == [] + mock_logger.warning.assert_called_once() + + +@pytest.mark.asyncio +async def test_vectorise_no_embeddings(mock_config): + """Test vectorise when there are no embeddings.""" + connector = ChromaDB0Connector(mock_config) + mock_chunker = MagicMock() + mock_chunker.chunk.return_value = [MagicMock(text="chunk1")] + connector.get_embedding = MagicMock(return_value=[]) + with patch( + "vectorcode.database.chroma0.ChromaDB0Connector._create_or_get_async_collection", + new_callable=AsyncMock, + ) as mock_create_collection: + stats = await connector.vectorise( + os.path.join(mock_config.project_root, "file1"), chunker=mock_chunker + ) + assert stats.skipped == 1 + mock_create_collection.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_with_exclude(mock_config): + """Test query with exclude paths.""" + file1_path = os.path.join(mock_config.project_root, "file1") + file2_path = os.path.join(mock_config.project_root, "file2") + connector = ChromaDB0Connector(mock_config) + connector._configs.query = ["test query"] + connector._configs.query_exclude = [file2_path] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = AsyncMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": file1_path}]], + "ids": [["id1"]], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + with patch( + "vectorcode.database.chroma0.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == {"path": {"$nin": [file2_path]}} + + +@pytest.mark.asyncio +async def test_query_with_include_chunk(mock_config): + """Test query with include chunk.""" + connector = ChromaDB0Connector(mock_config) + connector._configs.query = ["test query"] + connector._configs.include = [QueryInclude.chunk] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = AsyncMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + with patch( + "vectorcode.database.chroma0.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == {"start": {"$gte": 0}} + + +@pytest.mark.asyncio +async def test_create_or_get_async_collection_not_found(mock_config): + """Test _create_or_get_async_collection when collection is not found and allow_create is False.""" + connector = ChromaDB0Connector(mock_config) + with patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client: + mock_client = AsyncMock() + mock_client.get_collection.side_effect = InvalidCollectionException + mock_get_client.return_value.__aenter__.return_value = mock_client + + with pytest.raises(CollectionNotFoundError): + await connector._create_or_get_async_collection( + "collection_path", allow_create=False + ) + + +@pytest.mark.asyncio +async def test_delete_no_paths(mock_config): + """Test delete with no paths to remove.""" + file_to_keep = os.path.join(mock_config.project_root, "file1") + non_existent_file = os.path.join(mock_config.project_root, "non_existent_file") + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_keep)]) + ) + mock_config.rm_paths = [non_existent_file] + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma0.expand_globs", return_value=[non_existent_file] + ), + patch("vectorcode.database.chroma0.expand_path", side_effect=mock_expand_path), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 0 + mock_collection.delete.assert_not_called() + + +@pytest.mark.asyncio +async def test_list_collection_content_with_what(mock_config): + """Test the list_collection_content method with the 'what' parameter.""" + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + # Test with what=ResultType.document + content = await connector.list_collection_content(what=types.ResultType.document) + assert len(content.files) == 1 + assert len(content.chunks) == 0 + + # Test with what=ResultType.chunk + content = await connector.list_collection_content(what=types.ResultType.chunk) + assert len(content.files) == 0 + assert len(content.chunks) == 1 + + +@pytest.mark.asyncio +async def test_delete_with_string_rm_paths(mock_config): + """Test delete with rm_paths as a string.""" + file_to_delete = os.path.join(mock_config.project_root, "file1") + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path=file_to_delete)]) + ) + mock_config.rm_paths = file_to_delete + + def mock_expand_path(path, absolute): + return path + + with ( + patch( + "vectorcode.database.chroma0.expand_globs", return_value=[file_to_delete] + ), + patch("vectorcode.database.chroma0.expand_path", side_effect=mock_expand_path), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 1 + mock_collection.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_drop_with_collection_path(mock_config): + """Test drop with collection_path.""" + connector = ChromaDB0Connector(mock_config) + with ( + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + patch( + "vectorcode.database.chroma0.get_collection_id", + return_value="collection_id", + ) as mock_get_collection_id, + ): + mock_client = AsyncMock() + mock_get_client.return_value.__aenter__.return_value = mock_client + await connector.drop(collection_path=mock_config.project_root) + mock_get_collection_id.assert_called_once_with(mock_config.project_root) + mock_client.delete_collection.assert_called_once_with("collection_id") + + +@pytest.mark.asyncio +async def test_get_chunks_generic_exception(mock_config): + """Test get_chunks with a generic exception.""" + connector = ChromaDB0Connector(mock_config) + connector._create_or_get_async_collection = AsyncMock( + side_effect=Exception("test error") + ) + with pytest.raises(Exception) as excinfo: + await connector.get_chunks(os.path.join(mock_config.project_root, "file1")) + assert "test error" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_try_server_success(): + """Test _try_server when the server is running.""" + with patch("httpx.AsyncClient") as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_client.return_value.__aenter__.return_value.get.return_value = ( + mock_response + ) + + assert await _try_server("http://localhost:8000") is True + + +@pytest.mark.asyncio +async def test_try_server_failure(): + """Test _try_server when the server is not running.""" + with patch("httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.get.side_effect = ( + httpx.ConnectError("test") + ) + + assert await _try_server("http://localhost:8000") is False + + +@pytest.mark.asyncio +async def test_wait_for_server_success(): + """Test _wait_for_server when the server starts.""" + with patch( + "vectorcode.database.chroma0._try_server", new_callable=AsyncMock + ) as mock_try_server: + mock_try_server.side_effect = [False, True] + await _wait_for_server("http://localhost:8000", timeout=1) + assert mock_try_server.call_count == 2 + + +@pytest.mark.asyncio +async def test_wait_for_server_timeout(): + """Test _wait_for_server when the server does not start.""" + with patch( + "vectorcode.database.chroma0._try_server", new_callable=AsyncMock + ) as mock_try_server: + mock_try_server.return_value = False + with pytest.raises(TimeoutError): + await _wait_for_server("http://localhost:8000", timeout=0.2) + + +@pytest.mark.asyncio +async def test_start_server(mock_config): + """Test the _start_server function.""" + with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_exec: + mock_process = AsyncMock() + mock_exec.return_value = mock_process + with patch( + "vectorcode.database.chroma0._wait_for_server", new_callable=AsyncMock + ) as mock_wait: + process = await _start_server(mock_config) + assert process == mock_process + mock_exec.assert_called_once() + mock_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_client_manager_get_client_new_server(mock_config): + """Test get_client when a new server needs to be started.""" + with patch("atexit.register"): + manager = _Chroma0ClientManager() + manager.clear() + with ( + patch( + "vectorcode.database.chroma0._try_server", new_callable=AsyncMock + ) as mock_try_server, + patch( + "vectorcode.database.chroma0._start_server", new_callable=AsyncMock + ) as mock_start_server, + patch( + "vectorcode.database.chroma0._Chroma0ClientManager._create_client", + new_callable=AsyncMock, + ) as mock_create_client, + ): + mock_try_server.return_value = False + mock_process = MagicMock() + mock_process.returncode = None + mock_start_server.return_value = mock_process + mock_client = AsyncMock() + mock_client.get_version.return_value = "0.1.0" + mock_create_client.return_value = mock_client + + async with manager.get_client(mock_config, need_lock=False) as client: + assert client == mock_client + assert manager.get_processes() == [mock_process] + + manager.kill_servers() + mock_process.terminate.assert_called_once() + manager.clear() + + +@pytest.mark.asyncio +async def test_client_manager_get_client_existing_server(mock_config): + """Test get_client with an existing server.""" + manager = _Chroma0ClientManager() + manager.clear() + with ( + patch( + "vectorcode.database.chroma0._try_server", new_callable=AsyncMock + ) as mock_try_server, + patch( + "vectorcode.database.chroma0._Chroma0ClientManager._create_client", + new_callable=AsyncMock, + ) as mock_create_client, + ): + mock_try_server.return_value = True + mock_client = AsyncMock() + mock_client.get_version.return_value = "0.1.0" + mock_create_client.return_value = mock_client + + async with manager.get_client(mock_config, need_lock=False) as client: + assert client == mock_client + assert not manager.get_processes() + manager.clear() + + +@pytest.mark.asyncio +async def test_create_client(mock_config): + """Test the _create_client method.""" + manager = _Chroma0ClientManager() + with patch("chromadb.AsyncHttpClient", new_callable=AsyncMock) as mock_http_client: + await manager._create_client(mock_config) + mock_http_client.assert_called_once() + + +@pytest.mark.asyncio +async def test_client_manager_get_client_with_lock(mock_config): + """Test get_client with a lock.""" + with patch("atexit.register"): + manager = _Chroma0ClientManager() + manager.clear() + with ( + patch( + "vectorcode.database.chroma0._try_server", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "vectorcode.database.chroma0._start_server", new_callable=AsyncMock + ) as mock_start_server, + patch( + "vectorcode.database.chroma0._Chroma0ClientManager._create_client", + new_callable=AsyncMock, + ) as mock_create_client, + patch("vectorcode.database.chroma0.LockManager") as mock_lock_manager, + ): + mock_process = MagicMock() + mock_process.returncode = None + mock_start_server.return_value = mock_process + mock_client = AsyncMock() + mock_client.get_version.return_value = "0.1.0" + mock_create_client.return_value = mock_client + mock_lock = AsyncMock() + mock_lock_manager.return_value.get_lock.return_value = mock_lock + + async with manager.get_client(mock_config, need_lock=True) as client: + assert client == mock_client + + mock_lock.acquire.assert_called_once() + mock_lock.release.assert_called_once() + + manager.kill_servers() + manager.clear() + + +@pytest.mark.asyncio +async def test_query_no_n_result(mock_config): + """Test the query method without n_result.""" + connector = ChromaDB0Connector(mock_config) + connector._configs.query = ["test query"] + connector._configs.n_result = None + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = AsyncMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": os.path.join(mock_config.project_root, "file1")}]], + "ids": [["id1"]], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + mock_content = MagicMock() + mock_content.chunks = [1] * 10 + connector.list_collection_content = AsyncMock(return_value=mock_content) + + with patch( + "vectorcode.database.chroma0.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + _, kwargs = mock_collection.query.call_args + assert kwargs["n_results"] == 10 + + +@pytest.mark.asyncio +async def test_create_or_get_async_collection_exists(mock_config: Config): + """Test _create_or_get_async_collection when collection exists and allow_create is True.""" + mock_config.db_params["hnsw"] = {"M": 64} + connector = ChromaDB0Connector(mock_config) + with ( + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + patch("os.environ.get", return_value="DEFAULT_USER"), + ): + mock_client = AsyncMock() + mock_collection = AsyncMock() + mock_collection.metadata = { + "path": os.path.abspath(str(mock_config.project_root)), + "hostname": "test-host", + "created-by": "VectorCode", + "username": "DEFAULT_USER", + "embedding_function": "default", + "hnsw:M": 64, + } + mock_client.get_or_create_collection.return_value = mock_collection + mock_get_client.return_value.__aenter__.return_value = mock_client + with patch("socket.gethostname", return_value="test-host"): + collection = await connector._create_or_get_async_collection( + "collection_path", allow_create=True + ) + assert collection == mock_collection + mock_client.get_or_create_collection.assert_called_once() + + +@pytest.mark.asyncio +async def test_list_collection_content_with_id(mock_config): + """Test the list_collection_content method with collection_id.""" + connector = ChromaDB0Connector(mock_config) + with patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client: + mock_client = AsyncMock() + mock_collection = AsyncMock() + mock_collection.get.return_value = { + "metadatas": [ + { + "path": os.path.join(mock_config.project_root, "file1"), + "sha256": "hash1", + } + ], + "documents": ["doc1"], + "ids": ["id1"], + } + mock_client.get_collection.return_value = mock_collection + mock_get_client.return_value.__aenter__.return_value = mock_client + + content = await connector.list_collection_content(collection_id="test_id") + assert len(content.files) == 1 + assert len(content.chunks) == 1 + mock_client.get_collection.assert_called_once_with("test_id") + + +@pytest.mark.asyncio +async def test_query_with_exclude_and_include_chunk(mock_config): + """Test query with exclude paths and include chunk.""" + connector = ChromaDB0Connector(mock_config) + connector._configs.query = ["test query"] + connector._configs.query_exclude = ["file2"] + connector._configs.include = [QueryInclude.chunk] + connector.get_embedding = MagicMock(return_value=[[1.0, 2.0, 3.0]]) + + mock_collection = AsyncMock() + mock_collection.query.return_value = { + "documents": [["doc1"]], + "distances": [[0.1]], + "metadatas": [[{"path": "file1"}]], + "ids": [["id1"]], + } + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + + with patch( + "vectorcode.database.chroma0.convert_chroma_query_results" + ) as mock_convert: + mock_convert.return_value = ["converted_results"] + await connector.query() + mock_collection.query.assert_called_once() + _, kwargs = mock_collection.query.call_args + assert "where" in kwargs + assert kwargs["where"] == { + "$and": [{"path": {"$nin": ["file2"]}}, {"start": {"$gte": 0}}] + } + + +@pytest.mark.asyncio +async def test_create_or_get_async_collection_metadata_mismatch(mock_config): + """Test _create_or_get_async_collection when metadata mismatches.""" + connector = ChromaDB0Connector(mock_config) + with ( + patch( + "vectorcode.database.chroma0._Chroma0ClientManager.get_client" + ) as mock_get_client, + patch("os.environ.get", return_value="DEFAULT_USER"), + ): + mock_client = AsyncMock() + mock_collection = AsyncMock() + mock_collection.metadata = { + "path": os.path.abspath(str(mock_config.project_root)), + "hostname": "test-host", + "created-by": "VectorCode", + "username": "DIFFERENT_USER", + "embedding_function": "default", + "hnsw:M": 64, + } + mock_client.get_or_create_collection.return_value = mock_collection + mock_get_client.return_value.__aenter__.return_value = mock_client + with patch("socket.gethostname", return_value="test-host"): + with pytest.raises(AssertionError): + await connector._create_or_get_async_collection( + "collection_path", allow_create=True + ) + + +@pytest.mark.asyncio +async def test_delete_no_matching_files(mock_config): + """Test delete with no matching files.""" + connector = ChromaDB0Connector(mock_config) + mock_collection = AsyncMock() + connector._create_or_get_async_collection = AsyncMock(return_value=mock_collection) + connector.list_collection_content = AsyncMock( + return_value=MagicMock(files=[MagicMock(path="file1")]) + ) + mock_config.rm_paths = ["file2"] + + def mock_expand_path(path, absolute): + return path + + with ( + patch("vectorcode.database.chroma0.expand_globs", return_value=["file2"]), + patch("vectorcode.database.chroma0.expand_path", side_effect=mock_expand_path), + patch("os.path.isfile", return_value=True), + ): + deleted_count = await connector.delete() + assert deleted_count == 0 + mock_collection.delete.assert_not_called() diff --git a/tests/database/test_db_init.py b/tests/database/test_db_init.py new file mode 100644 index 00000000..374135d3 --- /dev/null +++ b/tests/database/test_db_init.py @@ -0,0 +1,45 @@ +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from vectorcode.cli_utils import Config +from vectorcode.database import get_database_connector + + +@pytest.mark.parametrize( + "db_type, module_to_mock, class_name", + [ + ("ChromaDB0", "vectorcode.database.chroma0", "ChromaDB0Connector"), + ("ChromaDB", "vectorcode.database.chroma", "ChromaDBConnector"), + # To test a new connector, add a tuple here following the same pattern. + # e.g. ("NewDB", "vectorcode.database.newdb", "NewDBConnector"), + ], +) +def test_get_database_connector(db_type, module_to_mock, class_name): + """ + Tests that get_database_connector can correctly return a connector + for a given db_type. This test is parameterized to be easily + extensible for new database connectors. + """ + mock_connector_class = MagicMock() + mock_module = MagicMock() + setattr(mock_module, class_name, mock_connector_class) + + # Use patch.dict to temporarily replace the module in sys.modules. + # This prevents the actual module from being imported, avoiding + # errors if its dependencies are not installed. + with patch.dict(sys.modules, {module_to_mock: mock_module}): + config = Config(db_type=db_type) + connector = get_database_connector(config) + + # Verify that the create method was called on our mock class + mock_connector_class.create.assert_called_once_with(config) + + # Verify that the returned connector is the one from our mock + assert connector == mock_connector_class.create.return_value + + +def test_get_database_connector_invalid_type(): + with pytest.raises(ValueError): + get_database_connector(Config(db_type="InvalidDB")) diff --git a/tests/database/test_db_types.py b/tests/database/test_db_types.py new file mode 100644 index 00000000..7c48b932 --- /dev/null +++ b/tests/database/test_db_types.py @@ -0,0 +1,46 @@ +from vectorcode.chunking import Chunk +from vectorcode.database.types import QueryResult, VectoriseStats + +""" +For boilerplate code that wasn't covered in other tests. +""" + + +def test_vectorstats_add(): + assert VectoriseStats( + add=1, update=2, removed=3, skipped=4, failed=5 + ) + VectoriseStats( + add=5, update=4, removed=3, skipped=2, failed=1 + ) == VectoriseStats(add=6, update=6, removed=6, skipped=6, failed=6) + + assert VectoriseStats( + add=1, update=2, removed=3, skipped=4, failed=5 + ) + VectoriseStats( + add=5, update=4, removed=3, skipped=2, failed=1 + ) != VectoriseStats(add=6, update=6, removed=6, skipped=6, failed=5) + + +def test_query_result_equal(): + assert QueryResult( + path="some_path", + chunk=Chunk(text="some_text"), + query=("some_query",), + scores=(1,), + ) == QueryResult( + path="other_path", + chunk=Chunk(text="other_text"), + query=("some_query",), + scores=(1.0,), + ) + + assert QueryResult( + path="some_path", + chunk=Chunk(text="some_text"), + query=("some_query",), + scores=(1,), + ) != QueryResult( + path="other_path", + chunk=Chunk(text="other_text"), + query=("some_query",), + scores=(2.0,), + ) diff --git a/tests/subcommands/files/test_files_ls.py b/tests/subcommands/files/test_files_ls.py index fc51caa6..9740776e 100644 --- a/tests/subcommands/files/test_files_ls.py +++ b/tests/subcommands/files/test_files_ls.py @@ -2,47 +2,29 @@ from unittest.mock import AsyncMock, patch import pytest -from chromadb.api.models.AsyncCollection import AsyncCollection from vectorcode.cli_utils import CliAction, Config, FilesAction +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import FileInCollection from vectorcode.subcommands.files.ls import ls @pytest.fixture -def client(): - return AsyncMock() - - -@pytest.fixture -def collection(): - col = AsyncMock(spec=AsyncCollection) - col.get.return_value = { - "ids": ["id1", "id2", "id3"], - "distances": [0.1, 0.2, 0.3], - "metadatas": [ - {"path": "file1.py", "start": 1, "end": 1}, - {"path": "file2.py", "start": 1, "end": 1}, - {"path": "file3.py", "start": 1, "end": 1}, - ], - "documents": [ - "content1", - "content2", - "content3", - ], - } - return col +def mock_db(): + db = AsyncMock() + db.list_collection_content.return_value.files = [ + FileInCollection(path="file1.py", sha256="hash1"), + FileInCollection(path="file2.py", sha256="hash2"), + FileInCollection(path="file3.py", sha256="hash3"), + ] + return db @pytest.mark.asyncio -async def test_ls(client, collection, capsys): - with ( - patch("vectorcode.subcommands.files.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.files.ls.get_collection", return_value=collection - ), - patch("vectorcode.common.try_server", return_value=True), +async def test_ls(mock_db, capsys): + with patch( + "vectorcode.subcommands.files.ls.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client await ls(Config(action=CliAction.files, files_action=FilesAction.ls)) out = capsys.readouterr().out assert "file1.py" in out @@ -51,27 +33,21 @@ async def test_ls(client, collection, capsys): @pytest.mark.asyncio -async def test_ls_piped(client, collection, capsys): - with ( - patch("vectorcode.subcommands.files.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.files.ls.get_collection", return_value=collection - ), - patch("vectorcode.common.try_server", return_value=True), +async def test_ls_piped(mock_db, capsys): + with patch( + "vectorcode.subcommands.files.ls.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client await ls(Config(action=CliAction.files, files_action=FilesAction.ls, pipe=True)) out = capsys.readouterr().out assert json.dumps(["file1.py", "file2.py", "file3.py"]).strip() == out.strip() @pytest.mark.asyncio -async def test_ls_no_collection(client, collection, capsys): - with ( - patch("vectorcode.subcommands.files.ls.ClientManager") as MockClientManager, - patch("vectorcode.subcommands.files.ls.get_collection", side_effect=ValueError), +async def test_ls_no_collection(mock_db): + mock_db.list_collection_content.side_effect = CollectionNotFoundError + with patch( + "vectorcode.subcommands.files.ls.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client assert ( await ls( Config(action=CliAction.files, files_action=FilesAction.ls, pipe=True) @@ -81,18 +57,15 @@ async def test_ls_no_collection(client, collection, capsys): @pytest.mark.asyncio -async def test_ls_empty_collection(client, capsys): - mock_collection = AsyncMock(spec=AsyncCollection) - mock_collection.get.return_value = {} - with ( - patch("vectorcode.subcommands.files.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.files.ls.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.common.try_server", return_value=True), +async def test_ls_empty_collection(mock_db, capsys): + mock_db.list_collection_content.return_value.files = [] + with patch( + "vectorcode.subcommands.files.ls.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client assert ( - await ls(Config(action=CliAction.files, files_action=FilesAction.ls)) == 0 + await ls( + Config(pipe=True, action=CliAction.files, files_action=FilesAction.ls) + ) + == 0 ) + assert capsys.readouterr().out.strip() == "[]" diff --git a/tests/subcommands/files/test_files_rm.py b/tests/subcommands/files/test_files_rm.py index f5bcf670..e0133757 100644 --- a/tests/subcommands/files/test_files_rm.py +++ b/tests/subcommands/files/test_files_rm.py @@ -1,101 +1,60 @@ +from pathlib import Path from unittest.mock import AsyncMock, patch import pytest -from chromadb.api.models.AsyncCollection import AsyncCollection -from vectorcode.cli_utils import CliAction, Config, FilesAction +from vectorcode.cli_utils import Config +from vectorcode.database.errors import CollectionNotFoundError from vectorcode.subcommands.files.rm import rm @pytest.fixture -def client(): - return AsyncMock() +def mock_db(): + db = AsyncMock() + def mock_delete(): + count = 0 + for f in db._configs.rm_paths: + if Path(f).name in {"file1.py", "file2.py", "file3.py"}: + count += 1 + return count -@pytest.fixture -def collection(): - col = AsyncMock(spec=AsyncCollection) - col.get.return_value = { - "ids": ["id1", "id2", "id3"], - "distances": [0.1, 0.2, 0.3], - "metadatas": [ - {"path": "file1.py", "start": 1, "end": 1}, - {"path": "file2.py", "start": 1, "end": 1}, - {"path": "file3.py", "start": 1, "end": 1}, - ], - "documents": [ - "content1", - "content2", - "content3", - ], - } - col.name = "test_collection" - return col + db.delete = AsyncMock(side_effect=mock_delete) + return db @pytest.mark.asyncio -async def test_rm(client, collection, capsys): - with ( - patch("vectorcode.subcommands.files.rm.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.files.rm.get_collection", return_value=collection - ), - patch("vectorcode.common.try_server", return_value=True), - patch("os.path.isfile", return_value=True), - patch( - "vectorcode.subcommands.files.rm.expand_path", side_effect=lambda x, y: x - ), +async def test_rm(mock_db, capsys): + configs = Config(rm_paths=["file1.py", "file2.py"]) + mock_db._configs = configs + with patch( + "vectorcode.subcommands.files.rm.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client - config = Config( - action=CliAction.files, - files_action=FilesAction.rm, - rm_paths=["file1.py"], - ) - await rm(config) - collection.delete.assert_called_with(where={"path": {"$in": ["file1.py"]}}) + assert await rm(configs) == 0 + assert capsys.readouterr().out.strip() == "Removed 2 file(s)." @pytest.mark.asyncio -async def test_rm_empty_collection(client, collection, capsys): - with ( - patch( - "vectorcode.subcommands.files.rm.get_collection", return_value=collection - ), - patch("vectorcode.common.try_server", return_value=True), - patch("os.path.isfile", return_value=True), - patch( - "vectorcode.subcommands.files.rm.expand_path", side_effect=lambda x, y: x - ), - patch( - "vectorcode.subcommands.files.rm.ClientManager._create_client", - return_value=client, - ), +async def test_rm_clean_after_rm(mock_db, capsys): + configs = Config(rm_paths=["file1.py", "file2.py"]) + mock_db._configs = configs + mock_db.count = AsyncMock(return_value=0) + with patch( + "vectorcode.subcommands.files.rm.get_database_connector", return_value=mock_db ): - config = Config( - action=CliAction.files, - files_action=FilesAction.rm, - rm_paths=["file1.py"], - ) - collection.count = AsyncMock(return_value=0) - client.delete_collection = AsyncMock() - await rm(config) - client.delete_collection.assert_called_once_with(collection.name) + assert await rm(configs) == 0 + mock_db.drop.assert_called_once() @pytest.mark.asyncio -async def test_rm_no_collection(client, collection, capsys): - with ( - patch("vectorcode.subcommands.files.rm.ClientManager") as MockClientManager, - patch("vectorcode.subcommands.files.rm.get_collection", side_effect=ValueError), +async def test_rm_no_collection(mock_db, capsys): + with patch( + "vectorcode.subcommands.files.rm.get_database_connector", return_value=mock_db ): - MockClientManager.return_value._create_client.return_value = client + mock_db.delete.side_effect = CollectionNotFoundError assert ( await rm( Config( - action=CliAction.files, - files_action=FilesAction.rm, - pipe=True, rm_paths=["file1.py"], ) ) diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py index 06559836..94916f80 100644 --- a/tests/subcommands/query/test_query.py +++ b/tests/subcommands/query/test_query.py @@ -1,586 +1,120 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from chromadb import QueryResult -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.api.types import IncludeEnum -from chromadb.errors import InvalidCollectionException, InvalidDimensionException +from tree_sitter import Point +from vectorcode.chunking import Chunk from vectorcode.cli_utils import CliAction, Config, QueryInclude -from vectorcode.subcommands.query import ( - build_query_results, - convert_query_results, - get_query_result_files, - query, -) -from vectorcode.subcommands.query.reranker import ( - RerankerError, -) - - -@pytest.fixture -def mock_collection(): - collection = AsyncMock(spec=AsyncCollection) - collection.count.return_value = 10 - collection.query.return_value = { - "ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]], - "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - "metadatas": [ - [ - {"path": "file1.py", "start": 1, "end": 1}, - {"path": "file2.py", "start": 1, "end": 1}, - {"path": "file3.py", "start": 1, "end": 1}, - ], - [ - {"path": "file2.py", "start": 1, "end": 1}, - {"path": "file4.py", "start": 1, "end": 1}, - {"path": "file3.py", "start": 1, "end": 1}, - ], - ], - "documents": [ - ["content1", "content2", "content3"], - ["content4", "content5", "content6"], - ], - } - return collection +from vectorcode.database.base import DatabaseConnectorBase +from vectorcode.subcommands.query import query @pytest.fixture def mock_config(): return Config( - query=["test query", "test query 2"], + action=CliAction.query, + query=["test query"], n_result=3, - query_multiplier=2, - chunk_size=100, - overlap_ratio=0.2, project_root="/test/project", pipe=False, include=[QueryInclude.path, QueryInclude.document], query_exclude=[], - reranker=None, + reranker="NaiveReranker", reranker_params={}, use_absolute_path=False, ) -@pytest.mark.asyncio -async def test_get_query_result_files(mock_collection, mock_config): - mock_embedding_function = MagicMock() - mock_config.embedding_dims = 10 - with ( - patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker, - patch( - "vectorcode.subcommands.query.get_embedding_function", - return_value=mock_embedding_function, - ), - ): - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock( - return_value=[ - "file1.py", - "file2.py", - "file3.py", - ] - ) - mock_get_reranker.return_value = mock_reranker_instance - - # Call the function - result = await get_query_result_files(mock_collection, mock_config) - - # Check that query was called with the right parameters - mock_collection.query.assert_called_once() - args, kwargs = mock_collection.query.call_args - mock_embedding_function.assert_called_once_with(["test query", "test query 2"]) - assert kwargs["n_results"] == 6 # n_result(3) * query_multiplier(2) - assert IncludeEnum.metadatas in kwargs["include"] - assert IncludeEnum.distances in kwargs["include"] - assert IncludeEnum.documents in kwargs["include"] - assert not kwargs["where"] # Since query_exclude is empty - - # Check reranker was used correctly - mock_get_reranker.assert_called_once_with(mock_config) - mock_reranker_instance.rerank.assert_called_once_with( - convert_query_results(mock_collection.query.return_value, mock_config.query) - ) - - # Check the result - assert result == ["file1.py", "file2.py", "file3.py"] - assert all( - len(i) == 10 for i in mock_collection.query.kwargs["query_embeddings"] - ) - - -@pytest.mark.asyncio -async def test_get_query_result_files_include_chunk(mock_collection, mock_config): - """Test get_query_result_files when QueryInclude.chunk is included.""" - mock_config.include = [QueryInclude.chunk] # Include chunk - - with patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker: - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["chunk1"]) - MockReranker.return_value = mock_reranker_instance - - await get_query_result_files(mock_collection, mock_config) - - # Check query call includes where clause for chunks - mock_collection.query.assert_called_once() - _, kwargs = mock_collection.query.call_args - # Line 43: Check the 'if' condition branch - assert kwargs["where"] == {"start": {"$gte": 0}} - assert kwargs["n_results"] == 3 # n_result should be used directly - - -@pytest.mark.asyncio -async def test_build_query_results_chunk_mode_success(mock_collection, mock_config): - """Test build_query_results in chunk mode successfully retrieves chunk details.""" - for request_abs_path in (True, False): - mock_config.include = [QueryInclude.chunk, QueryInclude.path] - mock_config.project_root = "/test/project" - mock_config.use_absolute_path = request_abs_path - mock_config.query = ["dummy_query"] - identifier = "chunk_id" - file_path = "/test/project/subdir/file1.py" - relative_path = "subdir/file1.py" - start_line = 5 - end_line = 10 - - full_file_content_lines = [f"line {i}\n" for i in range(15)] - - expected_chunk_content = "".join( - full_file_content_lines[start_line : end_line + 1] - ) - - mock_get_result = QueryResult( - ids=[[identifier]], - documents=[[expected_chunk_content]], - metadatas=[[{"path": file_path, "start": start_line, "end": end_line}]], - distances=[[0.2]], - ) - mock_collection.query = AsyncMock(return_value=mock_get_result) - with ( - patch( - "vectorcode.subcommands.query.get_query_result_files", - return_value=await get_query_result_files(mock_collection, mock_config), - ), - patch("os.path.isfile", return_value=False), - patch("os.path.relpath", return_value=relative_path) as mock_relpath, - ): - results = await build_query_results(mock_collection, mock_config) - - if not request_abs_path: - mock_relpath.assert_called_once_with( - file_path, str(mock_config.project_root) - ) - - assert len(results) == 1 - - expected_full_result = { - "path": file_path if request_abs_path else relative_path, - "chunk": expected_chunk_content, - "start_line": start_line, - "end_line": end_line, - "chunk_id": identifier, - } - - assert results[0] == expected_full_result - - -@pytest.mark.asyncio -async def test_get_query_result_files_with_query_exclude(mock_collection, mock_config): - # Setup query_exclude - mock_config.query_exclude = ["/excluded/path.py"] - - with ( - patch("vectorcode.subcommands.query.expand_path") as mock_expand_path, - patch("vectorcode.subcommands.query.expand_globs") as mock_expand_globs, - patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker, - patch("os.path.isfile", return_value=True), # Add this line to mock isfile - ): - mock_expand_globs.return_value = ["/excluded/path.py"] - mock_expand_path.return_value = "/excluded/path.py" - - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) - MockReranker.return_value = mock_reranker_instance - - # Call the function - await get_query_result_files(mock_collection, mock_config) - - # Check that query was called with the right parameters including the where clause - mock_collection.query.assert_called_once() - _, kwargs = mock_collection.query.call_args - assert kwargs["where"] == {"path": {"$nin": ["/excluded/path.py"]}} - - -@pytest.mark.asyncio -async def test_get_query_result_chunks_with_query_exclude(mock_collection, mock_config): - # Setup query_exclude - mock_config.query_exclude = ["/excluded/path.py"] - mock_config.include = [QueryInclude.chunk, QueryInclude.path] - - with ( - patch("vectorcode.subcommands.query.expand_path") as mock_expand_path, - patch("vectorcode.subcommands.query.expand_globs") as mock_expand_globs, - patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker, - patch("os.path.isfile", return_value=True), # Add this line to mock isfile - ): - mock_expand_globs.return_value = ["/excluded/path.py"] - mock_expand_path.return_value = "/excluded/path.py" - - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) - MockReranker.return_value = mock_reranker_instance - - # Call the function - await get_query_result_files(mock_collection, mock_config) - - # Check that query was called with the right parameters including the where clause - mock_collection.query.assert_called_once() - _, kwargs = mock_collection.query.call_args - assert kwargs["where"] == { - "$and": [{"path": {"$nin": ["/excluded/path.py"]}}, {"start": {"$gte": 0}}] - } - - -@pytest.mark.asyncio -async def test_get_query_reranker_initialisation_error(mock_collection, mock_config): - # Configure to use CrossEncoder reranker - mock_config.reranker = "cross-encoder/model-name" - - with patch( - "vectorcode.subcommands.query.reranker.CrossEncoderReranker" - ) as MockCrossEncoder: - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) - MockCrossEncoder.return_value = mock_reranker_instance - - with pytest.raises(RerankerError): - # Call the function - await get_query_result_files(mock_collection, mock_config) - - -@pytest.mark.asyncio -async def test_get_query_result_files_empty_collection(mock_collection, mock_config): - # Setup an empty collection - mock_collection.count.return_value = 0 - - # Call the function - result = await get_query_result_files(mock_collection, mock_config) - - # Check that the result is an empty list - assert result == [] - # Ensure query wasn't called - mock_collection.query.assert_not_called() - - -@pytest.mark.asyncio -async def test_get_query_result_files_query_error(mock_collection, mock_config): - # Make query raise an IndexError - mock_collection.query.side_effect = IndexError("No results") - - # Call the function - result = await get_query_result_files(mock_collection, mock_config) - - # Check that the result is an empty list - assert result == [] - - -@pytest.mark.asyncio -async def test_get_query_result_files_chunking(mock_collection, mock_config): - # Set a long query that will be chunked - mock_config.query = [ - "this is a longer query that should be chunked into multiple parts" +@pytest.fixture +def mock_database(): + db = AsyncMock(spec=DatabaseConnectorBase) + db.query.return_value = [ + MagicMock(path="file1.py", document="content1"), + MagicMock(path="file2.py", document="content2"), ] - mock_embedding_function = MagicMock() - with ( - patch("vectorcode.subcommands.query.StringChunker") as MockChunker, - patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker, - patch( - "vectorcode.subcommands.query.get_embedding_function", - return_value=mock_embedding_function, - ), - ): - # Set up MockChunker to chunk the query - mock_chunker_instance = MagicMock() - mock_chunker_instance.chunk.return_value = ["chunk1", "chunk2", "chunk3"] - MockChunker.return_value = mock_chunker_instance - - mock_reranker_instance = MagicMock() - mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"]) - MockReranker.return_value = mock_reranker_instance - - # Call the function - result = await get_query_result_files(mock_collection, mock_config) - - # Check that the chunker was used correctly - MockChunker.assert_called_once_with(mock_config) - mock_chunker_instance.chunk.assert_called_once_with(mock_config.query[0]) - - # Check query was called with chunked query - mock_collection.query.assert_called_once() - _, kwargs = mock_collection.query.call_args - mock_embedding_function.assert_called_once_with(["chunk1", "chunk2", "chunk3"]) - - # Check the result - assert result == ["file1.py", "file2.py"] + return db @pytest.mark.asyncio -async def test_query_success(mock_config): - # Mock all the necessary dependencies - mock_client = AsyncMock() - mock_collection = AsyncMock() - +async def test_query_success(mock_config, mock_database, capsys): with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection + "vectorcode.subcommands.query.get_database_connector", + return_value=mock_database, ), - patch("vectorcode.subcommands.query.verify_ef", return_value=True), - patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files, - patch("builtins.open", create=True) as mock_open, - patch("json.dumps"), - patch("os.path.isfile", return_value=True), - patch("os.path.relpath", return_value="rel/path.py"), - patch("os.path.abspath", return_value="/abs/path.py"), + patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker, ): - MockClientManager.return_value._create_client.return_value = mock_client - # Set up the mock file paths and contents - mock_get_files.return_value = ["file1.py", "file2.py"] - mock_file_handle = MagicMock() - mock_file_handle.__enter__.return_value.read.return_value = "file content" - mock_open.return_value = mock_file_handle + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [ + "file1.py", + "file2.py", + ] + mock_get_reranker.return_value = mock_reranker - # Call the function - result = await query(mock_config) + with ( + patch("builtins.open", MagicMock()), + patch("os.path.isfile", return_value=True), + ): + result = await query(mock_config) - # Verify the function completed successfully assert result == 0 - - # Check that all the expected functions were called - mock_get_files.assert_called_once_with(mock_collection, mock_config) - - # Check file opening and reading - assert mock_open.call_count == 2 # Two files + captured = capsys.readouterr() + assert "Path: file1.py" in captured.out + assert "Path: file2.py" in captured.out @pytest.mark.asyncio -async def test_query_pipe_mode(mock_config): - # Set pipe mode to True +async def test_query_pipe_mode(mock_config, mock_database): mock_config.pipe = True - - # Similar to test_query_success but check for JSON output - mock_client = AsyncMock() - mock_collection = AsyncMock() - with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection + "vectorcode.subcommands.query.get_database_connector", + return_value=mock_database, ), - patch("vectorcode.subcommands.query.verify_ef", return_value=True), - patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files, - patch("builtins.open", create=True) as mock_open, + patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker, patch("json.dumps") as mock_json_dumps, - patch("os.path.isfile", return_value=True), - patch("os.path.relpath", return_value="rel/path.py"), - patch("os.path.abspath", return_value="/abs/path.py"), ): - MockClientManager.return_value._create_client.return_value = mock_client - # Set up the mock file paths and contents - mock_get_files.return_value = ["file1.py", "file2.py"] - mock_file_handle = MagicMock() - mock_file_handle.__enter__.return_value.read.return_value = "file content" - mock_open.return_value = mock_file_handle - - # Call the function - result = await query(mock_config) + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [ + "file1.py", + "file2.py", + ] + mock_get_reranker.return_value = mock_reranker - # Verify the function completed successfully - assert result == 0 + with ( + patch("builtins.open", MagicMock()), + patch("os.path.isfile", return_value=True), + ): + await query(mock_config) - # Check that JSON dumps was called mock_json_dumps.assert_called_once() @pytest.mark.asyncio -async def test_query_absolute_path(mock_config): - # Set use_absolute_path to True - mock_config.use_absolute_path = True - - # Mock all the necessary dependencies - mock_client = AsyncMock() - mock_collection = AsyncMock() - - with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection - ), - patch("vectorcode.subcommands.query.verify_ef", return_value=True), - patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files, - patch("builtins.open", create=True) as mock_open, - patch("os.path.isfile", return_value=True), - patch("os.path.relpath", return_value="rel/path.py"), - patch("os.path.abspath", return_value="/abs/path.py"), - ): - MockClientManager.return_value._create_client.return_value = mock_client - # Set up the mock file paths and contents - mock_get_files.return_value = ["file1.py"] - mock_file_handle = MagicMock() - mock_file_handle.__enter__.return_value.read.return_value = "file content" - mock_open.return_value = mock_file_handle - - # Call the function - result = await query(mock_config) - - # Verify the function completed successfully - assert result == 0 - - -@pytest.mark.asyncio -async def test_query_collection_not_found(): - config = Config(project_root="/test/project") - - with ( - patch("vectorcode.subcommands.query.ClientManager"), - patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, - patch("sys.stderr"), - ): - # Make get_collection raise ValueError - mock_get_collection.side_effect = ValueError("Collection not found") - - # Call the function - result = await query(config) - - # Check the error was handled properly - assert result == 1 - - -@pytest.mark.asyncio -async def test_query_invalid_collection(): - config = Config(project_root="/test/project") - - with ( - patch("vectorcode.subcommands.query.ClientManager"), - patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, - patch("sys.stderr"), - ): - # Make get_collection raise InvalidCollectionException - mock_get_collection.side_effect = InvalidCollectionException( - "Invalid collection" - ) - - # Call the function - result = await query(config) - - # Check the error was handled properly - assert result == 1 - - -@pytest.mark.asyncio -async def test_query_invalid_dimension(): - config = Config(project_root="/test/project") - - with ( - patch("vectorcode.subcommands.query.ClientManager"), - patch("vectorcode.subcommands.query.get_collection") as mock_get_collection, - patch("sys.stderr"), - ): - # Make get_collection raise InvalidDimensionException - mock_get_collection.side_effect = InvalidDimensionException("Invalid dimension") - - # Call the function - result = await query(config) - - # Check the error was handled properly - assert result == 1 - - -@pytest.mark.asyncio -async def test_query_invalid_file(mock_config): - # Set up mocks for a successful query but with an invalid file - mock_client = AsyncMock() - mock_collection = AsyncMock() - - with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection - ), - patch("vectorcode.subcommands.query.verify_ef", return_value=True), - patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files, - patch("os.path.isfile", return_value=False), - ): - MockClientManager.return_value._create_client.return_value = mock_client - # Set up the mock file paths - mock_get_files.return_value = ["invalid_file.py"] - - # Call the function - result = await query(mock_config) - - # Verify the function completed successfully despite invalid file - assert result == 0 - - -@pytest.mark.asyncio -async def test_query_invalid_ef(mock_config): - # Test when verify_ef returns False - mock_client = AsyncMock() - mock_collection = AsyncMock() +async def test_query_chunk_mode(mock_config, mock_database, capsys): + mock_config.include = [QueryInclude.chunk] + chunk1 = Chunk(text="chunk1", path="file1.py", start=Point(1, 0), end=Point(2, 0)) + chunk2 = Chunk(text="chunk2", path="file1.py", start=Point(3, 0), end=Point(4, 0)) with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection + "vectorcode.subcommands.query.get_database_connector", + return_value=mock_database, ), - patch("vectorcode.subcommands.query.verify_ef", return_value=False), + patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker, ): - MockClientManager.return_value._create_client.return_value = mock_client - # Call the function - result = await query(mock_config) - - # Verify the function returns error code - assert result == 1 + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [chunk1, chunk2] + mock_get_reranker.return_value = mock_reranker - -@pytest.mark.asyncio -async def test_query_invalid_include(): - faulty_config = Config( - action=CliAction.query, include=[QueryInclude.chunk, QueryInclude.document] - ) - assert await query(faulty_config) != 0 + await query(mock_config) + captured = capsys.readouterr() + assert "Chunk: chunk1" in captured.out + assert "Chunk: chunk2" in captured.out @pytest.mark.asyncio -async def test_query_chunk_mode_no_metadata_fallback(mock_config): - mock_config.include = [QueryInclude.chunk, QueryInclude.path] - mock_client = AsyncMock() - mock_collection = AsyncMock() - - # Mock collection.get to return no IDs for the metadata check - mock_collection.get.return_value = {"ids": []} - - with ( - patch("vectorcode.subcommands.query.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.query.get_collection", return_value=mock_collection - ), - patch("vectorcode.subcommands.query.verify_ef", return_value=True), - patch("vectorcode.subcommands.query.build_query_results") as mock_build_results, - ): - MockClientManager.return_value._create_client.return_value = mock_client - mock_build_results.return_value = [] # Return empty results for simplicity - - result = await query(mock_config) - - assert result == 0 - - # Verify the metadata check call - mock_collection.get.assert_called_once_with(where={"start": {"$gte": 0}}) - - # Verify build_query_results was called with the *modified* config - mock_build_results.assert_called_once() - args, _ = mock_build_results.call_args - _, called_config = args - assert called_config.include == [QueryInclude.path, QueryInclude.document] +async def test_query_invalid_include(mock_config): + mock_config.include = [QueryInclude.chunk, QueryInclude.document] + result = await query(mock_config) + assert result != 0 diff --git a/tests/subcommands/query/test_reranker.py b/tests/subcommands/query/test_reranker.py index 31efe758..236fa0e1 100644 --- a/tests/subcommands/query/test_reranker.py +++ b/tests/subcommands/query/test_reranker.py @@ -5,6 +5,7 @@ import pytest from vectorcode.cli_utils import Config, QueryInclude +from vectorcode.database.types import QueryResult from vectorcode.subcommands.query.reranker import ( CrossEncoderReranker, NaiveReranker, @@ -14,7 +15,6 @@ get_available_rerankers, get_reranker, ) -from vectorcode.subcommands.query.types import QueryResult @pytest.fixture(scope="function") diff --git a/tests/subcommands/query/test_types.py b/tests/subcommands/query/test_types.py index 392b6c6f..bcf45e1c 100644 --- a/tests/subcommands/query/test_types.py +++ b/tests/subcommands/query/test_types.py @@ -2,7 +2,7 @@ from tree_sitter import Point from vectorcode.chunking import Chunk -from vectorcode.subcommands.query.types import QueryResult +from vectorcode.database.types import QueryResult def make_dummy_chunk(): diff --git a/tests/subcommands/test_clean.py b/tests/subcommands/test_clean.py index 8c79fd7f..1b10fa08 100644 --- a/tests/subcommands/test_clean.py +++ b/tests/subcommands/test_clean.py @@ -1,82 +1,37 @@ from unittest.mock import AsyncMock, patch import pytest -from chromadb.api import AsyncClientAPI from vectorcode.cli_utils import Config -from vectorcode.subcommands.clean import clean, run_clean_on_client +from vectorcode.subcommands.clean import clean @pytest.mark.asyncio -async def test_run_clean_on_client(): - mock_client = AsyncMock(spec=AsyncClientAPI) - mock_collection1 = AsyncMock() - mock_collection1.name = "test_collection_1" - mock_collection1.metadata = {"path": "/test/path1"} - mock_collection1.count.return_value = 0 # Empty collection - mock_collection2 = AsyncMock() - mock_collection2.name = "test_collection_2" - mock_collection2.metadata = {"path": "/test/path2"} - mock_collection2.count.return_value = 1 # Non-empty collection - - async def mock_get_collections(client): - yield mock_collection1 - yield mock_collection2 - - with ( - patch("vectorcode.subcommands.clean.get_collections", new=mock_get_collections), - patch("os.path.isdir", return_value=lambda x: x == "/test/path2"), - ): - await run_clean_on_client(mock_client, pipe_mode=False) - - mock_client.delete_collection.assert_called_once_with(mock_collection1.name) - - -@pytest.mark.asyncio -async def test_run_clean_on_client_pipe_mode(): - mock_client = AsyncMock(spec=AsyncClientAPI) - mock_collection1 = AsyncMock() - mock_collection1.name = "test_collection_1" - mock_collection1.metadata = {"path": "/test/path1"} - mock_collection1.count.return_value = 0 # Empty collection - - async def mock_get_collections(client): - yield mock_collection1 +async def test_clean(capsys): + mock_db = AsyncMock() + mock_db.cleanup.return_value = ["/test/path1", "/test/path2"] with patch( - "vectorcode.subcommands.clean.get_collections", new=mock_get_collections + "vectorcode.subcommands.clean.get_database_connector", return_value=mock_db ): - await run_clean_on_client(mock_client, pipe_mode=True) + result = await clean(Config(pipe=False)) - mock_client.delete_collection.assert_called_once_with(mock_collection1.name) + assert result == 0 + captured = capsys.readouterr() + assert "Deleted collection: /test/path1" in captured.out + assert "Deleted collection: /test/path2" in captured.out @pytest.mark.asyncio -async def test_run_clean_on_removed_dir(): - mock_client = AsyncMock(spec=AsyncClientAPI) - mock_collection1 = AsyncMock() - mock_collection1.name = "test_collection_1" - mock_collection1.metadata = {"path": "/test/path1"} - mock_collection1.count.return_value = 10 - - async def mock_get_collections(client): - yield mock_collection1 +async def test_clean_pipe_mode(capsys): + mock_db = AsyncMock() + mock_db.cleanup.return_value = ["/test/path1", "/test/path2"] - with ( - patch("vectorcode.subcommands.clean.get_collections", new=mock_get_collections), - patch("os.path.isdir", return_value=False), + with patch( + "vectorcode.subcommands.clean.get_database_connector", return_value=mock_db ): - await run_clean_on_client(mock_client, pipe_mode=True) - - mock_client.delete_collection.assert_called_once_with(mock_collection1.name) - - -@pytest.mark.asyncio -async def test_clean(): - AsyncMock(spec=AsyncClientAPI) - mock_config = Config(pipe=False) - - with patch("vectorcode.subcommands.clean.ClientManager"): - result = await clean(mock_config) + result = await clean(Config(pipe=True)) assert result == 0 + captured = capsys.readouterr() + assert captured.out == "" diff --git a/tests/subcommands/test_drop.py b/tests/subcommands/test_drop.py index 15b990d8..8268d902 100644 --- a/tests/subcommands/test_drop.py +++ b/tests/subcommands/test_drop.py @@ -1,65 +1,27 @@ -from contextlib import asynccontextmanager from unittest.mock import AsyncMock, patch import pytest from vectorcode.cli_utils import Config +from vectorcode.database.errors import CollectionNotFoundError from vectorcode.subcommands.drop import drop -@pytest.fixture -def mock_config(): - config = Config( - project_root="/path/to/project", - ) # Removed positional args - return config - - -@pytest.fixture -def mock_client(): - return AsyncMock() - - -@pytest.fixture -def mock_collection(): - collection = AsyncMock() - collection.name = "test_collection" - collection.metadata = {"path": "/path/to/project"} - return collection - - @pytest.mark.asyncio -async def test_drop_success(mock_config, mock_client, mock_collection): - mock_client.get_collection.return_value = mock_collection - mock_client.delete_collection = AsyncMock() - with ( - patch("vectorcode.subcommands.drop.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.drop.get_collection", return_value=mock_collection - ), +async def test_drop_success(): + mock_db = AsyncMock() + with patch( + "vectorcode.subcommands.drop.get_database_connector", return_value=mock_db ): - mock_client = AsyncMock() - - @asynccontextmanager - async def _get_client(self, config=None, need_lock=True): - yield mock_client - - mock_client_manager = MockClientManager.return_value - mock_client_manager._create_client = AsyncMock(return_value=mock_client) - mock_client_manager.get_client = _get_client - - result = await drop(mock_config) - assert result == 0 - mock_client.delete_collection.assert_called_once_with(mock_collection.name) + await drop(config=Config(project_root="DummyDir")) + mock_db.drop.assert_called_once() @pytest.mark.asyncio -async def test_drop_collection_not_found(mock_config, mock_client): - mock_client.get_collection.side_effect = ValueError("Collection not found") - with patch("vectorcode.subcommands.drop.ClientManager"): - with patch( - "vectorcode.subcommands.drop.get_collection", - side_effect=ValueError("Collection not found"), - ): - result = await drop(mock_config) - assert result == 1 +async def test_drop_collection_not_found(): + mock_db = AsyncMock() + mock_db.drop = AsyncMock(side_effect=CollectionNotFoundError) + with patch( + "vectorcode.subcommands.drop.get_database_connector", return_value=mock_db + ): + assert await drop(config=Config(project_root="DummyDir")) != 0 diff --git a/tests/subcommands/test_ls.py b/tests/subcommands/test_ls.py index bbc674eb..1b576c17 100644 --- a/tests/subcommands/test_ls.py +++ b/tests/subcommands/test_ls.py @@ -1,161 +1,70 @@ import json -import socket -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest import tabulate from vectorcode.cli_utils import Config -from vectorcode.subcommands.ls import get_collection_list, ls - - -@pytest.mark.asyncio -async def test_get_collection_list(): - mock_client = AsyncMock() - mock_collection1 = AsyncMock() - mock_collection1.name = "test_collection_1" - mock_collection1.metadata = { - "path": "/test/path1", - "username": "test_user", - "embedding_function": "test_ef", - } - mock_collection1.count.return_value = 100 - mock_collection1.get.return_value = { - "metadatas": [ - {"path": "/test/path1/file1.txt"}, - {"path": "/test/path1/file2.txt"}, - None, - ] - } - mock_collection2 = AsyncMock() - mock_collection2.name = "test_collection_2" - mock_collection2.metadata = { - "path": "/test/path2", - "username": "test_user", - "embedding_function": "test_ef", - } - mock_collection2.count.return_value = 200 - mock_collection2.get.return_value = { - "metadatas": [ - {"path": "/test/path2/file1.txt"}, - {"path": "/test/path2/file2.txt"}, - ] - } - - async def mock_get_collections(client): - yield mock_collection1 - yield mock_collection2 - - with patch("vectorcode.subcommands.ls.get_collections", new=mock_get_collections): - result = await get_collection_list(mock_client) - - assert len(result) == 2 - assert result[0]["project-root"] == "/test/path1" - assert result[0]["user"] == "test_user" - assert result[0]["hostname"] == socket.gethostname() - assert result[0]["collection_name"] == "test_collection_1" - assert result[0]["size"] == 100 - assert result[0]["embedding_function"] == "test_ef" - assert result[0]["num_files"] == 2 - assert result[1]["num_files"] == 2 +from vectorcode.database.types import CollectionInfo +from vectorcode.subcommands.ls import ls + + +@pytest.fixture +def mock_collections(): + return [ + CollectionInfo( + path="/test/path1", + id="test_collection_1", + chunk_count=100, + file_count=2, + embedding_function="test_ef", + database_backend="ChromaDB", + ), + CollectionInfo( + path="/test/path2", + id="test_collection_2", + chunk_count=200, + file_count=2, + embedding_function="test_ef", + database_backend="ChromaDB", + ), + ] @pytest.mark.asyncio -async def test_ls_pipe_mode(capsys): - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.name = "test_collection" - mock_collection.metadata = { - "path": "/test/path", - "username": "test_user", - "embedding_function": "test_ef", - } - mock_collection.count.return_value = 50 - mock_collection.get.return_value = {"metadatas": [{"path": "/test/path/file.txt"}]} - - async def mock_get_collections(client): - yield mock_collection - - with ( - patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.ls.get_collection_list", - return_value=[ - { - "project-root": "/test/path", - "size": 50, - "num_files": 1, - "embedding_function": "test_ef", - } - ], - ), +async def test_ls_pipe_mode(capsys, mock_collections): + mock_db = AsyncMock() + mock_db.list_collections.return_value = mock_collections + with patch( + "vectorcode.subcommands.ls.get_database_connector", return_value=mock_db ): - mock_client = MagicMock() - mock_client_manager = MockClientManager.return_value - mock_client_manager._create_client = AsyncMock(return_value=mock_client) - config = Config(pipe=True) await ls(config) captured = capsys.readouterr() - expected_output = ( - json.dumps( - [ - { - "project-root": "/test/path", - "size": 50, - "num_files": 1, - "embedding_function": "test_ef", - } - ] - ) - + "\n" - ) + expected_output = json.dumps([c.to_dict() for c in mock_collections]) + "\n" assert captured.out == expected_output @pytest.mark.asyncio -async def test_ls_table_mode(capsys, monkeypatch): - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.name = "test_collection" - mock_collection.metadata = { - "path": "/test/path", - "username": "test_user", - "embedding_function": "test_ef", - } - mock_collection.count.return_value = 50 - mock_collection.get.return_value = {"metadatas": [{"path": "/test/path/file.txt"}]} - - async def mock_get_collections(client): - yield mock_collection - - with ( - patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.ls.get_collection_list", - return_value=[ - { - "project-root": "/test/path", - "size": 50, - "num_files": 1, - "embedding_function": "test_ef", - } - ], - ), +async def test_ls_table_mode(capsys, mock_collections, monkeypatch): + mock_db = AsyncMock() + mock_db.list_collections.return_value = mock_collections + with patch( + "vectorcode.subcommands.ls.get_database_connector", return_value=mock_db ): - mock_client = MagicMock() - mock_client_manager = MockClientManager.return_value - mock_client_manager._create_client = AsyncMock(return_value=mock_client) - config = Config(pipe=False) await ls(config) captured = capsys.readouterr() + expected_table = [ + ["/test/path1", 100, 2, "test_ef"], + ["/test/path2", 200, 2, "test_ef"], + ] expected_output = ( tabulate.tabulate( - [["/test/path", 50, 1, "test_ef"]], + expected_table, headers=[ "Project Root", - "Collection Size", + "Number of Embeddings", "Number of Files", "Embedding Function", ], @@ -166,32 +75,22 @@ async def mock_get_collections(client): # Test with HOME environment variable set monkeypatch.setenv("HOME", "/test") - with ( - patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.ls.get_collection_list", - return_value=[ - { - "project-root": "/test/path", - "size": 50, - "num_files": 1, - "embedding_function": "test_ef", - } - ], - ), + with patch( + "vectorcode.subcommands.ls.get_database_connector", return_value=mock_db ): - mock_client = MagicMock() - mock_client_manager = MockClientManager.return_value - mock_client_manager._create_client = AsyncMock(return_value=mock_client) config = Config(pipe=False) await ls(config) captured = capsys.readouterr() + expected_table = [ + ["~/path1", 100, 2, "test_ef"], + ["~/path2", 200, 2, "test_ef"], + ] expected_output = ( tabulate.tabulate( - [["~/path", 50, 1, "test_ef"]], + expected_table, headers=[ "Project Root", - "Collection Size", + "Number of Embeddings", "Number of Files", "Embedding Function", ], diff --git a/tests/subcommands/test_update.py b/tests/subcommands/test_update.py index 314f7c2a..f7ccf566 100644 --- a/tests/subcommands/test_update.py +++ b/tests/subcommands/test_update.py @@ -1,126 +1,152 @@ +import asyncio from unittest.mock import AsyncMock, patch import pytest -from chromadb.api.types import IncludeEnum -from chromadb.errors import InvalidCollectionException from vectorcode.cli_utils import Config +from vectorcode.database.types import FileInCollection from vectorcode.subcommands.update import update @pytest.mark.asyncio -async def test_update_success(): - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = { - "metadatas": [{"path": "file1.py"}, {"path": "file2.py"}] - } - mock_collection.delete = AsyncMock() - mock_client.get_max_batch_size.return_value = 100 +async def test_update_success(tmp_path): + """Test successful update with some modified files.""" + config = Config(project_root=str(tmp_path), pipe=False) + + # Mock files in the database + file1_path = tmp_path / "file1.py" + file1_path.write_text("content1") + file2_path = tmp_path / "file2.py" + file2_path.write_text("new content2") # modified + file3_path = tmp_path / "file3.py" + file3_path.write_text("content3") + + collection_files = [ + FileInCollection(path=str(file1_path), sha256="hash1_old"), + FileInCollection(path=str(file2_path), sha256="hash2_old"), + FileInCollection(path=str(file3_path), sha256="hash3_old"), + ] with ( - patch("vectorcode.subcommands.update.ClientManager"), + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, patch( - "vectorcode.subcommands.update.get_collection", return_value=mock_collection - ), - patch("vectorcode.subcommands.update.verify_ef", return_value=True), - patch("os.path.isfile", return_value=True), - patch( - "vectorcode.subcommands.update.chunked_add", new_callable=AsyncMock - ) as mock_chunked_add, - patch("vectorcode.subcommands.update.show_stats"), + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock + ) as mock_vectorise_worker, + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, + patch("vectorcode.subcommands.update.hash_file") as mock_hash_file, ): - config = Config(project_root="/test/project", pipe=False) + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = collection_files + mock_get_db.return_value = mock_db + + # file1.py is unchanged, file2.py is changed, file3.py is unchanged + mock_hash_file.side_effect = ["hash1_old", "hash2_new", "hash3_old"] + result = await update(config) assert result == 0 - mock_collection.get.assert_called_once_with(include=[IncludeEnum.metadatas]) - assert mock_chunked_add.call_count == 2 - mock_collection.delete.assert_not_called() + mock_db.list_collection_content.assert_called_once() + + # vectorise_worker should only be called for the modified file (file2.py) + assert mock_vectorise_worker.call_count == 1 + # Check that it was called with file2.py + called_with_file = mock_vectorise_worker.call_args_list[0][0][1] + assert called_with_file == str(file2_path) + + mock_db.check_orphanes.assert_called_once() + mock_show_stats.assert_called_once() @pytest.mark.asyncio -async def test_update_with_orphans(): - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = { - "metadatas": [{"path": "file1.py"}, {"path": "file2.py"}, {"path": "orphan.py"}] - } - mock_collection.delete = AsyncMock() - mock_client.get_max_batch_size.return_value = 100 +async def test_update_force(tmp_path): + """Test update with force=True, all files should be re-vectorised.""" + config = Config(project_root=str(tmp_path), pipe=False, force=True) + + file1_path = tmp_path / "file1.py" + file1_path.write_text("content1") + file2_path = tmp_path / "file2.py" + file2_path.write_text("content2") + + collection_files = [ + FileInCollection(path=str(file1_path), sha256="hash1"), + FileInCollection(path=str(file2_path), sha256="hash2"), + ] with ( - patch("vectorcode.subcommands.update.ClientManager"), - patch( - "vectorcode.subcommands.update.get_collection", return_value=mock_collection - ), - patch("vectorcode.subcommands.update.verify_ef", return_value=True), - patch("os.path.isfile", side_effect=[True, True, False]), + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, patch( - "vectorcode.subcommands.update.chunked_add", new_callable=AsyncMock - ) as mock_chunked_add, - patch("vectorcode.subcommands.update.show_stats"), + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock + ) as mock_vectorise_worker, + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, + patch("vectorcode.subcommands.update.hash_file") as mock_hash_file, ): - config = Config(project_root="/test/project", pipe=False) + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = collection_files + mock_get_db.return_value = mock_db + result = await update(config) assert result == 0 - mock_collection.get.assert_called_once_with(include=[IncludeEnum.metadatas]) - assert mock_chunked_add.call_count == 2 - mock_collection.delete.assert_called_once_with( - where={"path": {"$in": ["orphan.py"]}} - ) + mock_db.list_collection_content.assert_called_once() + # vectorise_worker should be called for all files + assert mock_vectorise_worker.call_count == 2 + mock_hash_file.assert_not_called() # hash_file should not be called with force=True -@pytest.mark.asyncio -async def test_update_index_error(): - mock_client = AsyncMock() - # mock_collection = AsyncMock() + mock_db.check_orphanes.assert_called_once() + mock_show_stats.assert_called_once() - with ( - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, - patch("vectorcode.subcommands.update.get_collection", side_effect=IndexError), - patch("sys.stderr"), - ): - MockClientManager.return_value._create_client.return_value = mock_client - config = Config(project_root="/test/project", pipe=False) - result = await update(config) - assert result == 1 +@pytest.mark.asyncio +async def test_update_cancelled(tmp_path): + """Test update being cancelled.""" + config = Config(project_root=str(tmp_path), pipe=False) + file1_path = tmp_path / "file1.py" + file1_path.write_text("content1") -@pytest.mark.asyncio -async def test_update_value_error(): - mock_client = AsyncMock() - # mock_collection = AsyncMock() + collection_files = [ + FileInCollection(path=str(file1_path), sha256="hash1_old"), + ] with ( - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, - patch("vectorcode.subcommands.update.get_collection", side_effect=ValueError), - patch("sys.stderr"), + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, + patch( + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock + ) as mock_vectorise_worker, + patch("vectorcode.subcommands.update.hash_file", return_value="hash1_new"), ): - MockClientManager.return_value._create_client.return_value = mock_client - config = Config(project_root="/test/project", pipe=False) + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = collection_files + mock_get_db.return_value = mock_db + + mock_vectorise_worker.side_effect = asyncio.CancelledError + result = await update(config) assert result == 1 + mock_db.check_orphanes.assert_not_called() @pytest.mark.asyncio -async def test_update_invalid_collection_exception(): - mock_client = AsyncMock() - # mock_collection = AsyncMock() +async def test_update_empty_collection(tmp_path): + """Test update with an empty collection.""" + config = Config(project_root=str(tmp_path), pipe=False) with ( - patch("vectorcode.subcommands.update.ClientManager") as MockClientManager, + patch("vectorcode.subcommands.update.get_database_connector") as mock_get_db, patch( - "vectorcode.subcommands.update.get_collection", - side_effect=InvalidCollectionException, - ), - patch("sys.stderr"), + "vectorcode.subcommands.update.vectorise_worker", new_callable=AsyncMock + ) as mock_vectorise_worker, + patch("vectorcode.subcommands.update.show_stats") as mock_show_stats, ): - MockClientManager.return_value._create_client.return_value = mock_client - config = Config(project_root="/test/project", pipe=False) + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = [] + mock_get_db.return_value = mock_db + result = await update(config) - assert result == 1 + assert result == 0 + mock_vectorise_worker.assert_not_called() + mock_db.check_orphanes.assert_called_once() + mock_show_stats.assert_called_once() diff --git a/tests/subcommands/test_vectorise.py b/tests/subcommands/test_vectorise.py index 3ce5683b..fa3c7f38 100644 --- a/tests/subcommands/test_vectorise.py +++ b/tests/subcommands/test_vectorise.py @@ -1,28 +1,17 @@ -import asyncio import hashlib import os -import socket -import tempfile -from contextlib import ExitStack -from unittest.mock import AsyncMock, MagicMock, mock_open, patch +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch -import numpy import pytest -from chromadb.api.models.AsyncCollection import AsyncCollection -from tree_sitter import Point -from vectorcode.chunking import Chunk -from vectorcode.cli_utils import CliAction, Config +from vectorcode.cli_utils import Config +from vectorcode.database.errors import CollectionNotFoundError +from vectorcode.database.types import VectoriseStats +from vectorcode.database.utils import get_uuid, hash_file, hash_str from vectorcode.subcommands.vectorise import ( - VectoriseStats, - chunked_add, - exclude_paths_by_spec, find_exclude_specs, - get_uuid, - hash_file, - hash_str, load_files_from_include, - show_stats, vectorise, ) @@ -33,34 +22,14 @@ def test_hash_str(): assert hash_str(test_string) == expected_hash -def test_hash_file_basic(): +def test_hash_file(tmp_path): content = b"This is a test file for hashing." expected_hash = hashlib.sha256(content).hexdigest() + file_path = tmp_path / "test_file.txt" + file_path.write_bytes(content) - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - tmp_file.write(content) - tmp_file_path = tmp_file.name - - try: - actual_hash = hash_file(tmp_file_path) - assert actual_hash == expected_hash - finally: - os.remove(tmp_file_path) - - -def test_hash_file_empty(): - content = b"" - expected_hash = hashlib.sha256(content).hexdigest() - - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - tmp_file.write(content) - tmp_file_path = tmp_file.name - - try: - actual_hash = hash_file(tmp_file_path) - assert actual_hash == expected_hash - finally: - os.remove(tmp_file_path) + actual_hash = hash_file(str(file_path)) + assert actual_hash == expected_hash def test_get_uuid(): @@ -69,242 +38,115 @@ def test_get_uuid(): assert len(uuid_str) == 32 # UUID4 hex string length -@pytest.mark.asyncio -async def test_chunked_add(): - file_path = "test_file.py" - collection = AsyncMock() - collection_lock = asyncio.Lock() - stats = VectoriseStats() - stats_lock = asyncio.Lock() - configs = Config(chunk_size=100, overlap_ratio=0.2, project_root=".") - max_batch_size = 50 - semaphore = asyncio.Semaphore(1) +@patch("tabulate.tabulate") +def test_show_stats_pipe_false(mock_tabulate, capsys): + from vectorcode.subcommands.vectorise import show_stats - with ( - patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk, - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, - ): - mock_hash_file.return_value = "hash1" - mock_chunk.return_value = [Chunk("chunk1", Point(1, 0), Point(1, 5)), "chunk2"] - await chunked_add( - file_path, - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - - assert stats.add == 1 - assert stats.update == 0 - collection.add.assert_called() - assert collection.add.call_count == 1 + configs = Config(pipe=False) + stats = VectoriseStats(**{"add": 1, "update": 2, "removed": 3}) + show_stats(configs, stats) + mock_tabulate.assert_called_once() -@pytest.mark.asyncio -async def test_chunked_add_truncated(): - file_path = "test_file.py" - collection = AsyncMock() - collection_lock = asyncio.Lock() - stats = VectoriseStats() - stats_lock = asyncio.Lock() - configs = Config( - chunk_size=100, overlap_ratio=0.2, project_root=".", embedding_dims=10 - ) - max_batch_size = 50 - semaphore = asyncio.Semaphore(1) +def test_show_stats_pipe_true(capsys): + from vectorcode.subcommands.vectorise import show_stats - with ( - patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk, - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, - ): - mock_hash_file.return_value = "hash1" - mock_chunk.return_value = [Chunk("chunk1", Point(1, 0), Point(1, 5)), "chunk2"] - await chunked_add( - file_path, - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - - assert stats.add == 1 - assert stats.update == 0 - collection.add.assert_called() - assert collection.add.call_count == 1 - - assert all(len(i) == 10 for i in collection.add.call_args.kwargs["embeddings"]) + configs = Config(pipe=True) + stats = VectoriseStats(**{"add": 1, "update": 2, "removed": 3}) + show_stats(configs, stats) + captured = capsys.readouterr() + assert captured.out.strip() == (stats.to_json()) @pytest.mark.asyncio -async def test_chunked_add_with_existing(): - file_path = "test_file.py" - collection = AsyncMock() - collection.get = AsyncMock() - collection.get.return_value = {"ids": ["id1"], "metadatas": [{"sha256": "hash1"}]} - collection_lock = asyncio.Lock() - stats = VectoriseStats() - stats_lock = asyncio.Lock() - configs = Config(chunk_size=100, overlap_ratio=0.2, project_root=".") - max_batch_size = 50 - semaphore = asyncio.Semaphore(1) - +async def test_vectorise_success(tmp_path): + config = Config(project_root=str(tmp_path), files=["file1.py", "file2.py"]) with ( - patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk, - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, + patch("vectorcode.subcommands.vectorise.get_database_connector") as mock_get_db, + patch( + "vectorcode.subcommands.vectorise.expand_globs", return_value=config.files + ), + patch("vectorcode.subcommands.vectorise.FilterManager") as mock_filter_manager, + patch( + "vectorcode.subcommands.vectorise.vectorise_worker", new_callable=AsyncMock + ) as mock_worker, + patch("vectorcode.subcommands.vectorise.show_stats") as mock_show_stats, ): - mock_hash_file.return_value = "hash1" - mock_chunk.return_value = [Chunk("chunk1", Point(1, 0), Point(1, 5)), "chunk2"] - await chunked_add( - file_path, - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - - assert stats.add == 0 - assert stats.update == 0 - collection.add.assert_not_called() - + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = [] + mock_get_db.return_value = mock_db + mock_filter_manager.return_value.return_value = config.files -@pytest.mark.asyncio -async def test_chunked_add_update_existing(): - file_path = "test_file.py" - collection = AsyncMock() - collection.get = AsyncMock() - collection.get.return_value = {"ids": ["id1"], "metadatas": [{"sha256": "hash1"}]} - collection_lock = asyncio.Lock() - stats = VectoriseStats() - stats_lock = asyncio.Lock() - configs = Config(chunk_size=100, overlap_ratio=0.2, project_root=".") - max_batch_size = 50 - semaphore = asyncio.Semaphore(1) + result = await vectorise(config) - with ( - patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk, - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, - ): - mock_hash_file.return_value = "hash2" - mock_chunk.return_value = [Chunk("chunk1", Point(1, 0), Point(1, 5)), "chunk2"] - await chunked_add( - file_path, - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - - assert stats.add == 0 - assert stats.update == 1 - collection.add.assert_called() + assert result == 0 + assert mock_worker.call_count == 2 + mock_show_stats.assert_called_once() @pytest.mark.asyncio -async def test_chunked_add_empty_file(): - file_path = "test_file.py" - collection = AsyncMock() - collection_lock = asyncio.Lock() - stats = VectoriseStats(**{"add": 0, "update": 0}) - stats_lock = asyncio.Lock() - configs = Config(chunk_size=100, overlap_ratio=0.2, project_root=".") - max_batch_size = 50 - semaphore = asyncio.Semaphore(1) - +async def test_vectorise_with_excludes(tmp_path): + config = Config(project_root=str(tmp_path), files=["file1.py", "file2.py"]) with ( - patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk, - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, + patch("vectorcode.subcommands.vectorise.get_database_connector") as mock_get_db, + patch( + "vectorcode.subcommands.vectorise.expand_globs", return_value=config.files + ), + patch("vectorcode.subcommands.vectorise.FilterManager") as mock_filter_manager, + patch( + "vectorcode.subcommands.vectorise.vectorise_worker", new_callable=AsyncMock + ) as mock_worker, + patch( + "vectorcode.subcommands.vectorise.find_exclude_specs", + return_value=[".gitignore"], + ), ): - mock_hash_file.return_value = "hash1" - mock_chunk.return_value = [] - await chunked_add( - file_path, - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, - ) - - assert stats.add == 0 - assert stats.update == 0 - assert collection.add.call_count == 0 - - -@patch("tabulate.tabulate") -def test_show_stats_pipe_false(mock_tabulate, capsys): - configs = Config(pipe=False) - stats = VectoriseStats(**{"add": 1, "update": 2, "removed": 3}) - show_stats(configs, stats) - mock_tabulate.assert_called_once() - + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = [] + mock_get_db.return_value = mock_db + mock_filter_manager.return_value.return_value = ["file1.py"] -def test_show_stats_pipe_true(capsys): - configs = Config(pipe=True) - stats = VectoriseStats(**{"add": 1, "update": 2, "removed": 3}) - show_stats(configs, stats) - captured = capsys.readouterr() - assert captured.out.strip() == (stats.to_json()) + await vectorise(config) + assert mock_worker.call_count == 1 -def test_exclude_paths_by_spec(): - with tempfile.TemporaryDirectory() as dir: - paths = list( - os.path.join(dir, i) for i in ["file1.py", "file2.py", "exclude.py"] - ) - spec_path = os.path.join(dir, ".gitignore") - with open(spec_path, mode="w") as spec_file: - spec_file.writelines(["exclude.py"]) - paths_after_exclude = exclude_paths_by_spec(paths, spec_path) - assert "exclude.py" not in paths_after_exclude - assert len(paths_after_exclude) == 2 - os.remove(spec_path) +@pytest.mark.asyncio +async def test_vectorise_collection_not_found(tmp_path): + full_path = os.path.join(tmp_path, "file1.py") + config = Config(project_root=str(tmp_path), files=[full_path]) + Path(full_path).touch() + with ( + patch("vectorcode.subcommands.vectorise.get_database_connector") as mock_get_db, + patch( + "vectorcode.subcommands.vectorise.expand_globs", return_value=config.files + ), + ): + mock_db = AsyncMock() + mock_db.list_collection_content.side_effect = CollectionNotFoundError + mock_get_db.return_value = mock_db + # This should not raise an exception + await vectorise(config) -def test_nested_exclude_paths_by_spec(): - paths = [ - "file1.py", - "file2.py", - "exclude.py", - os.path.join("nested", "nested_exclude.py"), - ] - with tempfile.TemporaryDirectory() as project_root: - paths = [os.path.join(project_root, i) for i in paths] - with open(os.path.join(project_root, ".gitignore"), mode="w") as fin: - fin.writelines(["/exclude.py"]) - nested_git_dir = os.path.join(project_root, "nested") - os.makedirs(nested_git_dir, exist_ok=True) - with open(os.path.join(nested_git_dir, ".gitignore"), mode="w") as fin: - fin.writelines(["/nested_exclude.py"]) +def test_find_exclude_specs(tmp_path): + config = Config(project_root=str(tmp_path), recursive=True) + gitignore_path = tmp_path / ".gitignore" + gitignore_path.touch() + nested_dir = tmp_path / "nested" + nested_dir.mkdir() + nested_gitignore_path = nested_dir / ".gitignore" + nested_gitignore_path.touch() - specs = find_exclude_specs(Config(project_root=project_root, recursive=True)) - paths_after_exclude = paths[:] - for spec in specs: - paths_after_exclude = exclude_paths_by_spec(paths_after_exclude, spec) - assert "exclude.py" not in paths_after_exclude - assert "nested/nested_exclude.py" not in paths_after_exclude - assert len(paths_after_exclude) == 2 + specs = find_exclude_specs(config) + assert str(gitignore_path) in specs + assert str(nested_gitignore_path) in specs -@patch("os.path.isfile") +@patch("os.path.isfile", return_value=True) @patch("pathspec.PathSpec.check_tree_files") def test_load_files_from_local_include(mock_check_tree_files, mock_isfile, tmp_path): - """Tests loading files when a local '.vectorcode/vectorcode.include' exists.""" project_root = str(tmp_path) local_include_dir = tmp_path / ".vectorcode" local_include_dir.mkdir() @@ -312,545 +154,107 @@ def test_load_files_from_local_include(mock_check_tree_files, mock_isfile, tmp_p local_include_content = "local_file1.py\nlocal_file2.py" local_include_file.write_text(local_include_content) - # Mock os.path.isfile to return True only for the local file mock_isfile.side_effect = lambda p: str(p) == str(local_include_file) - # Mock check_tree_files mock_check_tree_files.return_value = [ MagicMock(file="local_file1.py", include=True), MagicMock(file="local_file2.py", include=True), MagicMock(file="ignored_file.py", include=False), ] - # Use mock_open for the specific local file path - m_open = mock_open(read_data=local_include_content) + m_open = MagicMock() + m_open.return_value.__enter__.return_value.readlines.return_value = ( + local_include_content.splitlines() + ) with patch("builtins.open", m_open): files = load_files_from_include(project_root) assert "local_file1.py" in files assert "local_file2.py" in files assert "ignored_file.py" not in files - assert len(files) == 2 - mock_isfile.assert_any_call(str(local_include_file)) - m_open.assert_called_once_with(str(local_include_file)) - mock_check_tree_files.assert_called_once() -@patch("os.path.isfile") -@patch("pathspec.PathSpec.check_tree_files") -def test_load_files_from_global_include(mock_check_tree_files, mock_isfile, tmp_path): - """Tests loading files when only a global include spec exists.""" - project_root = str(tmp_path) - local_include_file = tmp_path / ".vectorcode" / "vectorcode.include" +def test_find_exclude_specs_non_recursive(tmp_path): + config = Config(project_root=str(tmp_path), recursive=False) + gitignore_path = tmp_path / ".gitignore" + gitignore_path.touch() + nested_dir = tmp_path / "nested" + nested_dir.mkdir() + nested_gitignore_path = nested_dir / ".gitignore" + nested_gitignore_path.touch() - # Simulate a global include file - # Note: We don't actually need the real global path, just a path to use in mocks - temp_global_include_dir = tmp_path / "global_config" - temp_global_include_dir.mkdir() - global_include_file = temp_global_include_dir / "vectorcode.include" - global_include_content = "global_file1.py\nglobal_file2.py" - global_include_file.write_text(global_include_content) + specs = find_exclude_specs(config) + assert str(gitignore_path) in specs + assert str(nested_gitignore_path) not in specs - # Mock os.path.isfile: False for local, True for (mocked) global - mock_isfile.side_effect = lambda p: str(p) == str(global_include_file) - # Mock check_tree_files - mock_check_tree_files.return_value = [ - MagicMock(file="global_file1.py", include=True), - MagicMock(file="global_file2.py", include=True), - MagicMock(file="ignored_global.py", include=False), - ] +@patch("os.path.isfile") +def test_find_exclude_specs_global(mock_isfile, tmp_path): + from vectorcode.subcommands.vectorise import GLOBAL_EXCLUDE_SPEC - m_open = mock_open(read_data=global_include_content) - # Patch builtins.open and the GLOBAL_INCLUDE_SPEC constant used internally - with ( - patch("builtins.open", m_open), - patch( - "vectorcode.subcommands.vectorise.GLOBAL_INCLUDE_SPEC", - str(global_include_file), - ), - ): - files = load_files_from_include(project_root) + config = Config(project_root=str(tmp_path), recursive=False) - assert "global_file1.py" in files - assert "global_file2.py" in files - assert "ignored_global.py" not in files - assert len(files) == 2 - mock_isfile.assert_any_call(str(local_include_file)) - mock_isfile.assert_any_call(str(global_include_file)) - m_open.assert_called_once_with( - str(global_include_file) - ) # Check the global file was opened - mock_check_tree_files.assert_called_once() + def isfile_side_effect(path): + if path == GLOBAL_EXCLUDE_SPEC: + return True + return os.path.join(str(tmp_path), ".gitignore") == path + mock_isfile.side_effect = isfile_side_effect -@patch("os.path.isfile", return_value=False) # Neither local nor global exists -@patch("pathspec.PathSpec.check_tree_files") -def test_load_files_from_include_no_files(mock_check_tree_files, mock_isfile, tmp_path): - """Tests behavior when neither local nor global include files exist.""" - project_root = str(tmp_path) - local_include_file = tmp_path / ".vectorcode" / "vectorcode.include" - # Assume a mocked global path for the check - mocked_global_path = "/mock/global/.config/vectorcode/vectorcode.include" + specs = find_exclude_specs(config) + assert GLOBAL_EXCLUDE_SPEC in specs - with patch( - "vectorcode.subcommands.vectorise.GLOBAL_INCLUDE_SPEC", mocked_global_path - ): - files = load_files_from_include(project_root) - assert files == [] - mock_isfile.assert_any_call(str(local_include_file)) - mock_isfile.assert_any_call(mocked_global_path) - mock_check_tree_files.assert_not_called() +def test_find_exclude_specs_non_recursive_no_gitignore(tmp_path): + config = Config(project_root=str(tmp_path), recursive=False) + specs = find_exclude_specs(config) + assert specs == [] -@pytest.mark.asyncio -async def test_vectorise(capsys): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, - ) - mock_client = AsyncMock() - mock_collection = MagicMock(spec=AsyncCollection) - mock_collection.get.return_value = {"ids": []} - mock_collection.delete.return_value = None - mock_collection.metadata = { - "embedding_function": "SentenceTransformerEmbeddingFunction", - "path": "/test_project", - "hostname": socket.gethostname(), - "created-by": "VectorCode", - "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), - } - mock_client.get_max_batch_size.return_value = 50 - mock_embedding_function = MagicMock() - - with ExitStack() as stack: - stack.enter_context( - patch("vectorcode.subcommands.vectorise.ClientManager"), - ) - stack.enter_context(patch("os.path.isfile", return_value=False)) - stack.enter_context( - patch( - "vectorcode.subcommands.vectorise.expand_globs", - return_value=["test_file.py"], - ) - ) - mock_chunked_add = stack.enter_context( - patch("vectorcode.subcommands.vectorise.chunked_add", return_value=None) - ) - stack.enter_context( - patch( - "vectorcode.common.get_embedding_function", - return_value=mock_embedding_function, - ) - ) - stack.enter_context( - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ) - ) - - result = await vectorise(configs) - assert result == 0 - assert mock_chunked_add.call_count == 1 +def test_find_exclude_specs_local_exclude(tmp_path): + config = Config(project_root=str(tmp_path), recursive=False) + exclude_dir = tmp_path / ".vectorcode" + exclude_dir.mkdir() + exclude_file = exclude_dir / "vectorcode.exclude" + exclude_file.touch() + specs = find_exclude_specs(config) + assert str(exclude_file) in specs -@pytest.mark.asyncio -async def test_vectorise_cancelled(): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, - ) - - async def mock_chunked_add(*args, **kwargs): - raise asyncio.CancelledError - mock_client = AsyncMock() - mock_collection = AsyncMock() - - with ( - patch( - "vectorcode.subcommands.vectorise.chunked_add", side_effect=mock_chunked_add - ) as mock_add, - patch("sys.stderr") as mock_stderr, - patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch( - "os.path.isfile", - lambda x: not (x.endswith("gitignore") or x.endswith("vectorcode.exclude")), - ), - ): - MockClientManager.return_value._create_client.return_value = mock_client - result = await vectorise(configs) - assert result == 1 - mock_add.assert_called_once() - mock_stderr.write.assert_called() +@patch("os.path.isfile") +@patch("pathspec.PathSpec.check_tree_files") +def test_load_files_from_global_include(mock_check_tree_files, mock_isfile, tmp_path): + from vectorcode.subcommands.vectorise import GLOBAL_INCLUDE_SPEC + project_root = str(tmp_path) + global_include_content = "global_file1.py\nglobal_file2.py" -@pytest.mark.asyncio -async def test_vectorise_orphaned_files(): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, - ) + def isfile_side_effect(p): + return str(p) == GLOBAL_INCLUDE_SPEC - AsyncMock() - mock_collection = AsyncMock() + mock_isfile.side_effect = isfile_side_effect - # Define a mock response for collection.get in vectorise - get_return = { - "metadatas": [{"path": "test_file.py"}, {"path": "non_existent_file.py"}] - } - mock_collection.get.side_effect = [ - {"ids": [], "metadatas": []}, # Return value for chunked_add - get_return, # Return value for orphaned files + mock_check_tree_files.return_value = [ + MagicMock(file="global_file1.py", include=True), + MagicMock(file="global_file2.py", include=True), + MagicMock(file="ignored_file.py", include=False), ] - mock_collection.delete.return_value = None - - # Mock TreeSitterChunker - mock_chunker = AsyncMock() - - def chunk(*args, **kwargs): - return ["chunk1", "chunk2"] - - mock_chunker.chunk = chunk - - # Mock os.path.isfile - def is_file_side_effect(path): - if path == "non_existent_file.py": - return False - elif path.endswith(".gitignore") or path.endswith("vectorcode.exclude"): - return False - else: - return True - - mock_embedding_function = MagicMock(return_value=numpy.random.random((100,))) - with ( - patch("os.path.isfile", side_effect=is_file_side_effect), - patch( - "vectorcode.subcommands.vectorise.TreeSitterChunker", - return_value=mock_chunker, - ), - patch("vectorcode.subcommands.vectorise.ClientManager"), - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch( - "vectorcode.subcommands.vectorise.get_embedding_function", - return_value=mock_embedding_function, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch( - "vectorcode.subcommands.vectorise.expand_globs", - return_value=["test_file.py"], # Ensure expand_globs returns a valid file - ), - patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file, - ): - mock_hash_file.return_value = "hash1" - result = await vectorise(configs) - - assert result == 0 - mock_collection.delete.assert_called_once_with( - where={"path": {"$in": ["non_existent_file.py"]}} - ) - -@pytest.mark.asyncio -async def test_vectorise_collection_index_error(): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, + m_open = MagicMock() + m_open.return_value.__enter__.return_value.readlines.return_value = ( + global_include_content.splitlines() ) + with patch("builtins.open", m_open): + files = load_files_from_include(project_root) - mock_client = AsyncMock() - - with ( - patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, - patch("vectorcode.subcommands.vectorise.get_collection") as mock_get_collection, - patch("os.path.isfile", return_value=False), - ): - MockClientManager.return_value._create_client.return_value = mock_client - mock_get_collection.side_effect = IndexError("Collection not found") - result = await vectorise(configs) - assert result == 1 - - -@pytest.mark.asyncio -async def test_vectorise_verify_ef_false(): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, - ) - mock_client = AsyncMock() - mock_collection = AsyncMock() - - with ( - patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=False), - patch("os.path.isfile", return_value=False), - ): - MockClientManager.return_value._create_client.return_value = mock_client - result = await vectorise(configs) - assert result == 1 - - -@pytest.mark.asyncio -async def test_vectorise_gitignore(): - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - files=["test_file.py"], - recursive=False, - force=False, - pipe=False, - ) - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = {"metadatas": []} - - with ( - patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch( - "os.path.isfile", - side_effect=lambda path: path - == os.path.join("/test_project", ".gitignore"), - ), - patch("builtins.open", return_value=MagicMock()), - patch( - "vectorcode.subcommands.vectorise.expand_globs", - return_value=["test_file.py"], - ), - patch( - "vectorcode.subcommands.vectorise.exclude_paths_by_spec" - ) as mock_exclude_paths, - ): - MockClientManager.return_value._create_client.return_value = mock_client - await vectorise(configs) - mock_exclude_paths.assert_called_once() - - -@pytest.mark.asyncio -async def test_vectorise_exclude_file(): - # Create a temporary .vectorcode directory and vectorcode.exclude file - with tempfile.TemporaryDirectory() as tmpdir: - exclude_dir = os.path.join(tmpdir, ".vectorcode") - nested_dir = os.path.join(tmpdir, "nested") - - os.makedirs(exclude_dir, exist_ok=True) - os.makedirs(nested_dir, exist_ok=True) - - exclude_spec = os.path.join(exclude_dir, "vectorcode.exclude") - with open(exclude_spec, mode="w") as fin: - fin.writelines(["excluded_file.py"]) - with open(os.path.join(nested_dir, ".gitignore"), "w") as fin: - fin.writelines(["excluded_file.py"]) - nested_file_path = os.path.join(nested_dir, "nested_excluded_file.py") - with open(nested_file_path, "w") as fin: - # non-recursive case. This file should be included. - fin.writelines(['print("hello world")']) - - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root=str(tmpdir), - files=[ - os.path.join(tmpdir, "test_file.py"), - os.path.join(tmpdir, "excluded_file.py"), - nested_file_path, - ], - recursive=False, - force=False, - pipe=False, - ) - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = {"ids": []} - - with ( - patch( - "vectorcode.subcommands.vectorise.ClientManager" - ) as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch( - "vectorcode.subcommands.vectorise.expand_globs", - return_value=configs.files, - ), - patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, - ): - MockClientManager.return_value._create_client.return_value = mock_client - await vectorise(configs) - # Assert that chunked_add is only called for test_file.py, not excluded_file.py - call_args = [call[0][0] for call in mock_chunked_add.call_args_list] - assert str(os.path.join(tmpdir, "excluded_file.py")) not in call_args - assert os.path.join(tmpdir, "test_file.py") in call_args - assert mock_chunked_add.call_count == 2 - - -@pytest.mark.asyncio -async def test_vectorise_exclude_file_recursive(): - # Create a temporary .vectorcode directory and vectorcode.exclude file - with tempfile.TemporaryDirectory() as tmpdir: - exclude_dir = os.path.join(tmpdir, ".vectorcode") - nested_dir = os.path.join(tmpdir, "nested") - - os.makedirs(exclude_dir, exist_ok=True) - os.makedirs(nested_dir, exist_ok=True) - - exclude_spec = os.path.join(exclude_dir, "vectorcode.exclude") - with open(exclude_spec, mode="w") as fin: - fin.writelines(["excluded_file.py"]) - with open(os.path.join(nested_dir, ".gitignore"), "w") as fin: - fin.writelines(["excluded_file.py"]) - with open(os.path.join(nested_dir, "excluded_file.py"), "w") as fin: - # recursive case. This file should be skipped. - fin.writelines(['print("hello world")']) - - configs = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root=str(tmpdir), - files=[ - os.path.join(tmpdir, "test_file.py"), - os.path.join(tmpdir, "excluded_file.py"), - ], - recursive=True, - force=False, - pipe=False, - ) - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = {"ids": []} - - with ( - patch( - "vectorcode.subcommands.vectorise.ClientManager" - ) as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch( - "vectorcode.subcommands.vectorise.expand_globs", - return_value=configs.files, - ), - patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, - ): - MockClientManager.return_value._create_client.return_value = mock_client - await vectorise(configs) - # Assert that chunked_add is only called for test_file.py, not excluded_file.py - call_args = [call[0][0] for call in mock_chunked_add.call_args_list] - assert str(os.path.join(tmpdir, "excluded_file.py")) not in call_args - assert os.path.join(tmpdir, "test_file.py") in call_args - assert mock_chunked_add.call_count == 1 + assert "global_file1.py" in files + assert "global_file2.py" in files + assert "ignored_file.py" not in files -@pytest.mark.asyncio -async def test_vectorise_uses_global_exclude_when_local_missing(): - mock_client = AsyncMock() - mock_collection = AsyncMock() - mock_collection.get.return_value = {"ids": []} - - with tempfile.TemporaryDirectory() as temp_home: - os.environ["HOME"] = temp_home - global_config_dir = os.path.join(temp_home, ".config", "vectorcode") - os.makedirs(global_config_dir, exist_ok=True) - with open( - os.path.join(global_config_dir, "vectorcode.exclude"), mode="w" - ) as fin: - fin.writelines(["exclude.py"]) - - project_root = os.path.join(temp_home, "project") - os.makedirs(project_root, exist_ok=True) - files = list( - os.path.join(project_root, i) for i in ("include.py", "exclude.py") - ) - for f_name in files: - full_path = os.path.join(project_root, f_name) - with open(full_path, mode="w") as fin: - pass - with ( - patch( - "vectorcode.subcommands.vectorise.ClientManager" - ) as MockClientManager, - patch( - "vectorcode.subcommands.vectorise.get_collection", - return_value=mock_collection, - ), - patch("vectorcode.subcommands.vectorise.verify_ef", return_value=True), - patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add, - patch( - "vectorcode.subcommands.vectorise.GLOBAL_EXCLUDE_SPEC", - os.path.join(temp_home, ".config", "vectorcode", "vectorcode.exclude"), - ), - ): - MockClientManager.return_value._create_client.return_value = mock_client - await vectorise( - Config( - project_root=project_root, - files=list(os.path.join(project_root, i) for i in files), - action=CliAction.vectorise, - ) - ) - mock_chunked_add.assert_called_once() +@patch("os.path.isfile", return_value=False) +def test_load_files_from_include_no_spec(mock_isfile, tmp_path): + project_root = str(tmp_path) + files = load_files_from_include(project_root) + assert files == [] diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index bd10efc5..daffa4c0 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -36,7 +36,7 @@ async def test_config_import_from(): os.makedirs(db_path, exist_ok=True) config_dict: Dict[str, Any] = { "db_path": db_path, - "db_url": "http://test_host:1234", + "db_params": {"url": "http://test_host:1234"}, "embedding_function": "TestEmbedding", "embedding_params": {"param1": "value1"}, "chunk_size": 512, @@ -44,12 +44,9 @@ async def test_config_import_from(): "query_multiplier": 5, "reranker": "TestReranker", "reranker_params": {"reranker_param1": "reranker_value1"}, - "db_settings": {"db_setting1": "db_value1"}, } config = await Config.import_from(config_dict) - assert config.db_path == db_path - assert config.db_log_path == os.path.expanduser("~/.local/share/vectorcode/") - assert config.db_url == "http://test_host:1234" + assert isinstance(config.db_params, dict) assert config.embedding_function == "TestEmbedding" assert config.embedding_params == {"param1": "value1"} assert config.chunk_size == 512 @@ -57,44 +54,24 @@ async def test_config_import_from(): assert config.query_multiplier == 5 assert config.reranker == "TestReranker" assert config.reranker_params == {"reranker_param1": "reranker_value1"} - assert config.db_settings == {"db_setting1": "db_value1"} - - -@pytest.mark.asyncio -async def test_config_import_from_invalid_path(): - config_dict: Dict[str, Any] = {"db_path": "/path/does/not/exist"} - with pytest.raises(IOError): - await Config.import_from(config_dict) - - -@pytest.mark.asyncio -async def test_config_import_from_db_path_is_file(): - with tempfile.TemporaryDirectory() as temp_dir: - db_path = os.path.join(temp_dir, "test_db_file") - with open(db_path, "w") as f: - f.write("test") - - config_dict: Dict[str, Any] = {"db_path": db_path} - with pytest.raises(IOError): - await Config.import_from(config_dict) @pytest.mark.asyncio async def test_config_merge_from(): - config1 = Config(db_url="http://host1:8001", n_result=5) - config2 = Config(db_url="http://host2:8002", query=["test"]) + config1 = Config(db_params={"url": "http://host1:8001"}, n_result=5) + config2 = Config(db_params={"url": "http://host2:8002"}, query=["test"]) merged_config = await config1.merge_from(config2) - assert merged_config.db_url == "http://host2:8002" + assert merged_config.db_params["url"] == "http://host2:8002" assert merged_config.n_result == 5 assert merged_config.query == ["test"] @pytest.mark.asyncio async def test_config_merge_from_new_fields(): - config1 = Config(db_url="http://host1:8001") + config1 = Config(db_params={"url": "http://host1:8001"}) config2 = Config(query=["test"], n_result=10, recursive=True) merged_config = await config1.merge_from(config2) - assert merged_config.db_url == "http://host1:8001" + assert merged_config.db_params["url"] == "http://host1:8001" assert merged_config.query == ["test"] assert merged_config.n_result == 10 assert merged_config.recursive @@ -104,18 +81,17 @@ async def test_config_merge_from_new_fields(): async def test_config_import_from_missing_keys(): config_dict: Dict[str, Any] = {} # Empty dictionary, all keys missing config = await Config.import_from(config_dict) + default_config = Config() # Assert that default values are used - assert config.embedding_function == "SentenceTransformerEmbeddingFunction" - assert config.embedding_params == {} - assert config.db_url == "http://127.0.0.1:8000" - assert config.db_path == os.path.expanduser("~/.local/share/vectorcode/chromadb/") - assert config.chunk_size == 2500 - assert config.overlap_ratio == 0.2 - assert config.query_multiplier == -1 - assert config.reranker == "NaiveReranker" - assert config.reranker_params == {} - assert config.db_settings is None + assert config.embedding_function == default_config.embedding_function + assert config.embedding_params == default_config.embedding_params + assert config.db_params == default_config.db_params + assert config.chunk_size == default_config.chunk_size + assert config.overlap_ratio == default_config.overlap_ratio + assert config.query_multiplier == default_config.query_multiplier + assert config.reranker == default_config.reranker + assert config.reranker_params == default_config.reranker_params def test_expand_envs_in_dict(): @@ -133,6 +109,8 @@ def test_expand_envs_in_dict(): expand_envs_in_dict(d) assert d["key4"] == "$TEST_VAR2" # Should remain unchanged + expand_envs_in_dict(None) + del os.environ["TEST_VAR"] # Clean up the env @@ -222,12 +200,12 @@ async def test_load_from_default_config(): config_dir, ) os.makedirs(config_dir, exist_ok=True) - config_content = '{"db_url": "http://default.url:8000"}' + config_content = '{"db_params": {"url": "http://default.url:8000"}}' with open(config_path, "w") as fin: fin.write(config_content) config = await load_config_file() - assert config.db_url == "http://default.url:8000" + assert isinstance(config.db_params, dict) @pytest.mark.asyncio @@ -321,6 +299,7 @@ async def test_cli_arg_parser(): def test_query_include_to_header(): assert QueryInclude.path.to_header() == "Path: " assert QueryInclude.document.to_header() == "Document:\n" + assert QueryInclude.chunk.to_header() == "Chunk: " def test_find_project_root(): @@ -402,12 +381,12 @@ async def test_parse_cli_args_vectorise_recursive_dir(): @pytest.mark.asyncio async def test_parse_cli_args_vectorise_recursive_dir_include_hidden(): - with patch("sys.argv", ["vectorcode", "vectorise", "-r", "."]): + with patch("sys.argv", ["vectorcode", "vectorise", "-r", ".", "--include-hidden"]): config = await parse_cli_args() assert config.action == CliAction.vectorise assert config.files == ["."] assert config.recursive is True - assert config.include_hidden is False + assert config.include_hidden is True @pytest.mark.asyncio @@ -425,10 +404,10 @@ async def test_get_project_config_local_config(tmp_path): vectorcode_dir.mkdir(parents=True) config_file = vectorcode_dir / "config.json" - config_file.write_text('{"db_url": "http://test_host:9999" }') + config_file.write_text('{"db_params": {"url": "http://test_host:9999"} }') config = await get_project_config(project_root) - assert config.db_url == "http://test_host:9999" + assert isinstance(config.db_params, dict) @pytest.mark.asyncio @@ -438,10 +417,10 @@ async def test_get_project_config_local_config_json5(tmp_path): vectorcode_dir.mkdir(parents=True) config_file = vectorcode_dir / "config.json5" - config_file.write_text('{"db_url": "http://test_host:9999" }') + config_file.write_text('{"db_params": {"url": "http://test_host:9999"} }') config = await get_project_config(project_root) - assert config.db_url == "http://test_host:9999" + assert isinstance(config.db_params, dict) def test_find_project_root_file_input(tmp_path): @@ -484,9 +463,10 @@ async def test_parse_cli_args_check(): @pytest.mark.asyncio async def test_parse_cli_args_init(): - with patch("sys.argv", ["vectorcode", "init"]): + with patch("sys.argv", ["vectorcode", "init", "--force"]): config = await parse_cli_args() assert config.action == CliAction.init + assert config.force is True @pytest.mark.asyncio @@ -527,37 +507,15 @@ async def test_parse_cli_args_files(): assert config.rm_paths == ["foo.txt"] -@pytest.mark.asyncio -async def test_config_import_from_hnsw(): - with tempfile.TemporaryDirectory() as temp_dir: - db_path = os.path.join(temp_dir, "test_db") - os.makedirs(db_path, exist_ok=True) - config_dict: Dict[str, Any] = { - "hnsw": {"space": "cosine", "ef_construction": 200, "m": 32} - } - config = await Config.import_from(config_dict) - assert config.hnsw["space"] == "cosine" - assert config.hnsw["ef_construction"] == 200 - assert config.hnsw["m"] == 32 - - -@pytest.mark.asyncio -async def test_hnsw_config_merge(): - config1 = Config(hnsw={"space": "ip"}) - config2 = Config(hnsw={"ef_construction": 200}) - merged_config = await config1.merge_from(config2) - assert merged_config.hnsw["space"] == "ip" - assert merged_config.hnsw["ef_construction"] == 200 - - def test_cleanup_path(): home = os.environ.get("HOME") - if home is None: - return - assert cleanup_path(os.path.join(home, "test_path")) == os.path.join( - "~", "test_path" - ) + if home: + assert cleanup_path(os.path.join(home, "test_path")) == os.path.join( + "~", "test_path" + ) assert cleanup_path("/etc/dir") == "/etc/dir" + with patch.dict(os.environ, {"HOME": ""}): + assert cleanup_path("/etc/dir") == "/etc/dir" def test_shtab(): @@ -576,8 +534,10 @@ def test_shtab(): async def test_filelock(): manager = LockManager() with tempfile.TemporaryDirectory() as tmp_dir: - manager.get_lock(tmp_dir) + lock = manager.get_lock(tmp_dir) assert os.path.isfile(os.path.join(tmp_dir, "vectorcode.lock")) + # test getting existing lock + assert lock is manager.get_lock(tmp_dir) def test_specresolver(): @@ -590,6 +550,12 @@ def test_specresolver(): SpecResolver(spec, base_dir="nested").match([nested_path], negated=True) ) + assert SpecResolver(spec, base_dir="nested").match_file(nested_path) + assert not SpecResolver(spec, base_dir="nested").match_file( + nested_path, negated=True + ) + assert SpecResolver(spec).match_file("../outside_file.txt") + with tempfile.TemporaryDirectory() as dir: nested_dir = os.path.join(dir, "nested") nested_path = os.path.join(nested_dir, "file1.txt") @@ -610,23 +576,38 @@ def test_specresolver_builder(): patch("vectorcode.cli_utils.open"), ): base_dir = os.path.normpath(os.path.join("foo", "bar")) - assert ( - os.path.normpath( - SpecResolver.from_path(os.path.join(base_dir, ".gitignore")).base_dir - ) - == base_dir - ) + assert os.path.normpath( + SpecResolver.from_path(os.path.join(base_dir, ".gitignore")).base_dir + ) == os.path.abspath(base_dir) - assert ( - os.path.normpath( - SpecResolver.from_path( - os.path.join(base_dir, ".vectorcode", "vectorcode.exclude") - ).base_dir - ) - == base_dir - ) assert os.path.normpath( SpecResolver.from_path( - os.path.join(base_dir, "vectorcode", "vectorcode.exclude") + os.path.join(base_dir, ".vectorcode", "vectorcode.exclude") ).base_dir - ) == os.path.normpath(".") + ) == os.path.abspath(base_dir) + assert os.path.normpath( + SpecResolver.from_path( + os.path.join(base_dir, "vectorcode", "vectorcode.exclude"), + project_root=base_dir, + ).base_dir + ) == os.path.abspath(base_dir) + with pytest.raises(ValueError): + SpecResolver.from_path("foo/bar") + + +@pytest.mark.asyncio +async def test_find_project_root_at_root(): + with tempfile.TemporaryDirectory() as temp_dir: + os.makedirs(os.path.join(temp_dir, ".git")) + # in a git repo, find_project_root should not go beyond the git root + assert os.path.samefile(find_project_root(temp_dir, ".git"), temp_dir) + assert find_project_root(temp_dir, ".vectorcode") is None + + +@pytest.mark.asyncio +async def test_find_project_config_dir_at_root(): + with tempfile.TemporaryDirectory() as temp_dir: + git_dir = os.path.join(temp_dir, ".git") + os.makedirs(git_dir) + # in a git repo, find_project_root should not go beyond the git root + assert os.path.samefile(await find_project_config_dir(temp_dir), git_dir) diff --git a/tests/test_common.py b/tests/test_common.py deleted file mode 100644 index c0dbdc5f..00000000 --- a/tests/test_common.py +++ /dev/null @@ -1,642 +0,0 @@ -import os -import socket -import subprocess -import sys -import tempfile -from unittest.mock import AsyncMock, MagicMock, patch - -import httpx -import pytest -from chromadb.api import AsyncClientAPI -from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.utils import embedding_functions - -from vectorcode.cli_utils import Config -from vectorcode.common import ( - ClientManager, - get_collection, - get_collection_name, - get_collections, - get_embedding_function, - start_server, - try_server, - verify_ef, - wait_for_server, -) - - -def test_get_collection_name(): - with tempfile.TemporaryDirectory() as temp_dir: - file_path = os.path.join(temp_dir, "test_file.txt") - collection_name = get_collection_name(file_path) - assert isinstance(collection_name, str) - assert len(collection_name) == 63 - - # Test that the collection name is consistent for the same path - collection_name2 = get_collection_name(file_path) - assert collection_name == collection_name2 - - # Test that the collection name is different for different paths - file_path2 = os.path.join(temp_dir, "another_file.txt") - collection_name2 = get_collection_name(file_path2) - assert collection_name != collection_name2 - - # Test with absolute path - abs_file_path = os.path.abspath(file_path) - collection_name3 = get_collection_name(abs_file_path) - assert collection_name == collection_name3 - - -def test_get_embedding_function(): - # Test with a valid embedding function - config = Config( - embedding_function="SentenceTransformerEmbeddingFunction", embedding_params={} - ) - embedding_function = get_embedding_function(config) - assert "SentenceTransformerEmbeddingFunction" in str(type(embedding_function)) - - # Test with an invalid embedding function (fallback to SentenceTransformer) - config = Config(embedding_function="FakeEmbeddingFunction", embedding_params={}) - embedding_function = get_embedding_function(config) - assert "SentenceTransformerEmbeddingFunction" in str(type(embedding_function)) - - # Test with specific embedding parameters - config = Config( - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={"param1": "value1"}, - ) - embedding_function = get_embedding_function(config) - assert "SentenceTransformerEmbeddingFunction" in str(type(embedding_function)) - - -def test_get_embedding_function_init_exception(): - # Test when the embedding function exists but raises an error during initialization - config = Config( - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={"model_name": "non_existent_model_should_cause_error"}, - ) - - # Mock SentenceTransformerEmbeddingFunction.__init__ to raise a generic exception - with patch.object( - embedding_functions, "SentenceTransformerEmbeddingFunction", autospec=True - ) as mock_stef: - # Simulate an error during the embedding function's __init__ - mock_stef.side_effect = Exception("Simulated initialization error") - - with pytest.raises(Exception) as excinfo: - get_embedding_function(config) - - # Check if the raised exception is the one we simulated - assert "Simulated initialization error" in str(excinfo.value) - # Check if the additional note was added - assert "For errors caused by missing dependency" in excinfo.value.__notes__[0] - - # Verify that the constructor was called with the correct parameters - mock_stef.assert_called_once_with( - model_name="non_existent_model_should_cause_error" - ) - - -@pytest.mark.asyncio -async def test_try_server_versions(): - # Test successful v1 response - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.return_value = ( - mock_response - ) - assert await try_server("http://localhost:8300") is True - mock_client.return_value.__aenter__.return_value.get.assert_called_once_with( - url="http://localhost:8300/api/v1/heartbeat" - ) - - # Test fallback to v2 when v1 fails - with patch("httpx.AsyncClient") as mock_client: - mock_response_v1 = MagicMock() - mock_response_v1.status_code = 404 - mock_response_v2 = MagicMock() - mock_response_v2.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.side_effect = [ - mock_response_v1, - mock_response_v2, - ] - assert await try_server("http://localhost:8300") is True - assert mock_client.return_value.__aenter__.return_value.get.call_count == 2 - - # Test both versions fail - with patch("httpx.AsyncClient") as mock_client: - mock_response_v1 = MagicMock() - mock_response_v1.status_code = 404 - mock_response_v2 = MagicMock() - mock_response_v2.status_code = 500 - mock_client.return_value.__aenter__.return_value.get.side_effect = [ - mock_response_v1, - mock_response_v2, - ] - assert await try_server("http://localhost:8300") is False - - # Test connection error cases - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectError("Cannot connect") - ) - assert await try_server("http://localhost:8300") is False - - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectTimeout("Connection timeout") - ) - assert await try_server("http://localhost:8300") is False - - -def test_verify_ef(): - # Mocking AsyncCollection and Config - mock_collection = MagicMock() - mock_config = MagicMock() - - # Test when collection_ef and config.embedding_function are the same - mock_collection.metadata = {"embedding_function": "test_embedding_function"} - mock_config.embedding_function = "test_embedding_function" - assert verify_ef(mock_collection, mock_config) is True - - # Test when collection_ef and config.embedding_function are different - mock_collection.metadata = {"embedding_function": "test_embedding_function"} - mock_config.embedding_function = "another_embedding_function" - assert verify_ef(mock_collection, mock_config) is False - - # Test when collection_ep and config.embedding_params are the same - mock_collection.metadata = {"embedding_params": {"param1": "value1"}} - mock_config.embedding_params = {"param1": "value1"} - assert verify_ef(mock_collection, mock_config) is True - - # Test when collection_ep and config.embedding_params are different - mock_collection.metadata = {"embedding_params": {"param1": "value1"}} - mock_config.embedding_params = {"param1": "value2"} - assert ( - verify_ef(mock_collection, mock_config) is True - ) # It should return True according to the source code. - - # Test when collection_ef is None - mock_collection.metadata = {} - mock_config.embedding_function = "test_embedding_function" - assert verify_ef(mock_collection, mock_config) is True - - -@patch("socket.socket") -@pytest.mark.asyncio -async def test_try_server_mocked(mock_socket): - # Mocking httpx.AsyncClient and its get method to simulate a successful connection - with patch("httpx.AsyncClient") as mock_client: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.return_value.__aenter__.return_value.get.return_value = ( - mock_response - ) - assert await try_server("http://localhost:8000") is True - - # Mocking httpx.AsyncClient to raise a ConnectError - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectError("Simulated connection error") - ) - assert await try_server("http://localhost:8000") is False - - # Mocking httpx.AsyncClient to raise a ConnectTimeout - with patch("httpx.AsyncClient") as mock_client: - mock_client.return_value.__aenter__.return_value.get.side_effect = ( - httpx.ConnectTimeout("Simulated connection timeout") - ) - assert await try_server("http://localhost:8000") is False - - -@pytest.mark.asyncio -async def test_get_collection(): - config = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - ) - - # Test retrieving an existing collection - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: - mock_client = MagicMock(spec=AsyncClientAPI) - mock_collection = MagicMock() - mock_client.get_collection.return_value = mock_collection - MockAsyncHttpClient.return_value = mock_client - - collection = await get_collection(mock_client, config) - assert collection == mock_collection - mock_client.get_collection.assert_called_once() - mock_client.get_or_create_collection.assert_not_called() - - # Test creating a collection if it doesn't exist - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: - mock_client = MagicMock(spec=AsyncClientAPI) - mock_collection = MagicMock() - - # Clear the collection cache - from vectorcode.common import __COLLECTION_CACHE - - __COLLECTION_CACHE.clear() - - # Make get_collection raise ValueError to trigger get_or_create_collection - mock_client.get_collection.side_effect = ValueError("Collection not found") - mock_collection.metadata = { - "hostname": socket.gethostname(), - "username": os.environ.get( - "USER", os.environ.get("USERNAME", "DEFAULT_USER") - ), - "created-by": "VectorCode", - } - - async def mock_get_or_create_collection( - self, - name=None, - configuration=None, - metadata=None, - embedding_function=None, - data_loader=None, - ): - mock_collection.metadata.update(metadata or {}) - return mock_collection - - mock_client.get_or_create_collection.side_effect = mock_get_or_create_collection - MockAsyncHttpClient.return_value = mock_client - - collection = await get_collection(mock_client, config, make_if_missing=True) - assert collection.metadata["hostname"] == socket.gethostname() - assert collection.metadata["username"] == os.environ.get( - "USER", os.environ.get("USERNAME", "DEFAULT_USER") - ) - assert collection.metadata["created-by"] == "VectorCode" - assert collection.metadata["hnsw:M"] == 64 - mock_client.get_or_create_collection.assert_called_once() - mock_client.get_collection.side_effect = None - - # Test raising IndexError on hash collision. - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: - mock_client = MagicMock(spec=AsyncClientAPI) - mock_client.get_or_create_collection.side_effect = IndexError( - "Hash collision occurred" - ) - MockAsyncHttpClient.return_value = mock_client - from vectorcode.common import __COLLECTION_CACHE - - __COLLECTION_CACHE.clear() - with pytest.raises(IndexError): - await get_collection(mock_client, config, make_if_missing=True) - - -@pytest.mark.asyncio -async def test_get_collection_hnsw(): - config = Config( - db_url="http://test_host:1234", - db_path="test_db", - embedding_function="SentenceTransformerEmbeddingFunction", - embedding_params={}, - project_root="/test_project", - hnsw={"ef_construction": 200, "M": 32}, - ) - - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: - mock_client = MagicMock(spec=AsyncClientAPI) - mock_collection = MagicMock() - mock_collection.metadata = { - "hostname": socket.gethostname(), - "username": os.environ.get( - "USER", os.environ.get("USERNAME", "DEFAULT_USER") - ), - "created-by": "VectorCode", - "hnsw:ef_construction": 200, - "hnsw:M": 32, - "embedding_function": "SentenceTransformerEmbeddingFunction", - "path": "/test_project", - } - mock_client.get_or_create_collection.return_value = mock_collection - MockAsyncHttpClient.return_value = mock_client - - # Clear the collection cache to force creation - from vectorcode.common import __COLLECTION_CACHE - - __COLLECTION_CACHE.clear() - - collection = await get_collection(mock_client, config, make_if_missing=True) - - assert collection.metadata["hostname"] == socket.gethostname() - assert collection.metadata["username"] == os.environ.get( - "USER", os.environ.get("USERNAME", "DEFAULT_USER") - ) - assert collection.metadata["created-by"] == "VectorCode" - assert collection.metadata["hnsw:ef_construction"] == 200 - assert collection.metadata["hnsw:M"] == 32 - mock_client.get_or_create_collection.assert_called_once() - assert ( - mock_client.get_or_create_collection.call_args.kwargs["metadata"] - == mock_collection.metadata - ) - - -@pytest.mark.asyncio -async def test_start_server(): - with tempfile.TemporaryDirectory() as temp_dir: - - def _new_isdir(path): - if str(temp_dir) in str(path): - return True - return False - - # Mock subprocess.Popen - with ( - patch("asyncio.create_subprocess_exec") as MockCreateProcess, - patch("asyncio.sleep"), - patch("socket.socket") as MockSocket, - patch("vectorcode.common.wait_for_server") as MockWaitForServer, - patch("os.path.isdir") as mock_isdir, - patch("os.makedirs") as mock_makedirs, - ): - mock_isdir.side_effect = _new_isdir - # Mock socket to return a specific port - mock_socket = MagicMock() - mock_socket.getsockname.return_value = ("localhost", 12345) # Mock port - MockSocket.return_value.__enter__.return_value = mock_socket - - # Mock the process object - mock_process = MagicMock() - mock_process.returncode = 0 # Simulate successful execution - MockCreateProcess.return_value = mock_process - - # Create a config object - config = Config( - db_url="http://localhost:8000", - db_path=temp_dir, - project_root=temp_dir, - ) - - # Call start_server - process = await start_server(config) - - # Assert that asyncio.create_subprocess_exec was called with the correct arguments - MockCreateProcess.assert_called_once() - args, kwargs = MockCreateProcess.call_args - expected_args = [ - sys.executable, - "-m", - "chromadb.cli.cli", - "run", - "--host", - "localhost", - "--port", - str(12345), # Check the mocked port - "--path", - temp_dir, - "--log-path", - os.path.join(str(config.db_log_path), "chroma.log"), - ] - assert args[0] == sys.executable - assert tuple(args[1:]) == tuple(expected_args[1:]) - assert kwargs["stdout"] == subprocess.DEVNULL - assert kwargs["stderr"] == sys.stderr - assert "ANONYMIZED_TELEMETRY" in kwargs["env"] - assert config.db_url == "http://127.0.0.1:12345" - - MockWaitForServer.assert_called_once_with("http://127.0.0.1:12345") - - assert process == mock_process - mock_makedirs.assert_called_once_with(config.db_log_path) - - -@pytest.mark.asyncio -async def test_get_collections(): - # Mocking AsyncClientAPI and AsyncCollection - mock_client = MagicMock(spec=AsyncClientAPI) - - # Mock successful get_collection - mock_collection1 = MagicMock(spec=AsyncCollection) - mock_collection1.metadata = { - "created-by": "VectorCode", - "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), - "hostname": socket.gethostname(), - } - - # collection with meta == None - mock_collection2 = MagicMock(spec=AsyncCollection) - mock_collection2.metadata = None - - # collection with wrong "created-by" - mock_collection3 = MagicMock(spec=AsyncCollection) - mock_collection3.metadata = { - "created-by": "NotVectorCode", - "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), - "hostname": socket.gethostname(), - } - - # collection with wrong "username" - mock_collection4 = MagicMock(spec=AsyncCollection) - mock_collection4.metadata = { - "created-by": "VectorCode", - "username": "wrong_user", - "hostname": socket.gethostname(), - } - - # collection with wrong "hostname" - mock_collection5 = MagicMock(spec=AsyncCollection) - mock_collection5.metadata = { - "created-by": "VectorCode", - "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), - "hostname": "wrong_host", - } - - mock_client.list_collections.return_value = [ - "collection1", - "collection2", - "collection3", - "collection4", - "collection5", - ] - mock_client.get_collection.side_effect = [ - mock_collection1, - mock_collection2, - mock_collection3, - mock_collection4, - mock_collection5, - ] - - collections = [ - collection async for collection in get_collections(mock_client) - ] # call get_collections - assert len(collections) == 1 - assert collections[0] == mock_collection1 - - -def test_get_embedding_function_fallback(): - # Test with an invalid embedding function that causes AttributeError - config = Config(embedding_function="InvalidFunction", embedding_params={}) - embedding_function = get_embedding_function(config) - assert "SentenceTransformerEmbeddingFunction" in str(type(embedding_function)) - - -@pytest.mark.asyncio -async def test_wait_for_server_success(): - # Mock try_server to return True immediately - with patch("vectorcode.common.try_server") as mock_try_server: - mock_try_server.return_value = True - - # Should complete immediately without timeout - await wait_for_server("http://localhost:8000", timeout=1) - - # Verify try_server was called once - mock_try_server.assert_called_once_with("http://localhost:8000") - - -@pytest.mark.asyncio -async def test_wait_for_server_timeout(): - # Mock try_server to always return False - with patch("vectorcode.common.try_server") as mock_try_server: - mock_try_server.return_value = False - - # Should raise TimeoutError after 0.1 seconds (minimum timeout) - with pytest.raises(TimeoutError) as excinfo: - await wait_for_server("http://localhost:8000", timeout=0.1) - - assert "Server did not start within 0.1 seconds" in str(excinfo.value) - - # Verify try_server was called multiple times (due to retries) - assert mock_try_server.call_count > 1 - - -@pytest.mark.asyncio -async def test_client_manager_get_client(): - ClientManager().clear() - config = Config( - db_url="https://test_host:1234", db_path="test_db", project_root="test_proj" - ) - config1 = Config( - db_url="http://test_host1:1234", - db_path="test_db", - project_root="test_proj1", - db_settings={"anonymized_telemetry": True}, - ) - config1_alt = Config( - db_url="http://test_host1:1234", - db_path="test_db", - project_root="test_proj1", - db_settings={"anonymized_telemetry": True, "other_setting": "value"}, - ) - # Patch chromadb.AsyncHttpClient to avoid actual network calls - with ( - patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient, - patch("vectorcode.common.try_server", return_value=True), - ): - mock_client = MagicMock(spec=AsyncClientAPI, parent=AsyncClientAPI) - MockAsyncHttpClient.return_value = mock_client - - async with ( - ClientManager().get_client(config), - ): - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry - is False - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].chroma_server_ssl_enabled - is True - ) - - async with ( - ClientManager().get_client(config1) as client1, - ClientManager().get_client(config1_alt) as client1_alt, - ): - MockAsyncHttpClient.assert_called() - assert ( - MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host - == "test_host1" - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].chroma_server_http_port - == 1234 - ) - assert ( - MockAsyncHttpClient.call_args.kwargs[ - "settings" - ].anonymized_telemetry - is True - ) - - # Test with multiple db_settings, including an invalid one. The invalid one - # should be filtered out inside get_client. - assert id(client1_alt) == id(client1) - - -@pytest.mark.asyncio -async def test_client_manager_list_server_processes(): - async def _try_server(url): - return "127.0.0.1" in url or "localhost" in url - - async def _start_server(cfg): - return AsyncMock() - - with ( - tempfile.TemporaryDirectory() as temp_dir, - patch("vectorcode.common.start_server", side_effect=_start_server), - patch("vectorcode.common.try_server", side_effect=_try_server), - patch("vectorcode.common.ClientManager._create_client"), - ): - db_path = os.path.join(temp_dir, "db") - os.makedirs(db_path, exist_ok=True) - - async with ClientManager().get_client( - Config( - db_url="http://test_host:8001", - project_root="proj1", - db_path=db_path, - ) - ): - print(ClientManager().get_processes()) - async with ClientManager().get_client( - Config( - db_url="http://test_host:8002", - project_root="proj2", - db_path=db_path, - ) - ): - pass - assert len(ClientManager().get_processes()) == 2 - - -@pytest.mark.asyncio -async def test_client_manager_kill_servers(): - manager = ClientManager() - manager.clear() - - async def _try_server(url): - return "127.0.0.1" in url or "localhost" in url - - mock_process = AsyncMock() - mock_process.terminate = MagicMock() - with ( - patch("vectorcode.common.start_server", return_value=mock_process), - patch("vectorcode.common.try_server", side_effect=_try_server), - ): - manager._create_client = AsyncMock(return_value=AsyncMock()) - async with manager.get_client(Config(db_url="http://test_host:1081")): - pass - assert len(manager.get_processes()) == 1 - await manager.kill_servers() - mock_process.terminate.assert_called_once() - mock_process.wait.assert_awaited() diff --git a/tests/test_lsp.py b/tests/test_lsp.py index a97f9dfc..d3e458d0 100644 --- a/tests/test_lsp.py +++ b/tests/test_lsp.py @@ -1,6 +1,5 @@ import os -from contextlib import asynccontextmanager -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest from lsprotocol.types import WorkspaceFolder @@ -9,6 +8,15 @@ from vectorcode import __version__ from vectorcode.cli_utils import CliAction, Config, FilesAction, QueryInclude +from vectorcode.database.types import ( + CollectionContent as FileList, +) +from vectorcode.database.types import ( + CollectionInfo as Collection, +) +from vectorcode.database.types import ( + FileInCollection as File, +) from vectorcode.lsp_main import ( execute_command, lsp_start, @@ -48,16 +56,15 @@ async def test_execute_command_query(mock_language_server, mock_config): patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.ClientManager"), - patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock), + patch("vectorcode.lsp_main.get_database_connector"), patch( - "vectorcode.lsp_main.build_query_results", new_callable=AsyncMock - ) as mock_get_query_result_files, + "vectorcode.lsp_main.get_reranked_results", new_callable=AsyncMock + ) as mock_get_reranked_results, patch("os.path.isfile", return_value=True), patch("builtins.open", MagicMock()) as mock_open, ): mock_parse_cli_args.return_value = mock_config - mock_get_query_result_files.return_value = ["/test/file.txt"] + mock_get_reranked_results.return_value = [] # Configure the MagicMock object to return a string when read() is called mock_file = MagicMock() @@ -66,6 +73,7 @@ async def test_execute_command_query(mock_language_server, mock_config): # Ensure parsed_args.project_root is not None mock_config.project_root = "/test/project" + mock_config.query = ["test"] # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -77,6 +85,42 @@ async def test_execute_command_query(mock_language_server, mock_config): mock_language_server.progress.end.assert_called() +@pytest.mark.asyncio +async def test_execute_command_query_invalid_include(mock_language_server, mock_config): + with ( + patch( + "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock + ) as mock_parse_cli_args, + patch("vectorcode.lsp_main.get_database_connector"), + patch( + "vectorcode.lsp_main.get_reranked_results", new_callable=AsyncMock + ) as mock_get_reranked_results, + patch("os.path.isfile", return_value=True), + patch("builtins.open", MagicMock()) as mock_open, + ): + mock_parse_cli_args.return_value = mock_config + mock_get_reranked_results.return_value = [] + + # Configure the MagicMock object to return a string when read() is called + mock_file = MagicMock() + mock_file.__enter__.return_value.read.return_value = "{}" # Return valid JSON + mock_open.return_value = mock_file + + # Ensure parsed_args.project_root is not None + mock_config.project_root = "/test/project" + mock_config.query = ["test"] + mock_config.include = [QueryInclude.chunk, QueryInclude.document] + + # Mock the merge_from method + mock_config.merge_from = AsyncMock(return_value=mock_config) + + result = await execute_command(mock_language_server, ["query", "test"]) + + assert result == [] + mock_language_server.progress.begin.assert_called() + mock_language_server.progress.end.assert_called() + + @pytest.mark.asyncio async def test_execute_command_query_default_proj_root( mock_language_server, mock_config @@ -85,18 +129,17 @@ async def test_execute_command_query_default_proj_root( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.ClientManager"), - patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock), + patch("vectorcode.lsp_main.get_database_connector"), patch( - "vectorcode.lsp_main.build_query_results", new_callable=AsyncMock - ) as mock_get_query_result_files, + "vectorcode.lsp_main.get_reranked_results", new_callable=AsyncMock + ) as mock_get_reranked_results, patch("os.path.isfile", return_value=True), patch("builtins.open", MagicMock()) as mock_open, ): global DEFAULT_PROJECT_ROOT mock_config.project_root = None mock_parse_cli_args.return_value = mock_config - mock_get_query_result_files.return_value = ["/test/file.txt"] + mock_get_reranked_results.return_value = [] # Configure the MagicMock object to return a string when read() is called mock_file = MagicMock() @@ -105,6 +148,7 @@ async def test_execute_command_query_default_proj_root( # Ensure parsed_args.project_root is not None DEFAULT_PROJECT_ROOT = "/test/project" + mock_config.query = ["test"] # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -123,11 +167,10 @@ async def test_execute_command_query_workspace_dir(mock_language_server, mock_co patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.ClientManager"), - patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock), + patch("vectorcode.lsp_main.get_database_connector"), patch( - "vectorcode.lsp_main.build_query_results", new_callable=AsyncMock - ) as mock_get_query_result_files, + "vectorcode.lsp_main.get_reranked_results", new_callable=AsyncMock + ) as mock_get_reranked_results, patch("os.path.isfile", return_value=True), patch("os.path.isdir", return_value=True), patch("builtins.open", MagicMock()) as mock_open, @@ -135,8 +178,9 @@ async def test_execute_command_query_workspace_dir(mock_language_server, mock_co mock_language_server.workspace = MagicMock() mock_language_server.workspace.folders = {"dummy_dir": workspace_folder} mock_config.project_root = None + mock_config.query = ["test"] mock_parse_cli_args.return_value = mock_config - mock_get_query_result_files.return_value = ["/test/file.txt"] + mock_get_reranked_results.return_value = [] # Configure the MagicMock object to return a string when read() is called mock_file = MagicMock() @@ -151,9 +195,7 @@ async def test_execute_command_query_workspace_dir(mock_language_server, mock_co assert isinstance(result, list) mock_language_server.progress.begin.assert_called() mock_language_server.progress.end.assert_called() - assert ( - mock_get_query_result_files.call_args.args[1].project_root == "/dummy_dir" - ) + assert mock_config.project_root == "/dummy_dir" @pytest.mark.asyncio @@ -168,14 +210,21 @@ async def test_execute_command_ls(mock_language_server, mock_config): patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.ClientManager"), patch( - "vectorcode.lsp_main.get_collection_list", new_callable=AsyncMock - ) as mock_get_collection_list, - patch("vectorcode.common.get_embedding_function") as mock_embedding_function, - patch("vectorcode.common.get_collection") as mock_get_collection, + "vectorcode.lsp_main.get_database_connector" + ) as mock_get_database_connector, ): mock_parse_cli_args.return_value = mock_config + mock_db_connector = AsyncMock() + mock_db_connector.list_collections.return_value = [ + Collection( + id="dummy", + path="/test/project", + embedding_function="", + database_backend="", + ) + ] + mock_get_database_connector.return_value = mock_db_connector # Ensure parsed_args.project_root is not None mock_config.project_root = "/test/project" @@ -183,10 +232,6 @@ async def test_execute_command_ls(mock_language_server, mock_config): # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) - mock_get_collection_list.return_value = [{"project": "/test/project"}] - mock_embedding_function.return_value = MagicMock() # Mock embedding function - mock_get_collection.return_value = MagicMock() - result = await execute_command(mock_language_server, ["ls"]) assert isinstance(result, list) @@ -203,8 +248,6 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf mock_config.include_hidden = False mock_config.force = False # To test exclude_paths_by_spec path - # Files that load_files_from_include will return and expand_globs will process - dummy_initial_files = ["file_a.py", "file_b.txt"] # Files after expand_globs dummy_expanded_files = ["/test/project/file_a.py", "/test/project/file_b.txt"] @@ -213,104 +256,36 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch("vectorcode.lsp_main.ClientManager") as MockClientManager, patch( - "vectorcode.lsp_main.get_collection", new_callable=AsyncMock - ) as mock_get_collection, + "vectorcode.lsp_main.get_database_connector" + ) as mock_get_database_connector, patch( "vectorcode.lsp_main.expand_globs", new_callable=AsyncMock ) as mock_expand_globs, - patch( - "vectorcode.lsp_main.find_exclude_specs", return_value=[] - ) as mock_find_exclude_specs, - patch( - "vectorcode.lsp_main.exclude_paths_by_spec", - side_effect=lambda files, spec: files, - ) as mock_exclude_paths_by_spec, - patch( - "vectorcode.lsp_main.chunked_add", new_callable=AsyncMock - ) as mock_chunked_add, - patch( - "vectorcode.lsp_main.load_files_from_include", - return_value=dummy_initial_files, - ) as mock_load_files_from_include, - patch("os.cpu_count", return_value=1), # For asyncio.Semaphore - patch( - "vectorcode.lsp_main.remove_orphanes", new_callable=AsyncMock - ) as mock_remove_orphanes, + patch("os.path.isfile", lambda x: x in dummy_expanded_files), + patch("vectorcode.lsp_main.find_exclude_specs", return_value=[]), + patch("os.cpu_count", return_value=1), + patch("vectorcode.lsp_main.get_project_config", return_value=mock_config), ): - from unittest.mock import ANY - - from lsprotocol import types - - @asynccontextmanager - async def _get_client(*args): - yield mock_client - # Set return values for mocks mock_parse_cli_args.return_value = mock_config - mock_client = AsyncMock() - MockClientManager.return_value.get_client.side_effect = _get_client - mock_collection = AsyncMock() - mock_get_collection.return_value = mock_collection - mock_client.get_max_batch_size.return_value = 100 # Mock batch size + mock_db_connector = AsyncMock() + mock_get_database_connector.return_value = mock_db_connector mock_expand_globs.return_value = ( dummy_expanded_files # What expand_globs should return ) - # Mock merge_from as it's called mock_config.merge_from = AsyncMock(return_value=mock_config) - # Execute the command - result = await execute_command( - mock_language_server, ["vectorise", "/test/project"] - ) - assert isinstance(result, dict) and all( - k in ("add", "update", "removed", "failed", "skipped") - for k in result.keys() - ) - - # Assertions - mock_language_server.progress.create_async.assert_called_once() - mock_language_server.progress.begin.assert_called_once_with( - ANY, # progress_token - types.WorkDoneProgressBegin( - title="VectorCode", message="Vectorising files...", percentage=0 - ), - ) - - mock_load_files_from_include.assert_called_once_with( - str(mock_config.project_root) - ) - mock_expand_globs.assert_called_once_with( - dummy_initial_files, # Should be the result of load_files_from_include - recursive=mock_config.recursive, - include_hidden=mock_config.include_hidden, - ) - mock_find_exclude_specs.assert_called_once() - mock_exclude_paths_by_spec.assert_not_called() # Because mock_find_exclude_specs returns empty list (no specs to exclude by) - mock_client.get_max_batch_size.assert_called_once() - - # Check chunked_add calls - assert mock_chunked_add.call_count == len(dummy_expanded_files) - for file_path in dummy_expanded_files: - mock_chunked_add.assert_any_call( - file_path, - mock_collection, - ANY, # asyncio.Lock object - ANY, # stats dict - ANY, # stats_lock - ANY, - 100, # max_batch_size - ANY, # semaphore - ) - # Check progress report calls - assert mock_language_server.progress.report.call_count == len( - dummy_expanded_files + await execute_command( + mock_language_server, + ["vectorise", "/test/project", "file_a.py", "file_b.txt"], ) - mock_remove_orphanes.assert_called_once() - mock_language_server.progress.end.assert_called_once() + assert mock_db_connector.vectorise.await_args_list == [ + call(file_path="/test/project/file_a.py"), + call(file_path="/test/project/file_b.txt"), + ] @pytest.mark.asyncio @@ -328,15 +303,9 @@ async def test_execute_command_unsupported_action( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, - patch( - "vectorcode.lsp_main.get_collection", new_callable=AsyncMock - ) as mock_get_collection, ): mock_parse_cli_args.return_value = mock_config - mock_collection = MagicMock() - mock_get_collection.return_value = mock_collection - # Mock the merge_from method mock_config.merge_from = AsyncMock(return_value=mock_config) @@ -424,37 +393,28 @@ async def test_execute_command_files_ls(mock_language_server, mock_config: Confi mock_config.files_action = FilesAction.ls mock_config.project_root = "/test/project" - dummy_files = ["/test/project/file1.py", "/test/project/file2.txt"] - mock_client = AsyncMock() - mock_collection = AsyncMock() - + dummy_files = FileList( + files=[ + File(path="/test/project/file1.py", sha256="1"), + File(path="/test/project/file2.txt", sha256="2"), + ] + ) with ( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, patch( - "vectorcode.lsp_main.ClientManager._create_client", return_value=mock_client - ), - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.lsp_main.get_collection", return_value=mock_collection), - patch( - "vectorcode.lsp_main.list_collection_files", return_value=dummy_files - ) as mock_list_collection_files, + "vectorcode.lsp_main.get_database_connector" + ) as mock_get_database_connector, ): mock_parse_cli_args.return_value = mock_config - + mock_db_connector = AsyncMock() + mock_db_connector.list_collection_content.return_value = dummy_files + mock_get_database_connector.return_value = mock_db_connector mock_config.merge_from = AsyncMock(return_value=mock_config) - result = await execute_command(mock_language_server, ["files", "ls"]) - - assert result == dummy_files - mock_language_server.progress.create_async.assert_called_once() - - mock_list_collection_files.assert_called_once_with(mock_collection) - # For 'ls' action, progress.begin/end are not explicitly called in the lsp_main, - # but create_async is called before the match statement. - mock_language_server.progress.begin.assert_not_called() - mock_language_server.progress.end.assert_not_called() + await execute_command(mock_language_server, ["files", "ls"]) + mock_db_connector.list_collection_content.assert_called_once() @pytest.mark.asyncio @@ -468,18 +428,14 @@ async def test_execute_command_files_rm(mock_language_server, mock_config: Confi "/test/project/file_to_remove.py", "/test/project/another_file.txt", ] - mock_client = AsyncMock() - mock_collection = AsyncMock() with ( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, patch( - "vectorcode.lsp_main.ClientManager._create_client", return_value=mock_client - ), - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.lsp_main.get_collection", return_value=mock_collection), + "vectorcode.lsp_main.get_database_connector" + ) as mock_get_database_connector, patch( "os.path.isfile", side_effect=lambda x: x in expanded_paths or x in mock_config.rm_paths, @@ -490,17 +446,15 @@ async def test_execute_command_files_rm(mock_language_server, mock_config: Confi ), ): mock_parse_cli_args.return_value = mock_config - + mock_db_connector = AsyncMock() + mock_get_database_connector.return_value = mock_db_connector mock_config.merge_from = AsyncMock(return_value=mock_config) await execute_command( mock_language_server, ["files", "rm", "file_to_remove.py", "another_file.txt"], ) - - mock_collection.delete.assert_called_once_with( - where={"path": {"$in": expanded_paths}} - ) + mock_db_connector.delete.assert_called_once() @pytest.mark.asyncio @@ -512,18 +466,13 @@ async def test_execute_command_files_rm_no_files_to_remove( mock_config.project_root = "/test/project" mock_config.rm_paths = ["non_existent_file.py"] - mock_client = AsyncMock() - mock_collection = AsyncMock() - with ( patch( "vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock ) as mock_parse_cli_args, patch( - "vectorcode.lsp_main.ClientManager._create_client", return_value=mock_client - ), - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.lsp_main.get_collection", return_value=mock_collection), + "vectorcode.lsp_main.get_database_connector" + ) as mock_get_database_connector, patch("os.path.isfile", return_value=False), patch( "vectorcode.lsp_main.expand_path", @@ -531,12 +480,11 @@ async def test_execute_command_files_rm_no_files_to_remove( ), ): mock_parse_cli_args.return_value = mock_config - + mock_db_connector = AsyncMock() + mock_get_database_connector.return_value = mock_db_connector mock_config.merge_from = AsyncMock(return_value=mock_config) - result = await execute_command( + await execute_command( mock_language_server, ["files", "rm", "non_existent_file.py"] ) - - assert result is None - mock_collection.delete.assert_not_called() + mock_db_connector.assert_not_called() diff --git a/tests/test_main.py b/tests/test_main.py index b46c9989..4f675217 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -97,7 +97,6 @@ async def test_async_main_cli_action_chunks(monkeypatch): monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock(return_value=Config()) ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) return_code = await async_main() assert return_code == 0 @@ -141,7 +140,9 @@ async def test_async_main_cli_action_query(monkeypatch): "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) mock_final_configs = Config( - db_url="http://test_host:1234", action=CliAction.query, pipe=False + db_params={"db_url": "http://test_host:1234"}, + action=CliAction.query, + pipe=False, ) monkeypatch.setattr( "vectorcode.main.get_project_config", @@ -151,7 +152,6 @@ async def test_async_main_cli_action_query(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_query = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.query", mock_query) @@ -172,7 +172,9 @@ async def test_async_main_cli_action_vectorise(monkeypatch): "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) mock_final_configs = Config( - db_url="http://test_host:1234", action=CliAction.vectorise, include_hidden=True + db_params={"db_url": "http://test_host:1234"}, + action=CliAction.vectorise, + include_hidden=True, ) monkeypatch.setattr( "vectorcode.main.get_project_config", @@ -182,7 +184,6 @@ async def test_async_main_cli_action_vectorise(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_vectorise = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.vectorise", mock_vectorise) @@ -197,7 +198,9 @@ async def test_async_main_cli_action_drop(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) - mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.drop) + mock_final_configs = Config( + db_params={"db_url": "http://test_host:1234"}, action=CliAction.drop + ) monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock( @@ -206,7 +209,6 @@ async def test_async_main_cli_action_drop(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_drop = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.drop", mock_drop) @@ -221,7 +223,9 @@ async def test_async_main_cli_action_ls(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) - mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.ls) + mock_final_configs = Config( + db_params={"db_url": "http://test_host:1234"}, action=CliAction.ls + ) monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock( @@ -230,7 +234,6 @@ async def test_async_main_cli_action_ls(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_ls = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.ls", mock_ls) @@ -247,6 +250,10 @@ async def test_async_main_cli_action_files(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=cli_args) ) + monkeypatch.setattr( + "vectorcode.main.get_project_config", + AsyncMock(return_value=MagicMock(merge_from=AsyncMock(return_value=cli_args))), + ) assert await async_main() == 0 mock_files.assert_called_once() @@ -257,7 +264,9 @@ async def test_async_main_cli_action_update(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) - mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.update) + mock_final_configs = Config( + db_params={"db_url": "http://test_host:1234"}, action=CliAction.update + ) monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock( @@ -266,7 +275,6 @@ async def test_async_main_cli_action_update(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_update = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.update", mock_update) @@ -281,7 +289,9 @@ async def test_async_main_cli_action_clean(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) - mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.clean) + mock_final_configs = Config( + db_params={"db_url": "http://test_host:1234"}, action=CliAction.clean + ) monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock( @@ -290,7 +300,6 @@ async def test_async_main_cli_action_clean(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_clean = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.clean", mock_clean) @@ -305,7 +314,9 @@ async def test_async_main_exception_handling(monkeypatch): monkeypatch.setattr( "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) - mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.query) + mock_final_configs = Config( + db_params={"db_url": "http://test_host:1234"}, action=CliAction.query + ) monkeypatch.setattr( "vectorcode.main.get_project_config", AsyncMock( @@ -314,7 +325,6 @@ async def test_async_main_exception_handling(monkeypatch): ) ), ) - monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) mock_query = AsyncMock(side_effect=Exception("Test Exception")) monkeypatch.setattr("vectorcode.subcommands.query", mock_query) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b9a40bbf..336171fb 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,18 +1,15 @@ -import os -import tempfile -from argparse import ArgumentParser -from unittest.mock import AsyncMock, MagicMock, mock_open, patch +import sys +from unittest.mock import ANY, AsyncMock, MagicMock, patch -import numpy import pytest -from mcp import McpError +from mcp import ErrorData, McpError from vectorcode.cli_utils import Config -from vectorcode.common import ClientManager from vectorcode.mcp_main import ( get_arg_parser, list_collections, ls_files, + mcp_config, mcp_server, parse_cli_args, query_tool, @@ -23,407 +20,268 @@ @pytest.mark.asyncio async def test_list_collections_success(): - mock_client = AsyncMock() - with ( - patch("vectorcode.mcp_main.get_collections") as mock_get_collections, - patch("vectorcode.common.try_server", return_value=True), - patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client - ), - ): - mock_collection1 = AsyncMock() - mock_collection1.metadata = {"path": "path1"} - mock_collection2 = AsyncMock() - mock_collection2.metadata = {"path": "path2"} - - async def async_generator(): - yield mock_collection1 - yield mock_collection2 - - mock_get_collections.return_value = async_generator() + with patch("vectorcode.mcp_main.get_database_connector") as mock_get_db: + mock_db = AsyncMock() + mock_db.list_collections.return_value = [ + MagicMock(path="path1"), + MagicMock(path="path2"), + ] + mock_get_db.return_value = mock_db result = await list_collections() assert result == ["path1", "path2"] @pytest.mark.asyncio -async def test_list_collections_no_metadata(): - mock_client = AsyncMock() +async def test_query_tool_invalid_project_root(): + with patch("os.path.isdir", return_value=False): + with pytest.raises(McpError) as exc_info: + await query_tool( + n_query=5, + query_messages=["keyword1", "keyword2"], + project_root="invalid_path", + ) + assert exc_info.value.error.code == 1 + + +@pytest.mark.asyncio +async def test_query_tool_success(tmp_path): + mock_config = Config(project_root=tmp_path) with ( - patch("vectorcode.mcp_main.get_collections") as mock_get_collections, - patch("vectorcode.common.try_server", return_value=True), + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client + "vectorcode.subcommands.query.reranker.naive.NaiveReranker.rerank", + new_callable=AsyncMock, + return_value=[], ), ): - mock_collection1 = AsyncMock() - mock_collection1.metadata = {"path": "path1"} - mock_collection2 = AsyncMock() - mock_collection2.metadata = None + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_db._configs = mock_config + mock_db.query.return_value = [] - async def async_generator(cli): - yield mock_collection1 - yield mock_collection2 - - mock_get_collections.side_effect = async_generator - - result = await list_collections() - assert result == ["path1"] - - -@pytest.mark.asyncio -async def test_query_tool_invalid_project_root(): - with pytest.raises(McpError) as exc_info: await query_tool( - n_query=5, - query_messages=["keyword1", "keyword2"], - project_root="invalid_path", + n_query=2, query_messages=["keyword1"], project_root=str(tmp_path) ) - assert exc_info.value.error.code == 1 - assert ( - exc_info.value.error.message - == "Use `list_collections` tool to get a list of valid paths for this field." - ) + mock_db.query.assert_called_once() + assert mock_db._configs.n_result == 2 + assert mock_db._configs.query == ["keyword1"] @pytest.mark.asyncio -async def test_query_tool_success(): - mock_client = AsyncMock() - - with tempfile.TemporaryDirectory() as temp_dir: - os.chdir(temp_dir) - # Mock the collection's query method to return a valid QueryResult - mock_collection = AsyncMock() - mock_collection.query.return_value = { - "ids": [["id1", "id2"]], - "embeddings": None, - "metadatas": [[{"path": "file1.py"}, {"path": "file2.py"}]], - "documents": [["doc1", "doc2"]], - "uris": None, - "data": None, - "distances": [[0.1, 0.2]], # Valid distances - } - for i in range(1, 3): - with open(os.path.join(temp_dir, f"file{i}.py"), "w") as fin: - fin.writelines([f"doc{i}"]) - with ( - patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_collection", return_value=mock_collection), - patch( - "vectorcode.mcp_main.ClientManager._create_client", - return_value=mock_client, - ), - patch( - "vectorcode.subcommands.query.get_query_result_files" - ) as mock_get_query_result_files, - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.cli_utils.load_config_file") as mock_load_config_file, - ): - mock_config = Config( - chunk_size=100, overlap_ratio=0.1, reranker=None, project_root=temp_dir - ) - mock_load_config_file.return_value = mock_config - mock_get_project_config.return_value = mock_config - - # mock_get_collection.return_value = mock_collection - - mock_get_query_result_files.return_value = [ - os.path.join(temp_dir, i) for i in ("file1.py", "file2.py") - ] - - result = await query_tool( - n_query=2, query_messages=["keyword1"], project_root=temp_dir - ) - - assert len(result) == 2 +async def test_vectorise_tool_invalid_project_root(): + with patch("os.path.isdir", return_value=False): + with pytest.raises(McpError): + await vectorise_files(paths=["foo.bar"], project_root=".") @pytest.mark.asyncio -async def test_query_tool_collection_access_failure(): +async def test_vectorise_files_success(tmp_path): + mock_db = AsyncMock() + mock_config = Config(project_root=str(tmp_path)) + (tmp_path / "file1.py").touch() with ( - patch("os.path.isdir", return_value=True), - patch("vectorcode.mcp_main.get_project_config"), - patch("vectorcode.mcp_main.get_collection"), + patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db), + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), patch( - "vectorcode.mcp_main.ClientManager._create_client", - side_effect=Exception("Failed to connect"), - ), + "vectorcode.mcp_main.vectorise_worker", new_callable=AsyncMock + ) as mock_worker, ): - with pytest.raises(McpError): - await query_tool( - n_query=2, query_messages=["keyword1"], project_root="/valid/path" - ) + await vectorise_files( + paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path) + ) + mock_worker.assert_called_once() @pytest.mark.asyncio -async def test_query_tool_no_collection(): - mock_client = AsyncMock() +async def test_vectorise_files_with_ignore_spec(tmp_path): + project_root = tmp_path + (project_root / ".gitignore").write_text("ignored.py") + (project_root / "file1.py").touch() + (project_root / "ignored.py").touch() + + mock_db = AsyncMock() + mock_config = Config(project_root=str(project_root)) with ( - patch("os.path.isdir", return_value=True), - patch("vectorcode.mcp_main.get_project_config"), - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, + patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db), + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client - ), + "vectorcode.mcp_main.vectorise_worker", new_callable=AsyncMock + ) as mock_worker, ): - mock_get_collection.return_value = None - - with pytest.raises(McpError): - await query_tool( - n_query=2, query_messages=["keyword1"], project_root="/valid/path" - ) + await vectorise_files( + paths=[str(project_root / "file1.py"), str(project_root / "ignored.py")], + project_root=str(project_root), + ) + mock_worker.assert_called_once_with( + mock_db, str(project_root / "file1.py"), ANY, ANY, ANY + ) @pytest.mark.asyncio -async def test_vectorise_tool_invalid_project_root(): +async def test_mcp_server(tmp_path): with ( - patch("os.path.isdir", return_value=False), + patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, + patch("vectorcode.mcp_main.find_project_config_dir", return_value=tmp_path), + patch("vectorcode.mcp_main.get_project_config", return_value=Config()), ): - with pytest.raises(McpError): - await vectorise_files(paths=["foo.bar"], project_root=".") + await mcp_server() + assert mock_add_tool.call_count > 0 @pytest.mark.asyncio -async def test_vectorise_files_success(): - with tempfile.TemporaryDirectory() as temp_dir: - file_path = f"{temp_dir}/test_file.py" - with open(file_path, "w") as f: - f.write("def func(): pass") - mock_client = AsyncMock() - - mock_embedding_function = MagicMock(return_value=numpy.random.random((100,))) - with ( - patch("os.path.isdir", return_value=True), - patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, - patch( - "vectorcode.mcp_main.ClientManager._create_client", - return_value=mock_client, - ), - patch( - "vectorcode.subcommands.vectorise.get_embedding_function", - return_value=mock_embedding_function, - ), - patch("vectorcode.subcommands.vectorise.chunked_add"), - patch( - "vectorcode.subcommands.vectorise.hash_file", return_value="test_hash" - ), - patch("vectorcode.common.try_server", return_value=True), - ): - mock_config = Config(project_root=temp_dir) - mock_get_project_config.return_value = mock_config - - mock_collection = AsyncMock() - mock_collection.get.return_value = {"ids": [], "metadatas": []} - mock_get_collection.return_value = mock_collection - mock_client.get_max_batch_size.return_value = 100 - - result = await vectorise_files(paths=[file_path], project_root=temp_dir) - - assert result["add"] == 1 - mock_get_project_config.assert_called_once_with(temp_dir) - # Assert that the mocked get_collection was called with our mock_client. - mock_get_collection.assert_called_once() +async def test_mcp_server_ls_on_start(tmp_path): + with ( + patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, + patch("vectorcode.mcp_main.find_project_config_dir", return_value=tmp_path), + patch("vectorcode.mcp_main.get_project_config", return_value=Config()), + patch("vectorcode.mcp_main.list_collections", return_value=["path1", "path2"]), + ): + mcp_config.ls_on_start = True + await mcp_server() + assert mock_add_tool.call_count > 0 + mcp_config.ls_on_start = False @pytest.mark.asyncio -async def test_vectorise_files_collection_access_failure(): +async def test_ls_files_success(tmp_path): with ( - patch("os.path.isdir", return_value=True), - patch("vectorcode.mcp_main.get_project_config"), - patch( - "vectorcode.mcp_main.ClientManager._create_client", - side_effect=Exception("Client error"), - ), - patch("vectorcode.mcp_main.get_collection"), + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config") as mock_get_config, ): - with pytest.raises(McpError): - await vectorise_files(paths=["file.py"], project_root="/valid/path") + mock_db = AsyncMock() + mock_db.list_collection_content.return_value.files = [ + MagicMock(path="file1.py"), + MagicMock(path="file2.py"), + ] + mock_get_db.return_value = mock_db + mock_get_config.return_value = Config(project_root=str(tmp_path)) + result = await ls_files(project_root=str(tmp_path)) -@pytest.mark.asyncio -async def test_vectorise_files_with_exclude_spec(): - with tempfile.TemporaryDirectory() as temp_dir: - file1 = f"{temp_dir}/file1.py" - excluded_file = f"{temp_dir}/excluded.py" - exclude_spec_file = f"{temp_dir}/.vectorcode/vectorcode.exclude" - - os.makedirs(f"{temp_dir}/.vectorcode") - with open(file1, "w") as f: - f.write("content1") - with open(excluded_file, "w") as f: - f.write("content_excluded") - with open(exclude_spec_file, "w") as fin: - fin.writelines(["excluded.py"]) - - # Create mock file handles for specific file contents - mock_exclude_file_handle = mock_open(read_data="excluded.py").return_value - - def mock_open_side_effect(filename, *args, **kwargs): - if filename == exclude_spec_file: - return mock_exclude_file_handle - # For other files that might be opened, return a generic mock - return MagicMock() - - mock_client = AsyncMock() - with ( - patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, - patch( - "vectorcode.mcp_main.ClientManager._create_client", - return_value=mock_client, - ), - patch("vectorcode.mcp_main.chunked_add") as mock_chunked_add, - patch( - "vectorcode.subcommands.vectorise.hash_file", return_value="test_hash" - ), - patch("vectorcode.common.try_server", return_value=True), - ): - mock_config = Config(project_root=temp_dir) - mock_get_project_config.return_value = mock_config - - mock_collection = AsyncMock() - mock_collection.get.return_value = {"ids": [], "metadatas": []} - mock_get_collection.return_value = mock_collection - mock_client.get_max_batch_size.return_value = 100 - - await vectorise_files(paths=[file1, excluded_file], project_root=temp_dir) - - assert mock_chunked_add.call_count == 1 - call_args = [call[0][0] for call in mock_chunked_add.call_args_list] - assert excluded_file not in call_args + assert result == ["file1.py", "file2.py"] @pytest.mark.asyncio -async def test_mcp_server(): - mock_client = AsyncMock() +async def test_rm_files_success(tmp_path): + (tmp_path / "file1.py").touch() with ( - patch( - "vectorcode.mcp_main.find_project_config_dir" - ) as mock_find_project_config_dir, - patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, - patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, - patch("vectorcode.common.try_server", return_value=True), - patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client - ), + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config") as mock_get_config, ): - mock_find_project_config_dir.return_value = "/path/to/config" - mock_load_config_file.return_value = Config(project_root="/path/to/project") - - mock_collection = AsyncMock() - mock_get_collection.return_value = mock_collection + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_get_config.return_value = Config(project_root=str(tmp_path)) - await mcp_server() + await rm_files(files=[str(tmp_path / "file1.py")], project_root=str(tmp_path)) - assert mock_add_tool.call_count == 5 + mock_db.delete.assert_called_once() @pytest.mark.asyncio -async def test_mcp_server_ls_on_start(): - mock_client = AsyncMock() - mock_collection = AsyncMock() - +async def test_rm_files_no_files(tmp_path): with ( - patch( - "vectorcode.mcp_main.find_project_config_dir" - ) as mock_find_project_config_dir, - patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file, - patch("vectorcode.mcp_main.get_collection") as mock_get_collection, - patch( - "vectorcode.mcp_main.get_collections", spec=AsyncMock - ) as mock_get_collections, - patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool, - patch("vectorcode.common.try_server", return_value=True), - patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client - ), + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config") as mock_get_config, ): - from vectorcode.mcp_main import mcp_config + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_get_config.return_value = Config(project_root=str(tmp_path)) - mcp_config.ls_on_start = True - mock_find_project_config_dir.return_value = "/path/to/config" - mock_load_config_file.return_value = Config(project_root="/path/to/project") + await rm_files(files=["file1.py"], project_root=str(tmp_path)) - mock_collection.metadata = {"path": "/path/to/project"} - mock_get_collection.return_value = mock_collection + mock_db.delete.assert_not_called() - async def new_get_collections(clients): - yield mock_collection - mock_get_collections.side_effect = new_get_collections +def test_get_arg_parser(): + parser = get_arg_parser() + args = parser.parse_args(["-n", "5", "--ls-on-start"]) + assert args.number == 5 + assert args.ls_on_start is True - await mcp_server() - assert mock_add_tool.call_count == 5 - mock_get_collections.assert_called() +def test_parse_cli_args(): + with patch.object(sys, "argv", ["", "-n", "5", "--ls-on-start"]): + config = parse_cli_args() + assert config.n_results == 5 + assert config.ls_on_start is True @pytest.mark.asyncio -async def test_ls_files_success(): - ClientManager().clear() - mock_client = MagicMock() - mock_collection = MagicMock() - expected_files = ["/test/project/file1.py", "/test/project/dir/file2.txt"] - +async def test_vectorise_files_exception(tmp_path): + mock_db = AsyncMock() + mock_config = Config(project_root=str(tmp_path)) + (tmp_path / "file1.py").touch() with ( - patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, + patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db), + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client + "vectorcode.mcp_main.vectorise_worker", side_effect=Exception("test error") ), - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.mcp_main.get_collection", return_value=mock_collection), - patch( - "vectorcode.mcp_main.list_collection_files", return_value=expected_files - ) as mock_list_collection_files, - patch( - "vectorcode.cli_utils.expand_path", side_effect=lambda x, y: x - ), # Mock expand_path to return input ): - mock_get_project_config.return_value = Config(project_root="/test/project") - result = await ls_files(project_root="/test/project") - - assert result == expected_files - mock_get_project_config.assert_called_once_with("/test/project") - - mock_list_collection_files.assert_called_once_with(mock_collection) + with pytest.raises(McpError): + await vectorise_files( + paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path) + ) @pytest.mark.asyncio -async def test_rm_files_success(): - ClientManager().clear() - mock_client = MagicMock() - mock_collection = MagicMock() - files_to_remove = ["/test/project/file1.py", "/test/project/file2.txt"] - +async def test_query_tool_exception(tmp_path): + mock_config = Config(project_root=tmp_path) with ( - patch("os.path.isfile", side_effect=lambda x: x in files_to_remove), - patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config, + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), patch( - "vectorcode.mcp_main.ClientManager._create_client", return_value=mock_client + "vectorcode.mcp_main.get_reranked_results", + side_effect=Exception("test error"), ), - patch("vectorcode.common.try_server", return_value=True), - patch("vectorcode.mcp_main.get_collection", return_value=mock_collection), - patch("vectorcode.cli_utils.expand_path", side_effect=lambda x, y: x), ): - mock_get_project_config.return_value = Config(project_root="/test/project") - mock_collection.delete = AsyncMock() + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_db._configs = mock_config - await rm_files(files=files_to_remove, project_root="/test/project") + with pytest.raises(McpError): + await query_tool( + n_query=2, query_messages=["keyword1"], project_root=str(tmp_path) + ) - mock_get_project_config.assert_called_once_with("/test/project") - mock_collection.delete.assert_called_once_with( - where={"path": {"$in": files_to_remove}} - ) +@pytest.mark.asyncio +async def test_vectorise_files_mcp_exception(tmp_path): + mock_db = AsyncMock() + mock_config = Config(project_root=str(tmp_path)) + (tmp_path / "file1.py").touch() + with ( + patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db), + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), + patch( + "vectorcode.mcp_main.vectorise_worker", + side_effect=McpError(ErrorData(code=1, message="test error")), + ), + ): + with pytest.raises(McpError): + await vectorise_files( + paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path) + ) -def test_arg_parser(): - assert isinstance(get_arg_parser(), ArgumentParser) +@pytest.mark.asyncio +async def test_query_tool_mcp_exception(tmp_path): + mock_config = Config(project_root=tmp_path) + with ( + patch("vectorcode.mcp_main.get_database_connector") as mock_get_db, + patch("vectorcode.mcp_main.get_project_config", return_value=mock_config), + patch( + "vectorcode.mcp_main.get_reranked_results", + side_effect=McpError(ErrorData(code=1, message="test error")), + ), + ): + mock_db = AsyncMock() + mock_get_db.return_value = mock_db + mock_db._configs = mock_config -def test_args_parsing(): - args = ["--number", "15", "--ls-on-start"] - parsed = parse_cli_args(args) - assert parsed.n_results == 15 - assert parsed.ls_on_start + with pytest.raises(McpError): + await query_tool( + n_query=2, query_messages=["keyword1"], project_root=str(tmp_path) + )