Skip to content

Commit f3dfdae

Browse files
committed
feat(db): Implemented all methods in the ChromaDB 1.x connector
1 parent 19db7dd commit f3dfdae

File tree

1 file changed

+176
-9
lines changed

1 file changed

+176
-9
lines changed

src/vectorcode/database/chroma.py

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,23 @@
99

1010
import chromadb
1111
from filelock import AsyncFileLock
12+
from tree_sitter import Point
1213

1314
from vectorcode.chunking import Chunk, TreeSitterChunker
14-
from vectorcode.cli_utils import Config, LockManager, QueryInclude
15+
from vectorcode.cli_utils import (
16+
Config,
17+
LockManager,
18+
QueryInclude,
19+
expand_globs,
20+
expand_path,
21+
)
1522
from vectorcode.database import DatabaseConnectorBase
1623
from vectorcode.database.chroma_common import convert_chroma_query_results
1724
from vectorcode.database.errors import CollectionNotFoundError
1825
from vectorcode.database.types import (
1926
CollectionContent,
2027
CollectionInfo,
28+
FileInCollection,
2129
QueryResult,
2230
ResultType,
2331
VectoriseStats,
@@ -39,6 +47,7 @@
3947
from chromadb import Collection
4048
from chromadb.api import ClientAPI
4149
from chromadb.config import APIVersion, Settings
50+
from chromadb.errors import NotFoundError
4251

4352
logger = logging.getLogger(name=__name__)
4453

@@ -56,10 +65,23 @@
5665

5766

5867
class ChromaDBConnector(DatabaseConnectorBase):
68+
"""
69+
This is the connector layer for **ChromaDB 1.x**
70+
71+
Valid `db_params` options for ChromaDB 1.x:
72+
- `db_url`: default to `http://127.0.0.1:8000`
73+
- `db_path`: default to `~/.local/share/vectorcode/chromadb/`;
74+
- `db_log_path`: default to `~/.local/share/vectorcode/`
75+
- `db_settings`: See https://github.com/chroma-core/chroma/blob/508080841d2b2ebb3a9fbdc612087248df6f1382/chromadb/config.py#L120
76+
- `hnsw`: default to `{ "hnsw:M": 64 }`
77+
"""
78+
5979
def __init__(self, configs: Config):
6080
super().__init__(configs)
6181
params = _default_settings.copy()
6282
params.update(self._configs.db_params.copy())
83+
params["db_path"] = os.path.expanduser(params["db_path"])
84+
params["db_log_path"] = os.path.expanduser(params["db_log_path"])
6385
self._configs.db_params = params
6486

6587
self._lock: AsyncFileLock | None = None
@@ -162,7 +184,7 @@ async def _create_or_get_collection(
162184
if not allow_create:
163185
try:
164186
return client.get_collection(collection_id)
165-
except ValueError as e:
187+
except (ValueError, NotFoundError) as e:
166188
raise CollectionNotFoundError(
167189
f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}"
168190
) from e
@@ -266,17 +288,88 @@ def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata:
266288
return VectoriseStats(add=1)
267289

268290
async def delete(self) -> int:
269-
return await super().delete()
291+
project_root = self._configs.project_root
292+
collection = await self._create_or_get_collection(str(project_root), False)
293+
294+
rm_paths = self._configs.rm_paths
295+
if isinstance(rm_paths, str):
296+
rm_paths = [rm_paths]
297+
rm_paths = [
298+
str(expand_path(path=i, absolute=True))
299+
for i in await expand_globs(
300+
paths=self._configs.rm_paths,
301+
recursive=self._configs.recursive,
302+
include_hidden=self._configs.include_hidden,
303+
)
304+
]
305+
306+
files_in_collection = set(
307+
str(expand_path(i.path, True))
308+
for i in (
309+
await self.list_collection_content(what=ResultType.document)
310+
).files
311+
)
312+
313+
rm_paths = {
314+
str(expand_path(i, True))
315+
for i in rm_paths
316+
if os.path.isfile(i) and (i in files_in_collection)
317+
}
318+
319+
if rm_paths:
320+
async with self.maybe_lock():
321+
collection.delete(
322+
where=cast(chromadb.Where, {"path": {"$in": list(rm_paths)}})
323+
)
324+
return len(rm_paths)
270325

271326
async def drop(
272327
self, *, collection_id: str | None = None, collection_path: str | None = None
273328
):
274-
return await super().drop(
275-
collection_id=collection_id, collection_path=collection_path
276-
)
329+
collection_path = str(collection_path or self._configs.project_root)
330+
collection_id = collection_id or get_collection_id(collection_path)
331+
try:
332+
async with self.maybe_lock():
333+
await asyncio.to_thread(
334+
(await self.get_client()).delete_collection, collection_id
335+
)
336+
except ValueError as e:
337+
raise CollectionNotFoundError(
338+
f"Collection at {collection_path} is not found."
339+
) from e
277340

278341
async def get_chunks(self, file_path) -> list[Chunk]:
279-
return await super().get_chunks(file_path)
342+
file_path = os.path.abspath(file_path)
343+
try:
344+
collection = await self._create_or_get_collection(
345+
str(self._configs.project_root), False
346+
)
347+
except CollectionNotFoundError:
348+
logger.warning(
349+
f"There's no existing collection at {self._configs.project_root}."
350+
)
351+
return []
352+
353+
raw_results = collection.get(
354+
where={"path": file_path},
355+
include=["metadatas", "documents"],
356+
)
357+
assert raw_results["metadatas"] is not None
358+
assert raw_results["documents"] is not None
359+
360+
result: list[Chunk] = []
361+
for i in range(len(raw_results["ids"])):
362+
meta = raw_results["metadatas"][i]
363+
text = raw_results["documents"][i]
364+
_id = raw_results["ids"][i]
365+
chunk = Chunk(text=text, id=_id)
366+
if meta.get("start") is not None:
367+
chunk.start = Point(row=cast(int, meta["start"]), column=0)
368+
if meta.get("end") is not None:
369+
chunk.end = Point(row=cast(int, meta["end"]), column=0)
370+
371+
result.append(chunk)
372+
return result
280373

281374
async def list_collection_content(
282375
self,
@@ -285,7 +378,81 @@ async def list_collection_content(
285378
collection_id: str | None = None,
286379
collection_path: str | None = None,
287380
) -> CollectionContent:
288-
return CollectionContent(files=[], chunks=[])
381+
"""
382+
When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`.
383+
Otherwise, this method may populate only one of them to save waiting time.
384+
"""
385+
if collection_id is None:
386+
collection_path = str(collection_path or self._configs.project_root)
387+
collection = await self._create_or_get_collection(collection_path, False)
388+
else:
389+
try:
390+
collection = (await self.get_client()).get_collection(collection_id)
391+
except (ValueError, NotFoundError) as e:
392+
raise CollectionNotFoundError(
393+
f"There's no existing collection for {collection_path} in ChromaDB with the following setup: {self._configs.db_params}"
394+
) from e
395+
content = CollectionContent()
396+
raw_content = await asyncio.to_thread(
397+
collection.get,
398+
include=[
399+
"metadatas",
400+
"documents",
401+
],
402+
)
403+
metadatas = raw_content.get("metadatas", [])
404+
documents = raw_content.get("documents", [])
405+
ids = raw_content.get("ids", [])
406+
assert metadatas is not None
407+
assert documents is not None
408+
assert ids is not None
409+
if what is None or what == ResultType.document:
410+
content.files.extend(
411+
set(
412+
FileInCollection(
413+
path=str(i.get("path")), sha256=str(i.get("sha256"))
414+
)
415+
for i in metadatas
416+
)
417+
)
418+
if what is None or what == ResultType.chunk:
419+
for i in range(len(ids)):
420+
start, end = None, None
421+
if metadatas[i].get("start") is not None:
422+
start = Point(row=cast(int, metadatas[i]["start"]), column=0)
423+
if metadatas[i].get("end") is not None:
424+
end = Point(row=cast(int, metadatas[i]["end"]), column=0)
425+
content.chunks.append(
426+
Chunk(
427+
text=documents[i],
428+
path=str(metadatas[i].get("path", "")) or None,
429+
id=ids[i],
430+
start=start,
431+
end=end,
432+
)
433+
)
434+
435+
return content
289436

290437
async def list_collections(self) -> Sequence[CollectionInfo]:
291-
return []
438+
client = await self.get_client()
439+
result: list[CollectionInfo] = []
440+
for col in client.list_collections():
441+
project_root = str(col.metadata.get("path"))
442+
col_counts = await self.list_collection_content(
443+
collection_path=project_root
444+
)
445+
result.append(
446+
CollectionInfo(
447+
id=col.name,
448+
path=project_root,
449+
embedding_function=col.metadata.get(
450+
"embedding_function",
451+
Config().embedding_function, # fallback to default
452+
),
453+
database_backend="Chroma",
454+
file_count=len(col_counts.files),
455+
chunk_count=len(col_counts.chunks),
456+
)
457+
)
458+
return result

0 commit comments

Comments
 (0)