diff --git a/CHANGES.md b/CHANGES.md index 5cde77cfebf8..dda05276be48 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -75,6 +75,9 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Add pip-based install support for JupyterLab Sidepanel extension ([#35397](https://github.com/apache/beam/issues/#35397)). +* Milvus enrichment handler added (Python) ([#35216](https://github.com/apache/beam/pull/35216)). + Beam now supports Milvus enrichment handler capabilities for vector, keyword, + and hybrid search operations. ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py index 77eb27ed37ba..f6117a260a34 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py @@ -220,11 +220,11 @@ def format_query(self, chunks: List[Chunk]) -> str: # Create embeddings subquery for this group embedding_unions = [] for chunk in group_chunks: - if chunk.embedding is None or chunk.embedding.dense_embedding is None: + if not chunk.dense_embedding: raise ValueError(f"Chunk {chunk.id} missing embedding") embedding_str = ( f"SELECT '{chunk.id}' as id, " - f"{[float(x) for x in chunk.embedding.dense_embedding]} " + f"{[float(x) for x in chunk.dense_embedding]} " f"as embedding") embedding_unions.append(embedding_str) group_embeddings = " UNION ALL ".join(embedding_unions) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py new file mode 100644 index 000000000000..a0f597f5366f --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -0,0 +1,599 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Sequence +from dataclasses import dataclass +from dataclasses import field +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.protobuf.json_format import MessageToDict +from pymilvus import AnnSearchRequest +from pymilvus import Hit +from pymilvus import Hits +from pymilvus import MilvusClient +from pymilvus import SearchResult + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + + +class SearchStrategy(Enum): + """Search strategies for information retrieval. + + Args: + HYBRID: Combines vector and keyword search approaches. Leverages both + semantic understanding and exact matching. Typically provides the most + comprehensive results. Useful for queries with both conceptual and + specific keyword components. + VECTOR: Vector similarity search only. Based on semantic similarity between + query and documents. Effective for conceptual searches and finding related + content. Less sensitive to exact terminology than keyword search. + KEYWORD: Keyword/text search only. Based on exact or fuzzy matching of + specific terms. Effective for precise queries where exact wording matters. + Less effective for conceptual or semantic searches. + """ + HYBRID = "hybrid" + VECTOR = "vector" + KEYWORD = "keyword" + + +class KeywordSearchMetrics(Enum): + """Metrics for keyword search. + + Args: + BM25: Range [0 to ∞), Best Match 25 ranking algorithm for text relevance. + Combines term frequency, inverse document frequency, and document length. + Higher scores indicate greater relevance. Higher scores indicate greater + relevance. Takes into account diminishing returns of term frequency. + Balances between exact matching and semantic relevance. + """ + BM25 = "BM25" + + +class VectorSearchMetrics(Enum): + """Metrics for vector search. + + Args: + COSINE: Range [-1 to 1], higher values indicate greater similarity. Value 1 + means vectors point in identical direction. Value 0 means vectors are + perpendicular to each other (no relationship). Value -1 means vectors + point in exactly opposite directions. + EUCLIDEAN_DISTANCE (L2): Range [0 to ∞), lower values indicate greater + similarity. Value 0 means vectors are identical. Larger values mean more + dissimilarity between vectors. + INNER_PRODUCT (IP): Range varies based on vector magnitudes, higher values + indicate greater similarity. Value 0 means vectors are perpendicular to + each other. Positive values mean vectors share some directional component. + Negative values mean vectors point in opposing directions. + """ + COSINE = "COSINE" + EUCLIDEAN_DISTANCE = "L2" + INNER_PRODUCT = "IP" + + +class MilvusBaseRanker: + """Base class for ranking algorithms in Milvus hybrid search strategy.""" + def __int__(self): + return + + def dict(self): + return {} + + def __str__(self): + return self.dict().__str__() + + +@dataclass +class MilvusConnectionParameters: + """Parameters for establishing connections to Milvus servers. + + Args: + uri: URI endpoint for connecting to Milvus server in the format + "http(s)://hostname:port". + user: Username for authentication. Required if authentication is enabled and + not using token authentication. + password: Password for authentication. Required if authentication is enabled + and not using token authentication. + db_id: Database ID to connect to. Specifies which Milvus database to use. + Defaults to 'default'. + token: Authentication token as an alternative to username/password. + timeout: Connection timeout in seconds. Uses client default if None. + kwargs: Optional keyword arguments for additional connection parameters. + Enables forward compatibility. + """ + uri: str + user: str = field(default_factory=str) + password: str = field(default_factory=str) + db_id: str = "default" + token: str = field(default_factory=str) + timeout: Optional[float] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.uri: + raise ValueError("URI must be provided for Milvus connection") + + +@dataclass +class BaseSearchParameters: + """Base parameters for both vector and keyword search operations. + + Args: + anns_field: Approximate nearest neighbor search field indicates field name + containing the embedding to search. Required for both vector and keyword + search. + limit: Maximum number of results to return per query. Must be positive. + Defaults to 3 search results. + filter: Boolean expression string for filtering search results. + Example: 'price <= 1000 AND category == "electronics"'. + search_params: Additional search parameters specific to the search type. + Example: {"metric_type": VectorSearchMetrics.EUCLIDEAN_DISTANCE}. + consistency_level: Consistency level for read operations. + Options: "Strong", "Session", "Bounded", "Eventually". Defaults to + "Bounded" if not specified when creating the collection. + """ + anns_field: str + limit: int = 3 + filter: str = field(default_factory=str) + search_params: Dict[str, Any] = field(default_factory=dict) + consistency_level: Optional[str] = None + + def __post_init__(self): + if not self.anns_field: + raise ValueError( + "Approximate Nearest Neighbor Search (ANNS) field must be provided") + + if self.limit <= 0: + raise ValueError(f"Search limit must be positive, got {self.limit}") + + +@dataclass +class VectorSearchParameters(BaseSearchParameters): + """Parameters for vector similarity search operations. + + Inherits all parameters from BaseSearchParameters with the same semantics. + The anns_field should contain dense vector embeddings for this search type. + + Args: + kwargs: Optional keyword arguments for additional vector search parameters. + Enables forward compatibility. + + Note: + For inherited parameters documentation, see BaseSearchParameters. + """ + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class KeywordSearchParameters(BaseSearchParameters): + """Parameters for keyword/text search operations. + + This class inherits all parameters from BaseSearchParameters with the same + semantics. The anns_field should contain sparse vector embeddings content for + this search type. + + Args: + kwargs: Optional keyword arguments for additional keyword search parameters. + Enables forward compatibility. + + Note: + For inherited parameters documentation, see BaseSearchParameters. + """ + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridSearchParameters: + """Parameters for hybrid (vector + keyword) search operations. + + Args: + vector: Parameters for the vector search component. + keyword: Parameters for the keyword search component. + ranker: Ranker for combining vector and keyword search results. + Example: RRFRanker(k=100). + limit: Maximum number of results to return per query. Defaults to 3 search + results. + kwargs: Optional keyword arguments for additional hybrid search parameters. + Enables forward compatibility. + """ + vector: VectorSearchParameters + keyword: KeywordSearchParameters + ranker: MilvusBaseRanker + limit: int = 3 + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.vector or not self.keyword: + raise ValueError( + "Vector and keyword search parameters must be provided for " + "hybrid search") + + if not self.ranker: + raise ValueError("Ranker must be provided for hybrid search") + + if self.limit <= 0: + raise ValueError(f"Search limit must be positive, got {self.limit}") + + +SearchStrategyType = Union[VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters] + + +@dataclass +class MilvusSearchParameters: + """Parameters configuring Milvus search operations. + + This class encapsulates all parameters needed to execute searches against + Milvus collections, supporting vector, keyword, and hybrid search strategies. + + Args: + collection_name: Name of the collection to search in. + search_strategy: Type of search to perform (VECTOR, KEYWORD, or HYBRID). + partition_names: List of partition names to restrict the search to. If + empty, all partitions will be searched. + output_fields: List of field names to include in search results. If empty, + only primary fields including distances will be returned. + timeout: Search operation timeout in seconds. If not specified, the client's + default timeout is used. + round_decimal: Number of decimal places for distance/similarity scores. + Defaults to -1 means no rounding. + """ + collection_name: str + search_strategy: SearchStrategyType + partition_names: List[str] = field(default_factory=list) + output_fields: List[str] = field(default_factory=list) + timeout: Optional[float] = None + round_decimal: int = -1 + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + if not self.search_strategy: + raise ValueError("Search strategy must be provided") + + +@dataclass +class MilvusCollectionLoadParameters: + """Parameters that control how Milvus loads a collection into memory. + + This class provides fine-grained control over collection loading, which is + particularly important in resource-constrained environments. Proper + configuration can significantly reduce memory usage and improve query + performance by loading only necessary data. + + Args: + refresh: If True, forces a reload of the collection even if already loaded. + Ensures the most up-to-date data is in memory. + resource_groups: List of resource groups to load the collection into. Can be + used for load balancing across multiple query nodes. + load_fields: Specify which fields to load into memory. Loading only + necessary fields reduces memory usage. If empty, all fields loaded. + skip_load_dynamic_field: If True, dynamic/growing fields will not be loaded + into memory. Saves memory when dynamic fields aren't needed. + kwargs: Optional keyword arguments for additional collection load + parameters. Enables forward compatibility. + """ + refresh: bool = field(default_factory=bool) + resource_groups: List[str] = field(default_factory=list) + load_fields: List[str] = field(default_factory=list) + skip_load_dynamic_field: bool = field(default_factory=bool) + kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MilvusSearchResult: + """Search result from Milvus per chunk. + + Args: + id: List of entity IDs returned from the search. Can be either string or + integer IDs. + distance: List of distances/similarity scores for each returned entity. + fields: List of dictionaries containing additional field values for each + entity. Each dictionary corresponds to one returned entity. + """ + id: List[Union[str, int]] = field(default_factory=list) + distance: List[float] = field(default_factory=list) + fields: List[Dict[str, Any]] = field(default_factory=list) + + +InputT, OutputT = Union[Chunk, List[Chunk]], List[Tuple[Chunk, Dict[str, Any]]] + + +class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]): + """Enrichment handler for Milvus vector database searches. + + This handler is designed to work with the + :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` transform. + It enables enriching data through vector similarity, keyword, or hybrid + searches against Milvus collections. + + The handler supports different search strategies: + * Vector search - For finding similar embeddings based on vector similarity + * Keyword search - For text-based retrieval using BM25 or other text metrics + * Hybrid search - For combining vector and keyword search results + + This handler queries the Milvus database per element by default. To enable + batching for improved performance, set the `min_batch_size` and + `max_batch_size` parameters. These control the batching behavior in the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + For memory-intensive operations, the handler allows fine-grained control over + collection loading through the `collection_load_parameters`. + """ + def __init__( + self, + connection_parameters: MilvusConnectionParameters, + search_parameters: MilvusSearchParameters, + *, + collection_load_parameters: Optional[MilvusCollectionLoadParameters], + min_batch_size: int = 1, + max_batch_size: int = 1000, + **kwargs): + """ + Example Usage: + connection_paramters = MilvusConnectionParameters( + uri="http://localhost:19530") + search_parameters = MilvusSearchParameters( + collection_name="my_collection", + search_strategy=VectorSearchParameters(anns_field="embedding")) + collection_load_parameters = MilvusCollectionLoadParameters( + load_fields=["embedding", "metadata"]), + milvus_handler = MilvusSearchEnrichmentHandler( + connection_paramters, + search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=10, + max_batch_size=100) + + Args: + connection_parameters (MilvusConnectionParameters): Configuration for + connecting to the Milvus server, including URI, credentials, and + connection options. + search_parameters (MilvusSearchParameters): Configuration for search + operations, including collection name, search strategy, and output + fields. + collection_load_parameters (Optional[MilvusCollectionLoadParameters]): + Parameters controlling how collections are loaded into memory, which can + significantly impact resource usage and performance. + min_batch_size (int): Minimum number of elements to batch together when + querying Milvus. Default is 1 (no batching when max_batch_size is 1). + max_batch_size (int): Maximum number of elements to batch together.Default + is 1000. Higher values may improve throughput but increase memory usage. + **kwargs: Additional keyword arguments for Milvus Enrichment Handler. + + Note: + * For large collections, consider setting appropriate values in + collection_load_parameters to reduce memory usage. + * The search_strategy in search_parameters determines the type of search + (vector, keyword, or hybrid) and associated parameters. + * Batching can significantly improve performance but requires more memory. + """ + self._connection_parameters = connection_parameters + self._search_parameters = search_parameters + self._collection_load_parameters = collection_load_parameters + if not self._collection_load_parameters: + self._collection_load_parameters = MilvusCollectionLoadParameters() + self._batching_kwargs = { + 'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size + } + self.kwargs = kwargs + self.join_fn = join_fn + self.use_custom_types = True + + def __enter__(self): + connection_params = unpack_dataclass_with_kwargs( + self._connection_parameters) + collection_load_params = unpack_dataclass_with_kwargs( + self._collection_load_parameters) + self._client = MilvusClient(**connection_params) + self._client.load_collection( + collection_name=self.collection_name, + partition_names=self.partition_names, + **collection_load_params) + + def __call__(self, request: Union[Chunk, List[Chunk]], *args, + **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + reqs = request if isinstance(request, list) else [request] + search_result = self._search_documents(reqs) + return self._get_call_response(reqs, search_result) + + def _search_documents(self, chunks: List[Chunk]): + if isinstance(self.search_strategy, HybridSearchParameters): + data = self._get_hybrid_search_data(chunks) + return self._client.hybrid_search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + reqs=data, + ranker=self.search_strategy.ranker, + limit=self.search_strategy.limit, + **self.search_strategy.kwargs) + elif isinstance(self.search_strategy, VectorSearchParameters): + data = list(map(self._get_vector_search_data, chunks)) + vector_search_params = unpack_dataclass_with_kwargs(self.search_strategy) + return self._client.search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + data=data, + **vector_search_params) + elif isinstance(self.search_strategy, KeywordSearchParameters): + data = list(map(self._get_keyword_search_data, chunks)) + keyword_search_params = unpack_dataclass_with_kwargs(self.search_strategy) + return self._client.search( + collection_name=self.collection_name, + partition_names=self.partition_names, + output_fields=self.output_fields, + timeout=self.timeout, + round_decimal=self.round_decimal, + data=data, + **keyword_search_params) + else: + raise ValueError( + f"Not supported search strategy yet: {self.search_strategy}") + + def _get_hybrid_search_data(self, chunks: List[Chunk]): + vector_search_data = list(map(self._get_vector_search_data, chunks)) + keyword_search_data = list(map(self._get_keyword_search_data, chunks)) + + vector_search_req = AnnSearchRequest( + data=vector_search_data, + anns_field=self.search_strategy.vector.anns_field, + param=self.search_strategy.vector.search_params, + limit=self.search_strategy.vector.limit, + expr=self.search_strategy.vector.filter) + + keyword_search_req = AnnSearchRequest( + data=keyword_search_data, + anns_field=self.search_strategy.keyword.anns_field, + param=self.search_strategy.keyword.search_params, + limit=self.search_strategy.keyword.limit, + expr=self.search_strategy.keyword.filter) + + reqs = [vector_search_req, keyword_search_req] + return reqs + + def _get_vector_search_data(self, chunk: Chunk): + if not chunk.dense_embedding: + raise ValueError( + f"Chunk {chunk.id} missing dense embedding required for vector search" + ) + return chunk.dense_embedding + + def _get_keyword_search_data(self, chunk: Chunk): + if not chunk.content.text and not chunk.sparse_embedding: + raise ValueError( + f"Chunk {chunk.id} missing both text content and sparse embedding " + "required for keyword search") + + sparse_embedding = self.convert_sparse_embedding_to_milvus_format( + chunk.sparse_embedding) + + return chunk.content.text or sparse_embedding + + def _get_call_response( + self, chunks: List[Chunk], search_result: SearchResult[Hits]): + response = [] + for i in range(len(chunks)): + chunk = chunks[i] + hits: Hits = search_result[i] + result = MilvusSearchResult() + for i in range(len(hits)): + hit: Hit = hits[i] + normalized_fields = self._normalize_milvus_fields(hit.fields) + result.id.append(hit.id) + result.distance.append(hit.distance) + result.fields.append(normalized_fields) + response.append((chunk, result.__dict__)) + return response + + def _normalize_milvus_fields(self, fields: Dict[str, Any]): + normalized_fields = {} + for field, value in fields.items(): + value = self._normalize_milvus_value(value) + normalized_fields[field] = value + return normalized_fields + + def _normalize_milvus_value(self, value: Any): + # Convert Milvus-specific types to Python native types. + neither_str_nor_dict_nor_bytes = not isinstance(value, (str, dict, bytes)) + if isinstance(value, Sequence) and neither_str_nor_dict_nor_bytes: + return list(value) + elif hasattr(value, 'DESCRIPTOR'): + # Handle protobuf messages. + return MessageToDict(value) + else: + # Keep other types as they are. + return value + + def convert_sparse_embedding_to_milvus_format( + self, sparse_vector: Tuple[List[int], List[float]]) -> Dict[int, float]: + if not sparse_vector: + return None + # Converts sparse embedding from (indices, values) tuple format to + # Milvus-compatible values dict format {dimension_index: value, ...}. + indices, values = sparse_vector + return {int(idx): float(val) for idx, val in zip(indices, values)} + + @property + def collection_name(self): + """Getter method for collection_name property""" + return self._search_parameters.collection_name + + @property + def search_strategy(self): + """Getter method for search_strategy property""" + return self._search_parameters.search_strategy + + @property + def partition_names(self): + """Getter method for partition_names property""" + return self._search_parameters.partition_names + + @property + def output_fields(self): + """Getter method for output_fields property""" + return self._search_parameters.output_fields + + @property + def timeout(self): + """Getter method for search timeout property""" + return self._search_parameters.timeout + + @property + def round_decimal(self): + """Getter method for search round_decimal property""" + return self._search_parameters.round_decimal + + def __exit__(self, exc_type, exc_val, exc_tb): + self._client.release_collection(self.collection_name) + self._client.close() + self._client = None + + def batch_elements_kwargs(self) -> Dict[str, int]: + """Returns kwargs for beam.BatchElements.""" + return self._batching_kwargs + + +def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding: + left.metadata['enrichment_data'] = right + return left + + +def unpack_dataclass_with_kwargs(dataclass_instance): + # Create a copy of the dataclass's __dict__. + params_dict: dict = dataclass_instance.__dict__.copy() + + # Extract the nested kwargs dictionary. + nested_kwargs = params_dict.pop('kwargs', {}) + + # Merge the dictionaries, with nested_kwargs taking precedence + # in case of duplicate keys. + return {**params_dict, **nested_kwargs} diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py new file mode 100644 index 000000000000..ebc05722841c --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -0,0 +1,1371 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import logging +import os +import platform +import re +import socket +import tempfile +import unittest +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import cast + +import pytest +import yaml +from pymilvus import CollectionSchema +from pymilvus import DataType +from pymilvus import FieldSchema +from pymilvus import Function +from pymilvus import FunctionType +from pymilvus import MilvusClient +from pymilvus import RRFRanker +from pymilvus.milvus_client import IndexParams +from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES +from testcontainers.core.config import testcontainers_config +from testcontainers.core.generic import DbContainer +from testcontainers.milvus import MilvusContainer + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that + +try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.ml.rag.enrichment.milvus_search import ( + MilvusSearchEnrichmentHandler, + MilvusConnectionParameters, + MilvusSearchParameters, + MilvusCollectionLoadParameters, + VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters, + VectorSearchMetrics, + KeywordSearchMetrics) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + +_LOGGER = logging.getLogger(__name__) + + +def _construct_index_params(): + index_params = IndexParams() + + # Milvus doesn't support multiple indexes on the same field. This is a + # limitation of Milvus - someone can only create one index per field as yet. + + # Cosine similarity index on first dense embedding field + index_params.add_index( + field_name="dense_embedding_cosine", + index_name="dense_embedding_cosine_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.COSINE.value, + params={"nlist": 1}) + + # Euclidean distance index on second dense embedding field + index_params.add_index( + field_name="dense_embedding_euclidean", + index_name="dense_embedding_euclidean_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.EUCLIDEAN_DISTANCE.value, + params={"nlist": 1}) + + # Inner product index on third dense embedding field + index_params.add_index( + field_name="dense_embedding_inner_product", + index_name="dense_embedding_inner_product_ivf_flat", + index_type="IVF_FLAT", + metric_type=VectorSearchMetrics.INNER_PRODUCT.value, + params={"nlist": 1}) + + index_params.add_index( + field_name="sparse_embedding_inner_product", + index_name="sparse_embedding_inner_product_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type=VectorSearchMetrics.INNER_PRODUCT.value, + params={ + "inverted_index_algo": "TAAT_NAIVE", + }) + + # BM25 index on sparse_embedding field. + # + # For deterministic testing results + # 1. Using TAAT_NAIVE: Most predictable algorithm that processes each term + # completely before moving to the next. + # 2. Using k1=1: Moderate term frequency weighting – repeated terms matter + # but with diminishing returns. + # 3. Using b=0: No document length normalization – longer documents not + # penalized. + # This combination provides maximum transparency and predictability for + # test assertions. + index_params.add_index( + field_name="sparse_embedding_bm25", + index_name="sparse_embedding_bm25_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type=KeywordSearchMetrics.BM25.value, + params={ + "inverted_index_algo": "TAAT_NAIVE", + "bm25_k1": 1, + "bm25_b": 0, + }) + + return index_params + + +@dataclass +class MilvusITDataConstruct: + id: int + content: str + domain: str + cost: int + metadata: dict + tags: list[str] + dense_embedding: list[float] + sparse_embedding: dict + vocabulary: Dict[str, int] = field(default_factory=dict) + + def __getitem__(self, key): + return getattr(self, key) + + +MILVUS_IT_CONFIG = { + "collection_name": "docs_catalog", + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema( + name="content", + dtype=DataType.VARCHAR, + max_length=512, + enable_analyzer=True), + FieldSchema(name="domain", dtype=DataType.VARCHAR, max_length=128), + FieldSchema(name="cost", dtype=DataType.INT32), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema( + name="tags", + dtype=DataType.ARRAY, + element_type=DataType.VARCHAR, + max_length=64, + max_capacity=64), + FieldSchema( + name="dense_embedding_cosine", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="dense_embedding_euclidean", + dtype=DataType.FLOAT_VECTOR, + dim=3), + FieldSchema( + name="dense_embedding_inner_product", + dtype=DataType.FLOAT_VECTOR, + dim=3), + FieldSchema( + name="sparse_embedding_bm25", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema( + name="sparse_embedding_inner_product", + dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "functions": [ + Function( + name="content_bm25_emb", + input_field_names=["content"], + output_field_names=["sparse_embedding_bm25"], + function_type=FunctionType.BM25) + ], + "index": _construct_index_params, + "corpus": [ + MilvusITDataConstruct( + id=1, + content="This is a test document", + domain="medical", + cost=49, + metadata={"language": "en"}, + tags=["healthcare", "patient", "clinical"], + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding={ + 1: 0.05, 2: 0.41, 3: 0.05, 4: 0.41 + }), + MilvusITDataConstruct( + id=2, + content="Another test document", + domain="legal", + cost=75, + metadata={"language": "en"}, + tags=["contract", "law", "regulation"], + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding={ + 1: 0.07, 3: 3.07, 0: 0.53 + }), + MilvusITDataConstruct( + id=3, + content="وثيقة اختبار", + domain="financial", + cost=149, + metadata={"language": "ar"}, + tags=["banking", "investment", "arabic"], + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding={ + 6: 0.62, 5: 0.62 + }) + ], + "vocabulary": { + "this": 4, + "is": 2, + "test": 3, + "document": 1, + "another": 0, + "وثيقة": 6, + "اختبار": 5 + } +} + + +@dataclass +class MilvusDBContainerInfo: + container: DbContainer + host: str + port: int + user: Optional[str] = "" + password: Optional[str] = "" + token: Optional[str] = "" + id: Optional[str] = "default" + + @property + def uri(self) -> str: + return f"http://{self.host}:{self.port}" + + +class CustomMilvusContainer(MilvusContainer): + def __init__( + self, + image: str, + service_container_port, + healthcheck_container_port, + **kwargs, + ) -> None: + # Skip the parent class's constructor and go straight to + # GenericContainer. + super(MilvusContainer, self).__init__(image=image, **kwargs) + self.port = service_container_port + self.healthcheck_port = healthcheck_container_port + self.with_exposed_ports(service_container_port, healthcheck_container_port) + + # Get free host ports. + service_host_port = MilvusEnrichmentTestHelper.find_free_port() + healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port() + + # Bind container and host ports. + self.with_bind_ports(service_container_port, service_host_port) + self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) + self.cmd = "milvus run standalone" + + # Set environment variables needed for Milvus. + envs = { + "ETCD_USE_EMBED": "true", + "ETCD_DATA_DIR": "/var/lib/milvus/etcd", + "COMMON_STORAGETYPE": "local", + "METRICS_PORT": str(healthcheck_container_port) + } + for env, value in envs.items(): + self.with_env(env, value) + + +class MilvusEnrichmentTestHelper: + @staticmethod + def start_db_container( + image="milvusdb/milvus:v2.5.10", + max_vec_fields=5, + vector_client_max_retries=3, + tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]: + service_container_port = MilvusEnrichmentTestHelper.find_free_port() + healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port() + user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml + with user_yaml_creator(service_container_port, max_vec_fields) as cfg: + info = None + testcontainers_config.max_tries = tc_max_retries + for i in range(vector_client_max_retries): + try: + vector_db_container = CustomMilvusContainer( + image=image, + service_container_port=service_container_port, + healthcheck_container_port=healthcheck_container_port) + vector_db_container = vector_db_container.with_volume_mapping( + cfg, "/milvus/configs/user.yaml") + vector_db_container.start() + host = vector_db_container.get_container_host_ip() + port = vector_db_container.get_exposed_port(service_container_port) + info = MilvusDBContainerInfo(vector_db_container, host, port) + testcontainers_config.max_tries = TC_MAX_TRIES + _LOGGER.info( + "milvus db container started successfully on %s.", info.uri) + break + except Exception as e: + stdout_logs, stderr_logs = vector_db_container.get_logs() + stdout_logs = stdout_logs.decode("utf-8") + stderr_logs = stderr_logs.decode("utf-8") + _LOGGER.warning( + "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " + "STDOUT logs:\n%s\nSTDERR logs:\n%s", + i + 1, + vector_client_max_retries, + e, + stdout_logs, + stderr_logs) + if i == vector_client_max_retries - 1: + _LOGGER.error( + "Unable to start milvus db container for I/O tests after %d " + "retries. Tests cannot proceed. STDOUT logs:\n%s\n" + "STDERR logs:\n%s", + vector_client_max_retries, + stdout_logs, + stderr_logs) + raise e + return info + + @staticmethod + def stop_db_container(db_info: MilvusDBContainerInfo): + if db_info is None: + _LOGGER.warning("Milvus db info is None. Skipping stop operation.") + return + try: + _LOGGER.debug("Stopping milvus db container.") + db_info.container.stop() + _LOGGER.info("milvus db container stopped successfully.") + except Exception as e: + _LOGGER.warning( + "Error encountered while stopping milvus db container: %s", e) + + @staticmethod + def initialize_db_with_data(connc_params: MilvusConnectionParameters): + # Open the connection to the milvus db. + client = MilvusClient(**connc_params.__dict__) + + # Configure schema. + field_schemas: List[FieldSchema] = cast( + List[FieldSchema], MILVUS_IT_CONFIG["fields"]) + schema = CollectionSchema( + fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"]) + + # Create collection with the schema. + collection_name = MILVUS_IT_CONFIG["collection_name"] + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_IT_CONFIG["index"]) + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Assert that collection was created. + collection_error = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), collection_error + + # Gather all fields we have excluding 'sparse_embedding_bm25' special field. + fields = list(map(lambda field: field.name, field_schemas)) + + # Prep data for indexing. Currently we can't insert sparse vectors for BM25 + # sparse embedding field as it would be automatically generated by Milvus + # through the registered BM25 function. + data_ready_to_index = [] + for doc in MILVUS_IT_CONFIG["corpus"]: + item = {} + for field in fields: + if field.startswith("dense_embedding"): + item[field] = doc["dense_embedding"] + elif field == "sparse_embedding_inner_product": + item[field] = doc["sparse_embedding"] + elif field == "sparse_embedding_bm25": + # It is automatically generated by Milvus from the content field. + continue + else: + item[field] = doc[field] + data_ready_to_index.append(item) + + # Index data. + result = client.insert( + collection_name=collection_name, data=data_ready_to_index) + + # Assert that the intended data has been properly indexed. + insertion_err = f'failed to insert the {result["insert_count"]} data points' + assert result["insert_count"] == len(data_ready_to_index), insertion_err + + # Release the collection from memory. It will be loaded lazily when the + # enrichment handler is invoked. + client.release_collection(collection_name) + + # Close the connection to the Milvus database, as no further preparation + # operations are needed before executing the enrichment handler. + client.close() + + return collection_name + + @staticmethod + def find_free_port(): + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Bind to port 0, which asks OS to assign a free port. + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Return the port number assigned by OS. + return s.getsockname()[1] + + @staticmethod + @contextlib.contextmanager + def create_user_yaml(service_port: int, max_vector_field_num=5): + """Creates a temporary user.yaml file for Milvus configuration. + + This user yaml file overrides Milvus default configurations. It sets + the Milvus service port to the specified container service port. The + default for maxVectorFieldNum is 4, but we need 5 + (one unique field for each metric). + + Args: + service_port: Port number for the Milvus service. + max_vector_field_num: Max number of vec fields allowed per collection. + + Yields: + str: Path to the created temporary yaml file. + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', + delete=False) as temp_file: + # Define the content for user.yaml. + user_config = { + 'proxy': { + 'maxVectorFieldNum': max_vector_field_num, 'port': service_port + } + } + + # Write the content to the file. + yaml.dump(user_config, temp_file, default_flow_style=False) + path = temp_file.name + + try: + yield path + finally: + if os.path.exists(path): + os.remove(path) + + +@pytest.mark.uses_testcontainer +@unittest.skipUnless( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusSearchEnrichment(unittest.TestCase): + """Tests for search functionality across all search strategies""" + + _db: MilvusDBContainerInfo + _version = "milvusdb/milvus:v2.5.10" + + @classmethod + def setUpClass(cls): + try: + cls._db = MilvusEnrichmentTestHelper.start_db_container( + cls._version, vector_client_max_retries=1, tc_max_retries=1) + cls._connection_params = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_id=cls._db.id, + token=cls._db.token) + cls._collection_load_params = MilvusCollectionLoadParameters() + cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( + cls._connection_params) + except Exception as e: + pytest.skip( + f"Skipping all tests in {cls.__name__} due to DB startup failure: {e}" + ) + + @classmethod + def tearDownClass(cls): + MilvusEnrichmentTestHelper.stop_db_container(cls._db) + cls._db = None + + def test_invalid_query_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + existent_field = "dense_embedding_cosine" + + test_chunks = [ + Chunk( + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()) + ] + + search_parameters = MilvusSearchParameters( + collection_name=non_existent_collection, + search_strategy=VectorSearchParameters(anns_field=existent_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + expect_err_msg_contains = "collection not found" + self.assertIn(expect_err_msg_contains, str(context.exception)) + + def test_invalid_query_on_non_existent_field(self): + non_existent_field = "nonexistent_column" + existent_collection = MILVUS_IT_CONFIG["collection_name"] + + test_chunks = [ + Chunk( + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()) + ] + + search_parameters = MilvusSearchParameters( + collection_name=existent_collection, + search_strategy=VectorSearchParameters(anns_field=non_existent_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + expect_err_msg_contains = f"fieldName({non_existent_field}) not found" + self.assertIn(expect_err_msg_contains, str(context.exception)) + + def test_empty_input_chunks(self): + test_chunks = [] + anns_field = "dense_embedding_cosine" + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=VectorSearchParameters(anns_field=anns_field)) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + self._connection_params, + search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_filtered_search_with_cosine_similarity_and_batching(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + filter_condition = 'metadata["language"] == "en"' + + anns_field = "dense_embedding_cosine" + + addition_search_params = { + "metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, + limit=10, + filter=filter_condition, + search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + # Force batching. + min_batch_size, max_batch_size = 2, 2 + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 1], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 1], + 'distance': [1.0, 1.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_filtered_search_with_bm25_full_text_and_batching(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(sparse_embedding=None), + content=Content(text="This is a test document")), + Chunk( + id="query2", + embedding=Embedding(sparse_embedding=None), + content=Content(text="Another test document")), + Chunk( + id="query3", + embedding=Embedding(sparse_embedding=None), + content=Content(text="وثيقة اختبار")) + ] + + filter_condition = 'ARRAY_CONTAINS_ANY(tags, ["healthcare", "banking"])' + + anns_field = "sparse_embedding_bm25" + + addition_search_params = {"metric_type": KeywordSearchMetrics.BM25.value} + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_field, + limit=10, + filter=filter_condition, + search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=keyword_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + # Force batching. + min_batch_size, max_batch_size = 2, 2 + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(text='This is a test document'), + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [3.3], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding()), + Chunk( + id='query2', + content=Content(text='Another test document'), + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [0.8], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding()), + Chunk( + id='query3', + content=Content(text='وثيقة اختبار'), + metadata={ + 'enrichment_data': { + 'id': [3], + 'distance': [2.3], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }] + } + }, + embedding=Embedding()) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_vector_search_with_euclidean_distance(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + anns_field = "dense_embedding_euclidean" + + addition_search_params = { + "metric_type": VectorSearchMetrics.EUCLIDEAN_DISTANCE.value, + "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, limit=10, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2, 3], + 'distance': [0.0, 0.0, 0.1], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [2, 3, 1], + 'distance': [0.0, 0.0, 0.0], + 'fields': [{ + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.0, 0.0, 0.1], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_vector_search_with_inner_product_similarity(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content()), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content()), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content()) + ] + + anns_field = "dense_embedding_inner_product" + + addition_search_params = { + "metric_type": VectorSearchMetrics.INNER_PRODUCT.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_field, limit=10, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=vector_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.3, 0.2, 0.1], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])), + Chunk( + id='query2', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.4, 0.3, 0.2], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4])), + Chunk( + id='query3', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [3, 2, 1], + 'distance': [0.5, 0.4, 0.3], + 'fields': [{ + 'content': 'وثيقة اختبار', + 'metadata': { + 'language': 'ar' + }, + 'id': 3 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }, + { + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_keyword_search_with_inner_product_sparse_embedding(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41])), + content=Content()) + ] + + anns_field = "sparse_embedding_inner_product" + + addition_search_params = { + "metric_type": VectorSearchMetrics.INNER_PRODUCT.value, + } + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_field, limit=3, search_params=addition_search_params) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=keyword_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + id='query1', + content=Content(), + metadata={ + 'enrichment_data': { + 'id': [1, 2], + 'distance': [0.3, 0.2], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }, + { + 'content': 'Another test document', + 'metadata': { + 'language': 'en' + }, + 'id': 2 + }] + } + }, + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41]))) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + def test_hybrid_search(self): + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="This is a test document")) + ] + + anns_vector_field = "dense_embedding_cosine" + addition_vector_search_params = { + "metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1 + } + + vector_search_parameters = VectorSearchParameters( + anns_field=anns_vector_field, + limit=10, + search_params=addition_vector_search_params) + + anns_keyword_field = "sparse_embedding_bm25" + addition_keyword_search_params = { + "metric_type": KeywordSearchMetrics.BM25.value + } + + keyword_search_parameters = KeywordSearchParameters( + anns_field=anns_keyword_field, + limit=10, + search_params=addition_keyword_search_params) + + hybrid_search_parameters = HybridSearchParameters( + vector=vector_search_parameters, + keyword=keyword_search_parameters, + ranker=RRFRanker(1), + limit=1) + + search_parameters = MilvusSearchParameters( + collection_name=MILVUS_IT_CONFIG["collection_name"], + search_strategy=hybrid_search_parameters, + output_fields=["id", "content", "metadata"], + round_decimal=1) + + collection_load_parameters = MilvusCollectionLoadParameters() + + handler = MilvusSearchEnrichmentHandler( + connection_parameters=self._connection_params, + search_parameters=search_parameters, + collection_load_parameters=collection_load_parameters) + + expected_chunks = [ + Chunk( + content=Content(text='This is a test document'), + id='query1', + metadata={ + 'enrichment_data': { + 'id': [1], + 'distance': [1.0], + 'fields': [{ + 'content': 'This is a test document', + 'metadata': { + 'language': 'en' + }, + 'id': 1 + }] + } + }, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3])) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + assert_that( + result, + lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + + +def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: + parsed_chunks = [] + + # Define safe globals and disable built-in functions for safety. + safe_globals = { + 'Chunk': Chunk, + 'Content': Content, + 'Embedding': Embedding, + 'defaultdict': defaultdict, + 'list': list, + '__builtins__': {} + } + + for raw_str in chunk_str_list: + try: + # replace "" with actual list reference. + cleaned_str = re.sub( + r"defaultdict\(", "defaultdict(list", raw_str) + + # Evaluate string in restricted environment. + chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used + if isinstance(chunk, Chunk): + parsed_chunks.append(chunk) + else: + raise ValueError("Parsed object is not a Chunk instance") + except Exception as e: + raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") + + return parsed_chunks + + +def assert_chunks_equivalent( + actual_chunks: List[Chunk], expected_chunks: List[Chunk]): + """assert_chunks_equivalent checks for presence rather than exact match""" + # Sort both lists by ID to ensure consistent ordering. + actual_sorted = sorted(actual_chunks, key=lambda c: c.id) + expected_sorted = sorted(expected_chunks, key=lambda c: c.id) + + actual_len = len(actual_sorted) + expected_len = len(expected_sorted) + err_msg = ( + f"Different number of chunks, actual: {actual_len}, " + f"expected: {expected_len}") + assert actual_len == expected_len, err_msg + + for actual, expected in zip(actual_sorted, expected_sorted): + # Assert that IDs match. + assert actual.id == expected.id + + # Assert that dense embeddings match. + err_msg = f"Dense embedding mismatch for chunk {actual.id}" + assert actual.dense_embedding == expected.dense_embedding, err_msg + + # Assert that sparse embeddings match. + err_msg = f"Sparse embedding mismatch for chunk {actual.id}" + assert actual.sparse_embedding == expected.sparse_embedding, err_msg + + # Assert that text content match. + err_msg = f"Text Content mismatch for chunk {actual.id}" + assert actual.content.text == expected.content.text, err_msg + + # For enrichment_data, be more flexible. + # If "expected" has values for enrichment_data but actual doesn't, that's + # acceptable since vector search results can vary based on many factors + # including implementation details, vector database state, and slight + # variations in similarity calculations. + + # First ensure the enrichment data key exists. + err_msg = f"Missing enrichment_data key in chunk {actual.id}" + assert 'enrichment_data' in actual.metadata, err_msg + + # For enrichment_data, ensure consistent ordering of results. + actual_data = actual.metadata['enrichment_data'] + expected_data = expected.metadata['enrichment_data'] + + # If actual has enrichment data, then perform detailed validation. + if actual_data: + # Ensure the id key exist. + err_msg = f"Missing id key in metadata {actual.id}" + assert 'id' in actual_data, err_msg + + # Validate IDs have consistent ordering. + actual_ids = sorted(actual_data['id']) + expected_ids = sorted(expected_data['id']) + err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}" + assert actual_ids == expected_ids, err_msg + + # Ensure the distance key exist. + err_msg = f"Missing distance key in metadata {actual.id}" + assert 'distance' in actual_data, err_msg + + # Validate distances exist and have same length as IDs. + actual_distances = actual_data['distance'] + expected_distances = expected_data['distance'] + err_msg = ( + "Number of distances doesn't match number of IDs for " + f"chunk {actual.id}") + assert len(actual_distances) == len(expected_distances), err_msg + + # Ensure the fields key exist. + err_msg = f"Missing fields key in metadata {actual.id}" + assert 'fields' in actual_data, err_msg + + # Validate fields have consistent content. + # Sort fields by 'id' to ensure consistent ordering. + actual_fields_sorted = sorted( + actual_data['fields'], key=lambda f: f.get('id', 0)) + expected_fields_sorted = sorted( + expected_data['fields'], key=lambda f: f.get('id', 0)) + + # Compare field IDs. + actual_field_ids = [f.get('id') for f in actual_fields_sorted] + expected_field_ids = [f.get('id') for f in expected_fields_sorted] + err_msg = f"Field IDs don't match for chunk {actual.id}" + assert actual_field_ids == expected_field_ids, err_msg + + # Compare field content. + for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted): + # Ensure the id key exist. + err_msg = f"Missing id key in metadata.fields {actual.id}" + assert 'id' in a_f + + err_msg = f"Field ID mismatch chunk {actual.id}" + assert a_f['id'] == e_f['id'], err_msg + + # Validate field metadata. + err_msg = f"Field Metadata doesn't match for chunk {actual.id}" + assert a_f['metadata'] == e_f['metadata'], err_msg + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py new file mode 100644 index 000000000000..e69915cb3e9b --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py @@ -0,0 +1,343 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +try: + from apache_beam.ml.rag.types import Chunk + from apache_beam.ml.rag.types import Embedding + from apache_beam.ml.rag.types import Content + from apache_beam.ml.rag.enrichment.milvus_search import ( + MilvusSearchEnrichmentHandler, + MilvusConnectionParameters, + MilvusSearchParameters, + MilvusCollectionLoadParameters, + VectorSearchParameters, + KeywordSearchParameters, + HybridSearchParameters, + MilvusBaseRanker, + unpack_dataclass_with_kwargs) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +class MockRanker(MilvusBaseRanker): + def dict(self): + return {"name": "mock_ranker"} + + +class TestMilvusSearchEnrichment(unittest.TestCase): + """Unit tests for general search functionality in the Enrichment Handler.""" + def test_invalid_connection_parameters(self): + """Test validation errors for invalid connection parameters.""" + # Empty URI in connection parameters. + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters(uri="") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=VectorSearchParameters(anns_field="embedding")) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn( + "URI must be provided for Milvus connection", str(context.exception)) + + @parameterized.expand([ + # Empty collection name. + ( + lambda: MilvusSearchParameters( + collection_name="", + search_strategy=VectorSearchParameters(anns_field="embedding")), + "Collection name must be provided" + ), + # Missing search strategy. + ( + lambda: MilvusSearchParameters( + collection_name="test_collection", + search_strategy=None), # type: ignore[arg-type] + "Search strategy must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid general search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + search_params = create_params() + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_unpack_dataclass_with_kwargs(self): + """Test the unpack_dataclass_with_kwargs function.""" + # Create a test dataclass instance. + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530", + user="test_user", + kwargs={"custom_param": "value"}) + + # Call the actual function. + result = unpack_dataclass_with_kwargs(connection_params) + + # Verify the function correctly unpacks the dataclass and merges kwargs. + self.assertEqual(result["uri"], "http://localhost:19530") + self.assertEqual(result["user"], "test_user") + self.assertEqual(result["custom_param"], "value") + + # Verify that kwargs take precedence over existing attributes. + connection_params_with_override = MilvusConnectionParameters( + uri="http://localhost:19530", + user="test_user", + kwargs={"user": "override_user"}) + + result_with_override = unpack_dataclass_with_kwargs( + connection_params_with_override) + self.assertEqual(result_with_override["user"], "override_user") + + +class TestMilvusVectorSearchEnrichment(unittest.TestCase): + """Unit tests specific to vector search functionality""" + + @parameterized.expand([ + # Negative limit in vector search parameters. + ( + lambda: VectorSearchParameters(anns_field="embedding", limit=-1), + "Search limit must be positive, got -1" + ), + # Missing anns_field in vector search parameters. + ( + lambda: VectorSearchParameters(anns_field=None), # type: ignore[arg-type] + "Approximate Nearest Neighbor Search (ANNS) field must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid vector search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_missing_dense_embedding(self): + with self.assertRaises(ValueError) as context: + chunk = Chunk( + id=1, content=None, embedding=Embedding(dense_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_vector_search_data(chunk) + + err_msg = "Chunk 1 missing dense embedding required for vector search" + self.assertIn(err_msg, str(context.exception)) + + +class TestMilvusKeywordSearchEnrichment(unittest.TestCase): + """Unit tests specific to keyword search functionality""" + + @parameterized.expand([ + # Negative limit in keyword search parameters. + ( + lambda: KeywordSearchParameters( + anns_field="sparse_embedding", limit=-1), + "Search limit must be positive, got -1" + ), + # Missing anns_field in keyword search parameters. + ( + lambda: KeywordSearchParameters(anns_field=None), # type: ignore[arg-type] + "Approximate Nearest Neighbor Search (ANNS) field must be provided" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid keyword search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + keyword_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=keyword_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + def test_missing_text_content_and_sparse_embedding(self): + with self.assertRaises(ValueError) as context: + chunk = Chunk( + id=1, + content=Content(text=None), + embedding=Embedding(sparse_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + + err_msg = ( + "Chunk 1 missing both text content and sparse embedding " + "required for keyword search") + self.assertIn(err_msg, str(context.exception)) + + def test_missing_text_content_only(self): + try: + chunk = Chunk( + id=1, + content=Content(text=None), + embedding=Embedding( + sparse_embedding=([1, 2, 3, 4], [0.05, 0.41, 0.05, 0.41]))) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + except Exception as e: + self.fail(f"raised an unexpected exception: {e}") + + def test_missing_sparse_embedding_only(self): + try: + chunk = Chunk( + id=1, + content=Content(text="what is apache beam?"), + embedding=Embedding(sparse_embedding=None)) + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + vector_search_params = VectorSearchParameters(anns_field="embedding") + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=vector_search_params) + collection_load_params = MilvusCollectionLoadParameters() + handler = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + _ = handler._get_keyword_search_data(chunk) + except Exception as e: + self.fail(f"raised an unexpected exception: {e}") + pass + + +class TestMilvusHybridSearchEnrichment(unittest.TestCase): + """Tests specific to hybrid search functionality""" + + @parameterized.expand([ + # Missing vector in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=None, # type: ignore[arg-type] + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=MockRanker()), + "Vector and keyword search parameters must be provided for hybrid " + "search" + ), + # Missing keyword in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=None, # type: ignore[arg-type] + ranker=MockRanker()), + "Vector and keyword search parameters must be provided for hybrid " + "search" + ), + # Missing ranker in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=None), # type: ignore[arg-type] + "Ranker must be provided for hybrid search" + ), + # Negative limit in hybrid search parameters. + ( + lambda: HybridSearchParameters( + vector=VectorSearchParameters(anns_field="embedding"), + keyword=KeywordSearchParameters(anns_field="sparse_embedding"), + ranker=MockRanker(), + limit=-1), + "Search limit must be positive, got -1" + ), + ]) + def test_invalid_search_parameters(self, create_params, expected_error_msg): + """Test validation errors for invalid hybrid search parameters.""" + with self.assertRaises(ValueError) as context: + connection_params = MilvusConnectionParameters( + uri="http://localhost:19530") + hybrid_search_params = create_params() + search_params = MilvusSearchParameters( + collection_name="test_collection", + search_strategy=hybrid_search_params) + collection_load_params = MilvusCollectionLoadParameters() + + _ = MilvusSearchEnrichmentHandler( + connection_parameters=connection_params, + search_parameters=search_params, + collection_load_parameters=collection_load_params) + + self.assertIn(expected_error_msg, str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 79429899e4c1..3bb0e01b68cc 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -44,7 +44,7 @@ class Content: @dataclass class Embedding: """Represents vector embeddings. - + Args: dense_embedding: Dense vector representation sparse_embedding: Optional sparse vector representation for hybrid @@ -58,16 +58,24 @@ class Embedding: @dataclass class Chunk: """Represents a chunk of embeddable content with metadata. - + Args: content: The actual content of the chunk id: Unique identifier for the chunk index: Index of this chunk within the original document metadata: Additional metadata about the chunk (e.g., document source) - embedding: Vector embeddings of the content + embedding: Vector embeddings of the content """ content: Content id: str = field(default_factory=lambda: str(uuid.uuid4())) index: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) embedding: Optional[Embedding] = None + + @property + def dense_embedding(self): + return self.embedding.dense_embedding if self.embedding else None + + @property + def sparse_embedding(self): + return self.embedding.sparse_embedding if self.embedding else None diff --git a/sdks/python/container/license_scripts/dep_urls_py.yaml b/sdks/python/container/license_scripts/dep_urls_py.yaml index da10163fdb4f..b46fc10adf13 100644 --- a/sdks/python/container/license_scripts/dep_urls_py.yaml +++ b/sdks/python/container/license_scripts/dep_urls_py.yaml @@ -135,6 +135,8 @@ pip_dependencies: license: "https://github.com/PiotrDabkowski/pyjsparser/blob/master/LICENSE" pymongo: license: "https://raw.githubusercontent.com/mongodb/mongo-python-driver/master/LICENSE" + milvus-lite: + license: "https://raw.githubusercontent.com/milvus-io/milvus-lite/refs/heads/main/LICENSE" pyproject_hooks: license: "https://raw.githubusercontent.com/pypa/pyproject-hooks/main/LICENSE" python-gflags: diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 5f69f6b11928..e9b4f1905399 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -43,7 +43,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -57,17 +56,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -79,7 +78,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -93,27 +92,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -122,13 +121,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -144,7 +144,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -152,6 +153,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -166,6 +168,7 @@ scikit-learn==1.7.0 scipy==1.15.3 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.1.1 six==1.17.0 sniffio==1.3.1 @@ -175,17 +178,19 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tomli==2.2.1 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 10d55d17f409..af2e75a54b8f 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -42,7 +42,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -55,17 +54,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -77,7 +76,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -91,27 +90,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -120,13 +119,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -142,7 +142,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -150,6 +151,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -161,9 +163,10 @@ requests-mock==1.12.1 rpds-py==0.25.1 rsa==4.9.1 scikit-learn==1.7.0 -scipy==1.15.3 +scipy==1.16.0 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.1.1 six==1.17.0 sniffio==1.3.1 @@ -173,16 +176,18 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index d4b9c8751dca..f48d350e01d3 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -41,7 +41,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -54,17 +53,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -76,7 +75,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -90,27 +89,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.2.6 @@ -119,13 +118,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -141,7 +141,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -149,6 +150,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -160,7 +162,7 @@ requests-mock==1.12.1 rpds-py==0.25.1 rsa==4.9.1 scikit-learn==1.7.0 -scipy==1.15.3 +scipy==1.16.0 scramp==1.4.5 SecretStorage==3.3.3 setuptools==80.9.0 @@ -173,14 +175,15 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 wheel==0.45.1 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 849786b95756..1c2ebc4c7a4c 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -43,7 +43,6 @@ cloud-sql-python-connector==1.18.2 crcmod==1.7 cryptography==45.0.4 Cython==3.1.2 -deprecation==2.1.0 dill==0.3.1.1 dnspython==2.7.0 docker==7.1.0 @@ -57,17 +56,17 @@ freezegun==1.5.2 frozenlist==1.7.0 future==1.0.0 google-api-core==2.25.1 -google-api-python-client==2.172.0 +google-api-python-client==2.174.0 google-apitools==0.5.31 google-auth==2.40.3 google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.97.0 +google-cloud-aiplatform==1.100.0 google-cloud-bigquery==3.34.0 google-cloud-bigquery-storage==2.32.0 google-cloud-bigtable==2.31.0 google-cloud-core==2.4.3 google-cloud-datastore==2.21.0 -google-cloud-dlp==3.30.0 +google-cloud-dlp==3.31.0 google-cloud-language==2.17.2 google-cloud-profiler==4.1.0 google-cloud-pubsub==2.30.0 @@ -79,7 +78,7 @@ google-cloud-storage==2.19.0 google-cloud-videointelligence==2.16.2 google-cloud-vision==3.10.2 google-crc32c==1.7.1 -google-genai==1.20.0 +google-genai==1.23.0 google-resumable-media==2.7.2 googleapis-common-protos==1.70.0 greenlet==3.2.3 @@ -93,27 +92,27 @@ hdfs==2.7.3 httpcore==1.0.9 httplib2==0.22.0 httpx==0.28.1 -hypothesis==6.135.10 +hypothesis==6.135.17 idna==3.10 importlib_metadata==8.7.0 iniconfig==2.1.0 jaraco.classes==3.4.0 jaraco.context==6.0.1 -jaraco.functools==4.1.0 +jaraco.functools==4.2.1 jeepney==0.9.0 Jinja2==3.1.6 joblib==1.5.1 jsonpickle==3.4.2 jsonschema==4.24.0 jsonschema-specifications==2025.4.1 -kafka-python==2.2.11 keyring==25.6.0 keyrings.google-artifactregistry-auth==1.1.2 MarkupSafe==3.0.2 +milvus-lite==2.5.1 mmh3==5.1.0 mock==5.2.0 more-itertools==10.7.0 -multidict==6.4.4 +multidict==6.6.2 mysql-connector-python==9.3.0 nltk==3.9.1 numpy==2.0.2 @@ -122,13 +121,14 @@ objsize==0.7.1 opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-semantic-conventions==0.55b1 -oracledb==3.1.1 +oracledb==3.2.0 orjson==3.10.18 overrides==7.7.0 packaging==25.0 pandas==2.2.3 parameterized==0.9.0 pg8000==1.31.2 +pip==25.1.1 pluggy==1.6.0 propcache==0.3.2 proto-plus==1.26.1 @@ -144,7 +144,8 @@ pydantic_core==2.33.2 pydot==1.4.2 PyHamcrest==2.1.0 PyJWT==2.9.0 -pymongo==4.13.1 +pymilvus==2.5.11 +pymongo==4.13.2 PyMySQL==1.1.1 pyparsing==3.2.3 pyproject_hooks==1.2.0 @@ -152,6 +153,7 @@ pytest==7.4.4 pytest-timeout==2.4.0 pytest-xdist==3.7.0 python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 python-tds==1.16.1 pytz==2025.2 PyYAML==6.0.2 @@ -166,6 +168,7 @@ scikit-learn==1.6.1 scipy==1.13.1 scramp==1.4.5 SecretStorage==3.3.3 +setuptools==80.9.0 shapely==2.0.7 six==1.17.0 sniffio==1.3.1 @@ -175,17 +178,19 @@ SQLAlchemy==2.0.41 sqlalchemy_pytds==1.0.2 sqlparse==0.5.3 tenacity==8.5.0 -testcontainers==3.7.1 +testcontainers==4.10.0 threadpoolctl==3.6.0 tomli==2.2.1 tqdm==4.67.1 typing-inspection==0.4.1 typing_extensions==4.14.0 tzdata==2025.2 +ujson==5.10.0 uritemplate==4.2.0 -urllib3==2.4.0 +urllib3==2.5.0 virtualenv-clone==0.5.7 websockets==15.0.1 +wheel==0.45.1 wrapt==1.17.2 yarl==1.20.1 zipp==3.23.0 diff --git a/sdks/python/container/run_generate_requirements.sh b/sdks/python/container/run_generate_requirements.sh index 6c160bc6ac9e..23964d10e7b4 100755 --- a/sdks/python/container/run_generate_requirements.sh +++ b/sdks/python/container/run_generate_requirements.sh @@ -72,7 +72,7 @@ pip uninstall -y apache-beam echo "Checking for broken dependencies:" pip check echo "Installed dependencies:" -pip freeze +pip freeze --all PY_IMAGE="py${PY_VERSION//.}" REQUIREMENTS_FILE=$PWD/sdks/python/container/$PY_IMAGE/base_image_requirements.txt @@ -103,7 +103,7 @@ cat < "$REQUIREMENTS_FILE" EOT # Remove pkg_resources to guard against # https://stackoverflow.com/questions/39577984/what-is-pkg-resources-0-0-0-in-output-of-pip-freeze-command -pip freeze | grep -v pkg_resources >> "$REQUIREMENTS_FILE" +pip freeze --all | grep -v pkg_resources >> "$REQUIREMENTS_FILE" if grep -q "tensorflow==" "$REQUIREMENTS_FILE"; then # Get the version of tensorflow from the .txt file. diff --git a/sdks/python/setup.py b/sdks/python/setup.py index a0bbc301435b..d309a7ea4a64 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -401,6 +401,7 @@ def get_portability_package_data(): 'typing-extensions>=3.7.0', 'zstandard>=0.18.0,<1', 'pyyaml>=3.12,<7.0.0', + 'pymilvus>=2.5.10,<3.0.0', # Dynamic dependencies must be specified in a separate list, otherwise # Dependabot won't be able to parse the main list. Any dynamic # dependencies will not receive updates from Dependabot. @@ -434,11 +435,10 @@ def get_portability_package_data(): 'pytest-xdist>=2.5.0,<4', 'pytest-timeout>=2.1.0,<3', 'scikit-learn>=0.20.0', - 'setuptools', 'sqlalchemy>=1.3,<3.0', 'psycopg2-binary>=2.8.5,<2.9.10; python_version <= "3.9"', 'psycopg2-binary>=2.8.5,<3.0; python_version >= "3.10"', - 'testcontainers[mysql,kafka]>=3.0.3,<4.0.0', + 'testcontainers[mysql,kafka,milvus]>=4.0.0,<5.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0',