From 9aea2724c4d0c614730dfcdf286823f439f38edc Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 3 Jul 2025 11:28:09 +0800 Subject: [PATCH 1/2] refactor(cli): Modify the logics of metadata processing to avoid passing unnecessary metadata. --- src/vectorcode/common.py | 23 ++++++++++++++--------- tests/test_common.py | 33 +++++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index c5f7cee4..012c9411 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -182,16 +182,11 @@ async def get_collection( logger.debug( f"Getting/Creating collection with the following metadata: {collection_meta}" ) - if not make_if_missing: - __COLLECTION_CACHE[full_path] = await client.get_collection( + try: + collection = await client.get_collection( collection_name, embedding_function ) - else: - collection = await client.get_or_create_collection( - collection_name, - metadata=collection_meta, - embedding_function=embedding_function, - ) + __COLLECTION_CACHE[full_path] = collection if ( not collection.metadata.get("hostname") == socket.gethostname() or collection.metadata.get("username") @@ -208,7 +203,17 @@ async def get_collection( raise IndexError( "Failed to create the collection due to hash collision. Please file a bug report." ) - __COLLECTION_CACHE[full_path] = collection + except ValueError: + if make_if_missing: + collection = await client.create_collection( + collection_name, + metadata=collection_meta, + embedding_function=embedding_function, + ) + + __COLLECTION_CACHE[full_path] = collection + else: + raise return __COLLECTION_CACHE[full_path] diff --git a/tests/test_common.py b/tests/test_common.py index c0dbdc5f..d5b6d292 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -224,6 +224,16 @@ async def test_get_collection(): with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: mock_client = MagicMock(spec=AsyncClientAPI) mock_collection = MagicMock() + mock_collection.metadata = { + "path": config.project_root, + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get( + "USER", os.environ.get("USERNAME", "DEFAULT_USER") + ), + "embedding_function": config.embedding_function, + "hnsw:M": 64, + } mock_client.get_collection.return_value = mock_collection MockAsyncHttpClient.return_value = mock_client @@ -252,7 +262,7 @@ async def test_get_collection(): "created-by": "VectorCode", } - async def mock_get_or_create_collection( + async def mock_create_collection( self, name=None, configuration=None, @@ -263,7 +273,7 @@ async def mock_get_or_create_collection( mock_collection.metadata.update(metadata or {}) return mock_collection - mock_client.get_or_create_collection.side_effect = mock_get_or_create_collection + mock_client.create_collection.side_effect = mock_create_collection MockAsyncHttpClient.return_value = mock_client collection = await get_collection(mock_client, config, make_if_missing=True) @@ -273,16 +283,18 @@ async def mock_get_or_create_collection( ) assert collection.metadata["created-by"] == "VectorCode" assert collection.metadata["hnsw:M"] == 64 - mock_client.get_or_create_collection.assert_called_once() + mock_client.create_collection.assert_called_once() mock_client.get_collection.side_effect = None # Test raising IndexError on hash collision. - with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: + with ( + patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient, + patch("socket.gethostname", side_effect=(lambda: "dummy")), + ): mock_client = MagicMock(spec=AsyncClientAPI) - mock_client.get_or_create_collection.side_effect = IndexError( - "Hash collision occurred" - ) + MockAsyncHttpClient.return_value = mock_client + mock_client.get_collection = AsyncMock(return_value=mock_collection) from vectorcode.common import __COLLECTION_CACHE __COLLECTION_CACHE.clear() @@ -315,7 +327,8 @@ async def test_get_collection_hnsw(): "embedding_function": "SentenceTransformerEmbeddingFunction", "path": "/test_project", } - mock_client.get_or_create_collection.return_value = mock_collection + mock_client.create_collection.return_value = mock_collection + mock_client.get_collection.side_effect = ValueError MockAsyncHttpClient.return_value = mock_client # Clear the collection cache to force creation @@ -332,9 +345,9 @@ async def test_get_collection_hnsw(): assert collection.metadata["created-by"] == "VectorCode" assert collection.metadata["hnsw:ef_construction"] == 200 assert collection.metadata["hnsw:M"] == 32 - mock_client.get_or_create_collection.assert_called_once() + mock_client.create_collection.assert_called_once() assert ( - mock_client.get_or_create_collection.call_args.kwargs["metadata"] + mock_client.create_collection.call_args.kwargs["metadata"] == mock_collection.metadata ) From 57994d2af1121b526e35c35f7d59070213bad7ad Mon Sep 17 00:00:00 2001 From: Zhe Yu Date: Thu, 3 Jul 2025 16:25:22 +0800 Subject: [PATCH 2/2] tests(cli): Add test for non-existing collection --- tests/test_common.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index d5b6d292..b12fc642 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -242,6 +242,18 @@ async def test_get_collection(): mock_client.get_collection.assert_called_once() mock_client.get_or_create_collection.assert_not_called() + # Test retrieving a non-existing collection + with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: + from vectorcode.common import __COLLECTION_CACHE + + __COLLECTION_CACHE.clear() + mock_client = MagicMock(spec=AsyncClientAPI) + mock_client.get_collection.side_effect = ValueError + MockAsyncHttpClient.return_value = mock_client + + with pytest.raises(ValueError): + collection = await get_collection(mock_client, config, False) + # Test creating a collection if it doesn't exist with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient: mock_client = MagicMock(spec=AsyncClientAPI)