diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base.py b/sdks/python/apache_beam/ml/rag/embeddings/base.py index 25dc3ee47e80..0128d6a6d6fc 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base.py @@ -14,42 +14,56 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""RAG-specific embedding adapters. + +This module provides adapters for extracting content from +EmbeddableItem instances and mapping embeddings back. Adapters +are used by EmbeddingsManager to support various input types. +""" + from collections.abc import Sequence from typing import List -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.rag.types import Embedding from apache_beam.ml.transforms.base import EmbeddingTypeAdapter -def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]: - """Creates adapter for converting between Chunk and Embedding types. - - The adapter: - - Extracts text from Chunk.content.text for embedding - - Creates Embedding objects from model output - - Sets Embedding in Chunk.embedding - - Returns: - EmbeddingTypeAdapter configured for RAG pipeline types - """ +def create_text_adapter( +) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]: + """Creates adapter for text content embedding. + + Works with any EmbeddableItem that has text content + (content.text). Extracts text for embedding and maps + results back as Embedding objects. + + Returns: + EmbeddingTypeAdapter configured for text embedding + """ return EmbeddingTypeAdapter( - input_fn=_extract_chunk_text, output_fn=_add_embedding_fn) + input_fn=_extract_text, output_fn=_add_embedding_fn) + + +# Backward compatibility alias. +create_rag_adapter = create_text_adapter -def _extract_chunk_text(chunks: Sequence[Chunk]) -> List[str]: - """Extract text from chunks for embedding.""" - chunk_texts = [] - for chunk in chunks: - if not chunk.content.text: - raise ValueError("Expected chunk text content.") - chunk_texts.append(chunk.content.text) - return chunk_texts +def _extract_text(items: Sequence[EmbeddableItem]) -> List[str]: + """Extract text from items for embedding.""" + texts = [] + for item in items: + if not item.content.text: + raise ValueError( + f"Expected text content in {type(item).__name__} {item.id}, " + "got None") + texts.append(item.content.text) + return texts def _add_embedding_fn( - chunks: Sequence[Chunk], embeddings: Sequence[List[float]]) -> List[Chunk]: - """Create Embeddings from chunks and embedding vectors.""" - for chunk, embedding in zip(chunks, embeddings): - chunk.embedding = Embedding(dense_embedding=embedding) - return list(chunks) + items: Sequence[EmbeddableItem], + embeddings: Sequence[List[float]]) -> List[EmbeddableItem]: + """Create Embeddings from items and embedding vectors.""" + for item, embedding in zip(items, embeddings): + item.embedding = Embedding(dense_embedding=embedding) + return list(items) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py index 3a27ae8e7ebb..aacdf6004ee7 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -16,7 +16,7 @@ import unittest -from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.embeddings.base import create_text_adapter from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Content from apache_beam.ml.rag.types import Embedding @@ -41,7 +41,7 @@ def setUp(self): def test_adapter_input_conversion(self): """Test the RAG type adapter converts correctly.""" - adapter = create_rag_adapter() + adapter = create_text_adapter() # Test input conversion texts = adapter.input_fn(self.test_chunks) @@ -49,10 +49,10 @@ def test_adapter_input_conversion(self): def test_adapter_input_conversion_missing_text_content(self): """Test the RAG type adapter converts correctly.""" - adapter = create_rag_adapter() + adapter = create_text_adapter() # Test input conversion - with self.assertRaisesRegex(ValueError, "Expected chunk text content"): + with self.assertRaisesRegex(ValueError, "Expected text content"): adapter.input_fn([ Chunk( content=Content(), @@ -83,7 +83,7 @@ def test_adapter_output_conversion(self): }, content=Content(text='Another example.')), ] - adapter = create_rag_adapter() + adapter = create_text_adapter() embeddings = adapter.output_fn(self.test_chunks, mock_embeddings) self.assertListEqual(embeddings, expected) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index 4cb0aecd6e82..8cf9298849cc 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -20,8 +20,8 @@ import apache_beam as beam from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.rag.embeddings.base import create_rag_adapter -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.embeddings.base import create_text_adapter +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _TextEmbeddingHandler from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler @@ -49,7 +49,7 @@ def __init__( "sentence-transformers is required to use " "HuggingfaceTextEmbeddings." "Please install it with using `pip install sentence-transformers`.") - super().__init__(type_adapter=create_rag_adapter(), **kwargs) + super().__init__(type_adapter=create_text_adapter(), **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.model_class = SentenceTransformer @@ -67,8 +67,9 @@ def get_model_handler(self): def get_ptransform_for_processing( self, **kwargs - ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + ) -> beam.PTransform[beam.PCollection[EmbeddableItem], + beam.PCollection[EmbeddableItem]]: """Returns PTransform that uses the RAG adapter.""" return RunInference( model_handler=_TextEmbeddingHandler(self), - inference_args=self.inference_args).with_output_types(Chunk) + inference_args=self.inference_args).with_output_types(EmbeddableItem) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py index 3495a235c114..8ef98223e69f 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py @@ -27,8 +27,8 @@ import apache_beam as beam from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.rag.embeddings.base import create_rag_adapter -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.embeddings.base import create_text_adapter +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _TextEmbeddingHandler from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE @@ -69,7 +69,7 @@ def __init__( "vertexai is required to use VertexAITextEmbeddings. " "Please install it with `pip install google-cloud-aiplatform`") - super().__init__(type_adapter=create_rag_adapter(), **kwargs) + super().__init__(type_adapter=create_text_adapter(), **kwargs) self.model_name = model_name self.title = title self.task_type = task_type @@ -90,8 +90,9 @@ def get_model_handler(self): def get_ptransform_for_processing( self, **kwargs - ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + ) -> beam.PTransform[beam.PCollection[EmbeddableItem], + beam.PCollection[EmbeddableItem]]: """Returns PTransform that uses the RAG adapter.""" return RunInference( model_handler=_TextEmbeddingHandler(self), - inference_args=self.inference_args).with_output_types(Chunk) + inference_args=self.inference_args).with_output_types(EmbeddableItem) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py index f6117a260a34..614e5f9c0800 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py @@ -28,7 +28,7 @@ from google.cloud import bigquery -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.rag.types import Embedding from apache_beam.transforms.enrichment import EnrichmentSourceHandler @@ -41,13 +41,13 @@ class BigQueryVectorSearchParameters: This class is used by BigQueryVectorSearchEnrichmentHandler to perform vector similarity search using BigQuery's VECTOR_SEARCH function. It - processes :class:`~apache_beam.ml.rag.types.Chunk` objects that contain - :class:`~apache_beam.ml.rag.types.Embedding` and returns similar vectors - from a BigQuery table. + processes :class:`~apache_beam.ml.rag.types.EmbeddableItem` objects that + contain :class:`~apache_beam.ml.rag.types.Embedding` and returns similar + vectors from a BigQuery table. BigQueryVectorSearchEnrichmentHandler is used with :class:`~apache_beam.transforms.enrichment.Enrichment` transform to enrich - Chunks with similar content from a vector database. For example: + EmbeddableItems with similar content from a vector database. For example: >>> # Create search parameters >>> params = BigQueryVectorSearchParameters( @@ -58,7 +58,7 @@ class BigQueryVectorSearchParameters: ... ) >>> # Use in pipeline >>> enriched = ( - ... chunks + ... embeddable_items ... | "Generate Embeddings" >> MLTransform(...) ... | "Find Similar" >> Enrichment( ... BigQueryVectorSearchEnrichmentHandler( @@ -89,11 +89,11 @@ class BigQueryVectorSearchParameters: ... columns=['content', 'language'], ... neighbor_count=5, ... # For column 'language', value comes from - ... # chunk.metadata['language'] + ... # embeddable_item.metadata['language'] ... metadata_restriction_template="language = '{language}'" ... ) - >>> # When processing a chunk with metadata={'language': 'en'}, - >>> # generates: WHERE language = 'en' + >>> # When processing a embeddable_item with + >>> # metadata={'language': 'en'}, generates: WHERE language = 'en' Example with nested repeated metadata: @@ -113,12 +113,14 @@ class BigQueryVectorSearchParameters: ... embedding_column='embedding', ... columns=['content', 'metadata'], ... neighbor_count=5, - ... # check_metadata(field_name, key_to_search, value_from_chunk) + ... # check_metadata(field_name, key_to_search, + ... # value_from_embeddable_item) ... metadata_restriction_template=( ... "check_metadata(metadata, 'language', '{language}')" ... ) ... ) - >>> # When processing a chunk with metadata={'language': 'en'}, + >>> # When processing a embeddable_item with + >>> # metadata={'language': 'en'}, >>> # generates: WHERE check_metadata(metadata, 'language', 'en') >>> # Searches for {key: 'language', value: 'en'} in metadata array @@ -134,13 +136,13 @@ class BigQueryVectorSearchParameters: 1. For flattened metadata columns: ``column_name = '{metadata_key}'`` where column_name is the BigQuery column and metadata_key is used to get the value from - chunk.metadata[metadata_key]. + embeddable_item.metadata[metadata_key]. 2. For nested repeated metadata (ARRAY>): ``check_metadata(field_name, 'key_to_match', '{metadata_key}')`` where field_name is the ARRAY column in BigQuery, key_to_match is the literal key to search for in the array, and metadata_key is used to get value from - chunk.metadata[metadata_key]. + embeddable_item.metadata[metadata_key]. Multiple conditions can be combined using AND/OR operators. For example:: @@ -150,7 +152,8 @@ class BigQueryVectorSearchParameters: ... "check_metadata(metadata, 'language', '{language}') " ... "AND source = '{source}'" ... ) - >>> # When chunk.metadata = {'language': 'en', 'source': 'web'} + >>> # When embeddable_item.metadata = {'language': 'en', + >>> # 'source': 'web'} >>> # Generates: WHERE >>> # check_metadata(metadata, 'language', 'en') >>> # AND source = 'web' @@ -165,23 +168,24 @@ class BigQueryVectorSearchParameters: embedding_column: str columns: List[str] neighbor_count: int - metadata_restriction_template: Optional[Union[str, Callable[[Chunk], - str]]] = None + metadata_restriction_template: Optional[Union[str, + Callable[[EmbeddableItem], + str]]] = None distance_type: Optional[str] = None options: Optional[Dict[str, Any]] = None include_distance: bool = False - def _format_restrict(self, chunk: Chunk) -> str: + def _format_restrict(self, item: EmbeddableItem) -> str: assert self.metadata_restriction_template is not None, ( "metadata_restriction_template cannot be None when formatting. " "This indicates a logical error in the code." ) if callable(self.metadata_restriction_template): - return self.metadata_restriction_template(chunk) - return self.metadata_restriction_template.format(**chunk.metadata) + return self.metadata_restriction_template(item) + return self.metadata_restriction_template.format(**item.metadata) - def format_query(self, chunks: List[Chunk]) -> str: + def format_query(self, items: List[EmbeddableItem]) -> str: """Format the vector search query template.""" base_columns_str = ", ".join(f"base.{col}" for col in self.columns) columns_str = ", ".join(self.columns) @@ -204,27 +208,27 @@ def format_query(self, chunks: List[Chunk]) -> str: )); """ if self.metadata_restriction_template else "" - # Group chunks by their metadata conditions + # Group items by their metadata conditions condition_groups = defaultdict(list) if self.metadata_restriction_template: - for chunk in chunks: - condition = self._format_restrict(chunk) - condition_groups[condition].append(chunk) + for item in items: + condition = self._format_restrict(item) + condition_groups[condition].append(item) else: - # No metadata filtering - all chunks in one group - condition_groups[""] = chunks + # No metadata filtering - all items in one group + condition_groups[""] = items # Generate VECTOR_SEARCH subqueries for each condition group vector_searches = [] - for condition, group_chunks in condition_groups.items(): + for condition, group_items in condition_groups.items(): # Create embeddings subquery for this group embedding_unions = [] - for chunk in group_chunks: - if not chunk.dense_embedding: - raise ValueError(f"Chunk {chunk.id} missing embedding") + for item in group_items: + if not item.dense_embedding: + raise ValueError(f"Item {item.id} missing embedding") embedding_str = ( - f"SELECT '{chunk.id}' as id, " - f"{[float(x) for x in chunk.dense_embedding]} " + f"SELECT '{item.id}' as id, " + f"{[float(x) for x in item.dense_embedding]} " f"as embedding") embedding_unions.append(embedding_str) group_embeddings = " UNION ALL ".join(embedding_unions) @@ -235,10 +239,11 @@ def format_query(self, chunks: List[Chunk]) -> str: SELECT query.id, ARRAY_AGG( - STRUCT({"distance, " if self.include_distance else ""} {base_columns_str}) - ) as chunks + STRUCT({"distance, " if self.include_distance else ""}\ + {base_columns_str}) + ) as embeddable_items FROM VECTOR_SEARCH( - (SELECT {columns_str}, {self.embedding_column} + (SELECT {columns_str}, {self.embedding_column} FROM `{self.table_name}` {where_clause}), '{self.embedding_column}', @@ -262,16 +267,18 @@ def format_query(self, chunks: List[Chunk]) -> str: class BigQueryVectorSearchEnrichmentHandler( - EnrichmentSourceHandler[Union[Chunk, List[Chunk]], - List[Tuple[Chunk, Dict[str, Any]]]]): + EnrichmentSourceHandler[Union[EmbeddableItem, List[EmbeddableItem]], + List[Tuple[EmbeddableItem, Dict[str, Any]]]]): """Enrichment handler that performs vector similarity search using BigQuery. - This handler enriches Chunks by finding similar vectors in a BigQuery table - using the VECTOR_SEARCH function. It supports batching requests for efficiency - and preserves the original Chunk metadata while adding the search results. + This handler enriches EmbeddableItems by finding similar vectors in a + BigQuery table using the VECTOR_SEARCH function. It supports batching + requests for efficiency and preserves the original metadata while adding + the search results. Example: - >>> from apache_beam.ml.rag.types import Chunk, Content, Embedding + >>> from apache_beam.ml.rag.types import EmbeddableItem + >>> from apache_beam.ml.rag.types import Content, Embedding >>> >>> # Configure vector search >>> params = BigQueryVectorSearchParameters( @@ -295,7 +302,7 @@ class BigQueryVectorSearchEnrichmentHandler( ... enriched = ( ... p ... | beam.Create([ - ... Chunk( + ... EmbeddableItem( ... id='query1', ... embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), ... content=Content(text='test query'), @@ -307,16 +314,16 @@ class BigQueryVectorSearchEnrichmentHandler( Args: vector_search_parameters: Configuration for the vector search query - min_batch_size: Minimum number of chunks to batch before processing - max_batch_size: Maximum number of chunks to process in one batch + min_batch_size: Minimum number of items to process in one batch + max_batch_size: Maximum number of items to process in one batch log_query: Debug option to log the BigQuery query **kwargs: Additional arguments passed to bigquery.Client The handler will: - 1. Batch incoming chunks according to batch size parameters + 1. Batch incoming embeddable_items according to batch size parameters 2. Format and execute vector search query for each batch - 3. Join results back to original chunks - 4. Return tuples of (original_chunk, search_results) + 3. Join results back to original embeddable_items + 4. Return tuples of (original_embeddable_item, search_results) """ def __init__( self, @@ -339,17 +346,20 @@ def __init__( def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) - def __call__(self, request: Union[Chunk, List[Chunk]], *args, - **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + def __call__( + self, + request: Union[EmbeddableItem, List[EmbeddableItem]], + *args, + **kwargs) -> List[Tuple[EmbeddableItem, Dict[str, Any]]]: """Process request(s) using BigQuery vector search. Args: - request: Single Chunk with embedding or list of Chunk's with - embeddings to process - + request: Single EmbeddableItem with embedding or list of + EmbeddableItems with embeddings to process + Returns: - Chunk(s) where chunk.metadata['enrichment_output'] contains the - data retrieved via BigQuery VECTOR_SEARCH. + EmbeddableItem(s) where metadata['enrichment_output'] contains + the data retrieved via BigQuery VECTOR_SEARCH. """ # Convert single request to list for uniform processing requests = request if isinstance(request, list) else [request] @@ -361,17 +371,18 @@ def __call__(self, request: Union[Chunk, List[Chunk]], *args, query_job = self.client.query(query) results = query_job.result() - # Create results dict with empty chunks list as default + # Create results dict with empty embeddable_items list as default results_by_id = {} for result_row in results: result_dict = dict(result_row.items()) results_by_id[result_row.id] = result_dict - # Return all chunks in original order, with empty results if no matches + # Return all embeddable_items in original order, with empty results if + # no matches response = [] - for chunk in requests: - result_dict = results_by_id.get(chunk.id, {}) - response.append((chunk, result_dict)) + for embeddable_item in requests: + result_dict = results_by_id.get(embeddable_item.id, {}) + response.append((embeddable_item, result_dict)) return response diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py index 41355e8c10aa..85a63cfba21e 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -34,7 +34,7 @@ from pymilvus import SearchResult from pymilvus.exceptions import MilvusException -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.rag.types import Embedding from apache_beam.ml.rag.utils import MilvusConnectionParameters from apache_beam.ml.rag.utils import MilvusHelpers @@ -279,7 +279,7 @@ class MilvusCollectionLoadParameters: @dataclass class MilvusSearchResult: - """Search result from Milvus per chunk. + """Search result from Milvus per embeddable_item. Args: id: List of entity IDs returned from the search. Can be either string or @@ -293,7 +293,8 @@ class MilvusSearchResult: fields: List[Dict[str, Any]] = field(default_factory=list) -InputT, OutputT = Union[Chunk, List[Chunk]], List[Tuple[Chunk, Dict[str, Any]]] +InputT, OutputT = (Union[EmbeddableItem, List[EmbeddableItem]], + List[Tuple[EmbeddableItem, Dict[str, Any]]]) class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]): @@ -412,8 +413,11 @@ def connect_and_load(): exception_types=(MilvusException, )) return self - def __call__(self, request: Union[Chunk, List[Chunk]], *args, - **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + def __call__( + self, + request: Union[EmbeddableItem, List[EmbeddableItem]], + *args, + **kwargs) -> List[Tuple[EmbeddableItem, Dict[str, Any]]]: reqs = request if isinstance(request, list) else [request] # Early return for empty requests to avoid unnecessary connection attempts if not reqs: @@ -421,9 +425,9 @@ def __call__(self, request: Union[Chunk, List[Chunk]], *args, search_result = self._search_documents(reqs) return self._get_call_response(reqs, search_result) - def _search_documents(self, chunks: List[Chunk]): + def _search_documents(self, embeddable_items: List[EmbeddableItem]): if isinstance(self.search_strategy, HybridSearchParameters): - data = self._get_hybrid_search_data(chunks) + data = self._get_hybrid_search_data(embeddable_items) return self._client.hybrid_search( collection_name=self.collection_name, partition_names=self.partition_names, @@ -435,7 +439,7 @@ def _search_documents(self, chunks: List[Chunk]): limit=self.search_strategy.limit, **self.search_strategy.kwargs) elif isinstance(self.search_strategy, VectorSearchParameters): - data = list(map(self._get_vector_search_data, chunks)) + data = list(map(self._get_vector_search_data, embeddable_items)) vector_search_params = unpack_dataclass_with_kwargs(self.search_strategy) return self._client.search( collection_name=self.collection_name, @@ -446,7 +450,7 @@ def _search_documents(self, chunks: List[Chunk]): data=data, **vector_search_params) elif isinstance(self.search_strategy, KeywordSearchParameters): - data = list(map(self._get_keyword_search_data, chunks)) + data = list(map(self._get_keyword_search_data, embeddable_items)) keyword_search_params = unpack_dataclass_with_kwargs(self.search_strategy) return self._client.search( collection_name=self.collection_name, @@ -460,9 +464,11 @@ def _search_documents(self, chunks: List[Chunk]): raise ValueError( f"Not supported search strategy yet: {self.search_strategy}") - def _get_hybrid_search_data(self, chunks: List[Chunk]): - vector_search_data = list(map(self._get_vector_search_data, chunks)) - keyword_search_data = list(map(self._get_keyword_search_data, chunks)) + def _get_hybrid_search_data(self, embeddable_items: List[EmbeddableItem]): + vector_search_data = list( + map(self._get_vector_search_data, embeddable_items)) + keyword_search_data = list( + map(self._get_keyword_search_data, embeddable_items)) vector_search_req = AnnSearchRequest( data=vector_search_data, @@ -481,35 +487,40 @@ def _get_hybrid_search_data(self, chunks: List[Chunk]): reqs = [vector_search_req, keyword_search_req] return reqs - def _get_vector_search_data(self, chunk: Chunk): - if not chunk.dense_embedding: + def _get_vector_search_data(self, embeddable_item: EmbeddableItem): + if not embeddable_item.dense_embedding: raise ValueError( - f"Chunk {chunk.id} missing dense embedding required for vector search" - ) - return chunk.dense_embedding - - def _get_keyword_search_data(self, chunk: Chunk): - if not chunk.content.text and not chunk.sparse_embedding: + f"Item {embeddable_item.id} missing dense embedding required for" + " vector search") + return embeddable_item.dense_embedding + + def _get_keyword_search_data(self, embeddable_item: EmbeddableItem): + has_no_text = not embeddable_item.content.text + has_no_sparse = not embeddable_item.sparse_embedding + if has_no_text and has_no_sparse: raise ValueError( - f"Chunk {chunk.id} missing both text content and sparse embedding " - "required for keyword search") - sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding) - return chunk.content.text or sparse_embedding + f"Item {embeddable_item.id} missing both text content and sparse " + "embedding required for keyword search") + sparse_embedding = MilvusHelpers.sparse_embedding( + embeddable_item.sparse_embedding) + return embeddable_item.content.text or sparse_embedding def _get_call_response( - self, chunks: List[Chunk], search_result: SearchResult[Hits]): + self, + embeddable_items: List[EmbeddableItem], + search_result: SearchResult[Hits]): response = [] - for i in range(len(chunks)): - chunk = chunks[i] + for i in range(len(embeddable_items)): + embeddable_item = embeddable_items[i] hits: Hits = search_result[i] result = MilvusSearchResult() - for i in range(len(hits)): - hit: Hit = hits[i] + for j in range(len(hits)): + hit: Hit = hits[j] normalized_fields = self._normalize_milvus_fields(hit.fields) result.id.append(hit.id) result.distance.append(hit.distance) result.fields.append(normalized_fields) - response.append((chunk, result.__dict__)) + response.append((embeddable_item, result.__dict__)) return response def _normalize_milvus_fields(self, fields: Dict[str, Any]): diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py index ef5af8ca4940..3fa593e4b68a 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_test.py @@ -169,7 +169,7 @@ def test_missing_dense_embedding(self): _ = handler._get_vector_search_data(chunk) - err_msg = "Chunk 1 missing dense embedding required for vector search" + err_msg = "Item 1 missing dense embedding required for vector search" self.assertIn(err_msg, str(context.exception)) @@ -228,7 +228,7 @@ def test_missing_text_content_and_sparse_embedding(self): _ = handler._get_keyword_search_data(chunk) err_msg = ( - "Chunk 1 missing both text content and sparse embedding " + "Item 1 missing both text content and sparse embedding " "required for keyword search") self.assertIn(err_msg, str(context.exception)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py index 229c3e2bd99b..333c259f9b86 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py @@ -149,7 +149,7 @@ def __init__( column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database - schema. If None, uses default Chunk schema. + schema. If None, uses default EmbeddableItem schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by diff --git a/sdks/python/apache_beam/ml/rag/ingestion/base.py b/sdks/python/apache_beam/ml/rag/ingestion/base.py index d79aa7778405..5f894ce4cc86 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/base.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/base.py @@ -19,7 +19,7 @@ from typing import Any import apache_beam as beam -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem class VectorDatabaseWriteConfig(ABC): @@ -32,7 +32,7 @@ class VectorDatabaseWriteConfig(ABC): The configuration flow: 1. Subclass provides database-specific configuration (table names, etc) 2. create_write_transform() creates appropriate PTransform for writing - 3. Transform handles converting Chunks to database-specific format + 3. Transform handles converting EmbeddableItems to database-specific format Example implementation: >>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): @@ -45,14 +45,15 @@ class VectorDatabaseWriteConfig(ABC): ... ) """ @abstractmethod - def create_write_transform(self) -> beam.PTransform[Chunk, Any]: + def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]: """Creates a PTransform that writes embeddings to the vector database. - + Returns: - A PTransform that accepts PCollection[Chunk] and writes the chunks' - embeddings and metadata to the configured vector database. + A PTransform that accepts PCollection[EmbeddableItem] + and writes the embeddings + and metadata to the configured vector database. The transform should handle: - - Converting Chunk format to database schema + - Converting EmbeddableItem format to database schema - Setting up database connection/client - Writing with appropriate batching/error handling """ @@ -71,10 +72,10 @@ class VectorDatabaseWriteTransform(beam.PTransform): ... table='project.dataset.embeddings', ... embedding_column='embedding' ... ) - >>> + >>> >>> with beam.Pipeline() as p: - ... chunks = p | beam.Create([...]) # PCollection[Chunk] - ... chunks | VectorDatabaseWriteTransform(config) + ... items = p | beam.Create([...]) # PCollection[EmbeddableItem] + ... items | VectorDatabaseWriteTransform(config) Args: database_config: Configuration for the target vector database. @@ -96,17 +97,18 @@ def __init__(self, database_config: VectorDatabaseWriteConfig): f"got {type(database_config)}") self.database_config = database_config - def expand(self, - pcoll: beam.PCollection[Chunk]) -> beam.PTransform[Chunk, Any]: + def expand( + self, pcoll: beam.PCollection[EmbeddableItem] + ) -> beam.PTransform[EmbeddableItem, Any]: """Creates and applies the database-specific write transform. - + Args: - pcoll: PCollection of Chunks with embeddings to write to the - vector database. Each Chunk must have: + pcoll: PCollection of EmbeddableItems with embeddings to write to the + vector database. Each EmbeddableItem must have: - An embedding - An ID - Metadata used to filter results as specified by database config - + Returns: Result of writing to database (implementation specific). """ diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py index 35cd65ff7a94..e955c7856f50 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py @@ -24,10 +24,10 @@ from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.typehints.row_type import RowTypeConstraint -ChunkToDictFn = Callable[[Chunk], Dict[str, any]] +EmbeddableToDictFn = Callable[[EmbeddableItem], Dict[str, any]] @dataclass @@ -48,11 +48,13 @@ class SchemaConfig: ... {'name': 'custom_field', 'type': 'STRING'} ... ] ... } - chunk_to_dict_fn: Function that converts a Chunk to a dict matching the - schema. Takes a Chunk and returns Dict[str, Any] with keys matching + embeddable_to_dict_fn: Function that converts an + EmbeddableItem to a dict matching the schema. + Takes an EmbeddableItem and returns + Dict[str, Any] with keys matching schema fields. Example: - >>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]: + >>> def embeddable_to_dict(chunk: EmbeddableItem) -> Dict[str, Any]: ... return { ... 'id': chunk.id, ... 'embedding': chunk.embedding.dense_embedding, @@ -60,7 +62,7 @@ class SchemaConfig: ... } """ schema: Dict - chunk_to_dict_fn: ChunkToDictFn + embeddable_to_dict_fn: EmbeddableToDictFn class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig): @@ -87,7 +89,7 @@ def __init__( ... {'name': 'source_url', 'type': 'STRING'} ... ] ... }, - ... chunk_to_dict_fn=lambda chunk: { + ... embeddable_to_dict_fn=lambda chunk: { ... 'id': chunk.id, ... 'embedding': chunk.embedding.dense_embedding, ... 'source_url': chunk.metadata.get('url') @@ -121,7 +123,7 @@ def create_write_transform(self) -> beam.PTransform: return _WriteToBigQueryVectorDatabase(self) -def _default_chunk_to_dict_fn(chunk: Chunk): +def _default_embeddable_to_dict_fn(chunk: EmbeddableItem): if chunk.embedding is None or chunk.embedding.dense_embedding is None: raise ValueError("chunk must contain dense embedding") return { @@ -161,17 +163,17 @@ class _WriteToBigQueryVectorDatabase(beam.PTransform): def __init__(self, config: BigQueryVectorWriterConfig): self.config = config - def expand(self, pcoll: beam.PCollection[Chunk]): + def expand(self, pcoll: beam.PCollection[EmbeddableItem]): schema = ( self.config.schema_config.schema if self.config.schema_config else _default_schema()) - chunk_to_dict_fn = ( - self.config.schema_config.chunk_to_dict_fn - if self.config.schema_config else _default_chunk_to_dict_fn) + embeddable_to_dict_fn = ( + self.config.schema_config.embeddable_to_dict_fn + if self.config.schema_config else _default_embeddable_to_dict_fn) return ( pcoll - | "Chunk to dict" >> beam.Map(chunk_to_dict_fn) - | "Chunk dict to schema'd row" >> beam.Map( + | "EmbeddableItem to dict" >> beam.Map(embeddable_to_dict_fn) + | "EmbeddableItem dict to schema'd row" >> beam.Map( lambda chunk_dict: beam_row_from_dict( row=chunk_dict, schema=schema)).with_output_types( RowTypeConstraint.from_fields( diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py index d3710a7f70a4..4cd6474ba348 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py @@ -164,7 +164,7 @@ def __init__( column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database - schema. If None, uses default Chunk schema. + schema. If None, uses default EmbeddableItem schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py index c73aba5f42e4..7d7c554cc68e 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py @@ -31,7 +31,7 @@ from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE from apache_beam.ml.rag.utils import MilvusConnectionParameters from apache_beam.ml.rag.utils import MilvusHelpers @@ -87,11 +87,11 @@ class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): This class extends VectorDatabaseWriteConfig to provide Milvus-specific configuration for ingesting vector embeddings and associated metadata. - It defines how Apache Beam chunks are converted to Milvus records and + It defines how EmbeddableItem objects are converted to Milvus records and handles the write operation parameters. The configuration includes connection parameters, write settings, and - column specifications that determine how chunk data is mapped to Milvus + column specifications that determine how data is mapped to Milvus fields. Args: @@ -99,7 +99,8 @@ class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): including URI, credentials, and connection options. write_config: Configuration for write operations including collection name, partition, batch size, and timeouts. - column_specs: List of column specifications defining how chunk fields are + column_specs: List of column specifications defining how + EmbeddableItem fields are mapped to Milvus collection fields. Defaults to standard RAG fields (id, embedding, sparse_embedding, content, metadata). @@ -115,14 +116,15 @@ class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): column_specs: List[ColumnSpec] = field( default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) - def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: - """Creates a function to convert Apache Beam Chunks to Milvus records. + def create_converter(self) -> Callable[[EmbeddableItem], Dict[str, Any]]: + """Creates a function to convert EmbeddableItem objects to Milvus records. Returns: - A function that takes a Chunk and returns a dictionary representing + A function that takes an EmbeddableItem and returns a + dictionary representing a Milvus record with fields mapped according to column_specs. """ - def convert(chunk: Chunk) -> Dict[str, Any]: + def convert(chunk: EmbeddableItem) -> Dict[str, Any]: result = {} for col in self.column_specs: result[col.column_name] = col.value_fn(chunk) @@ -134,7 +136,8 @@ def create_write_transform(self) -> beam.PTransform: """Creates the Apache Beam transform for writing to Milvus. Returns: - A PTransform that can be applied to a PCollection of Chunks to write + A PTransform that can be applied to a PCollection of + EmbeddableItem objects to write them to the configured Milvus collection. """ return _WriteToMilvusVectorDatabase(self) @@ -145,7 +148,7 @@ def default_column_specs() -> List[ColumnSpec]: Creates column mappings for standard RAG fields: id, dense embedding, sparse embedding, content text, and metadata. These specifications - define how Chunk fields are converted to Milvus-compatible formats. + define how EmbeddableItem fields are converted to Milvus-compatible formats. Returns: List of ColumnSpec objects defining the default field mappings. @@ -163,7 +166,8 @@ def default_column_specs() -> List[ColumnSpec]: class _WriteToMilvusVectorDatabase(beam.PTransform): """Apache Beam PTransform for writing vector data to Milvus. - This transform handles the conversion of Apache Beam Chunks to Milvus records + This transform handles the conversion of EmbeddableItem objects + to Milvus records and coordinates the write operations. It applies the configured converter function and uses a DoFn for batched writes to optimize performance. @@ -174,11 +178,11 @@ class _WriteToMilvusVectorDatabase(beam.PTransform): def __init__(self, config: MilvusVectorWriterConfig): self.config = config - def expand(self, pcoll: beam.PCollection[Chunk]): - """Expands the PTransform to convert chunks and write to Milvus. + def expand(self, pcoll: beam.PCollection[EmbeddableItem]): + """Expands the PTransform to convert and write to Milvus. Args: - pcoll: PCollection of Chunk objects to write to Milvus. + pcoll: PCollection of EmbeddableItem objects to write to Milvus. Returns: PCollection of dictionaries representing the records written to Milvus. diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py index c64c083b6c9c..45f33ea2bad5 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py @@ -33,7 +33,7 @@ from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem _LOGGER = logging.getLogger(__name__) @@ -96,7 +96,7 @@ def __init__( *, column_specs: List[ColumnSpec], conflict_resolution: Optional[ConflictResolution] = None): - """Builds SQL queries for writing Chunks with Embeddings to MySQL. + """Builds SQL queries for writing EmbeddableItems with Embeddings to MySQL. """ self.table_name = table_name @@ -132,9 +132,9 @@ def build_insert(self) -> str: _LOGGER.info("MySQL Query with placeholders %s", query) return query - def create_converter(self) -> Callable[[Chunk], NamedTuple]: - """Creates a function to convert Chunks to records.""" - def convert(chunk: Chunk) -> self.record_type: # type: ignore + def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]: + """Creates a function to convert EmbeddableItems to records.""" + def convert(chunk: EmbeddableItem) -> self.record_type: # type: ignore return self.record_type( **{col.column_name: col.value_fn(chunk) for col in self.column_specs}) # type: ignore @@ -167,7 +167,8 @@ def __init__( column_specs: Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written to the database - schema. If None, uses default Chunk schema with MySQL vector + schema. If None, uses default EmbeddableItem + schema with MySQL vector functions. conflict_resolution: Optional :class:`~.mysql_common.ConflictResolution` @@ -248,7 +249,7 @@ def __init__(self, config: MySQLVectorWriterConfig): self.connection_config = config.connection_config self.write_config = config.write_config - def expand(self, pcoll: beam.PCollection[Chunk]): + def expand(self, pcoll: beam.PCollection[EmbeddableItem]): return ( pcoll | diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py index c1ee703a5f2e..6fb88b637edb 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py @@ -24,17 +24,17 @@ from typing import Optional from typing import Type -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem -def chunk_embedding_fn(chunk: Chunk) -> str: +def chunk_embedding_fn(chunk: EmbeddableItem) -> str: """Convert embedding to MySQL vector string format. - + Formats dense embedding as a MySQL-compatible vector string. Example: [1.0, 2.0] -> '[1.0,2.0]' - + Args: - chunk: Input Chunk object. + chunk: Input EmbeddableItem object. Returns: str: MySQL vector string representation of the embedding. @@ -49,73 +49,76 @@ def chunk_embedding_fn(chunk: Chunk) -> str: @dataclass class ColumnSpec: - """Specification for mapping Chunk fields to MySQL columns for insertion. - - Defines how to extract and format values from Chunks into MySQL database - columns, handling the full pipeline from Python value to SQL insertion. - - The insertion process works as follows: - - value_fn extracts a value from the Chunk and formats it as needed - - The value is stored in a NamedTuple field with the specified python_type - - During SQL insertion, the value is bound to a ? placeholder - - Attributes: - column_name: The column name in the database table. - python_type: Python type for the NamedTuple field that will hold the - value. Must be compatible with - :class:`~apache_beam.coders.row_coder.RowCoder`. - value_fn: Function to extract and format the value from a Chunk. - Takes a Chunk and returns a value of python_type. - placeholder: Optional placeholder to apply typecasts or functions to - value ? placeholder e.g. "string_to_vector(?)" for vector columns. - - Examples: - - Basic text column (uses standard JDBC type mapping): - - >>> ColumnSpec.text( - ... column_name="content", - ... value_fn=lambda chunk: chunk.content.text - ... ) - ... # Results in: INSERT INTO table (content) VALUES (?) - - Timestamp from metadata: - - >>> ColumnSpec( - ... column_name="created_at", - ... python_type=str, - ... value_fn=lambda chunk: chunk.metadata.get("timestamp") - ... ) - ... # Results in: INSERT INTO table (created_at) VALUES (?) - - - Factory Methods: - text: Creates a text column specification. - integer: Creates an integer column specification. - float: Creates a float column specification. - vector: Creates a vector column specification with string_to_vector(). - json: Creates a JSON column specification. + """Mapping of EmbeddableItem fields to SQL columns for insertion. + + Defines how to extract and format values from EmbeddableItems into MySQL + database columns, handling the full pipeline from Python value to SQL + insertion. + + The insertion process works as follows: + - value_fn extracts a value from the EmbeddableItem and formats it as needed + - The value is stored in a NamedTuple field with the specified python_type + - During SQL insertion, the value is bound to a ? placeholder + + Attributes: + column_name: The column name in the database table. + python_type: :class:`~apache_beam.coders.row_coder.RowCoder` compatible + python type. + value_fn: Function to extract and format the value from an + EmbeddableItem. + placeholder: Optional placeholder to apply typecasts or functions to + value ? placeholder e.g. "string_to_vector(?)" for vector columns. + + Examples: + + Basic text column (uses standard JDBC type mapping): + + >>> ColumnSpec.text( + ... column_name="content", + ... value_fn=lambda chunk: chunk.content.text + ... ) + ... # Results in: INSERT INTO table (content) VALUES (?) + + Timestamp from metadata: + + >>> ColumnSpec( + ... column_name="created_at", + ... python_type=str, + ... value_fn=lambda chunk: chunk.metadata.get("timestamp") + ... ) + ... # Results in: INSERT INTO table (created_at) VALUES (?) + + + Factory Methods: + text: Creates a text column specification. + integer: Creates an integer column specification. + float: Creates a float column specification. + vector: Creates a vector column specification with string_to_vector(). + json: Creates a JSON column specification. """ column_name: str python_type: Type - value_fn: Callable[[Chunk], Any] + value_fn: Callable[[EmbeddableItem], Any] placeholder: str = '?' @classmethod def text( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a text column specification.""" return cls(column_name, str, value_fn) @classmethod def integer( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create an integer column specification.""" return cls(column_name, int, value_fn) @classmethod def float( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a float column specification.""" return cls(column_name, float, value_fn) @@ -123,13 +126,15 @@ def float( def vector( cls, column_name: str, - value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec': + value_fn: Callable[[EmbeddableItem], Any] = chunk_embedding_fn + ) -> 'ColumnSpec': """Create a vector column specification with string_to_vector() function.""" return cls(column_name, str, value_fn, "string_to_vector(?)") @classmethod def json( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a JSON column specification.""" return cls(column_name, str, value_fn) @@ -175,7 +180,7 @@ def with_id_spec( ... convert_fn=lambda id: int(id.split('_')[1]) ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: value = chunk.id return convert_fn(value) if convert_fn else value @@ -209,7 +214,7 @@ def with_content_spec( ... convert_fn=len # Store content length instead of content ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if chunk.content.text is None: raise ValueError(f'Expected chunk to contain content. {chunk}') value = chunk.content.text @@ -245,7 +250,7 @@ def with_metadata_spec( ... convert_fn=lambda meta: ','.join(meta.keys()) ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if convert_fn: return convert_fn(chunk.metadata) return json.dumps( @@ -279,7 +284,7 @@ def with_embedding_spec( ... for x in values) + ']' ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if chunk.embedding is None or chunk.embedding.dense_embedding is None: raise ValueError(f'Expected chunk to contain embedding. {chunk}') values = chunk.embedding.dense_embedding @@ -339,7 +344,7 @@ def add_metadata_field( """ name = column_name or field - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: value = chunk.metadata.get(field, default) if value is not None and convert_fn is not None: value = convert_fn(value) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres.py index 045579a73d28..b01e450e9bec 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres.py @@ -32,7 +32,7 @@ from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem _LOGGER = logging.getLogger(__name__) @@ -46,7 +46,7 @@ def __init__( *, column_specs: List[ColumnSpec], conflict_resolution: Optional[ConflictResolution] = None): - """Builds SQL queries for writing Chunks with Embeddings to Postgres. + """Builds SQL queries for writing EmbeddableItems to Postgres. """ self.table_name = table_name @@ -93,9 +93,9 @@ def build_insert(self) -> str: _LOGGER.info("Query with placeholders %s", query) return query - def create_converter(self) -> Callable[[Chunk], NamedTuple]: - """Creates a function to convert Chunks to records.""" - def convert(chunk: Chunk) -> self.record_type: # type: ignore + def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]: + """Creates a function to convert EmbeddableItems to records.""" + def convert(chunk: EmbeddableItem) -> self.record_type: # type: ignore return self.record_type( **{col.column_name: col.value_fn(chunk) for col in self.column_specs}) # type: ignore @@ -129,7 +129,7 @@ def __init__( column_specs: Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how embeddings and metadata are written a database - schema. If None, uses default Chunk schema. + schema. If None, uses default EmbeddableItem schema. conflict_resolution: Optional :class:`~.postgres_common.ConflictResolution` strategy for handling insert conflicts. ON CONFLICT DO NOTHING by @@ -189,7 +189,7 @@ def __init__(self, config: PostgresVectorWriterConfig): self.connection_config = config.connection_config self.write_config = config.write_config - def expand(self, pcoll: beam.PCollection[Chunk]): + def expand(self, pcoll: beam.PCollection[EmbeddableItem]): return ( pcoll | diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py index 93968564f156..d789c25bb092 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -26,17 +26,17 @@ from typing import Type from typing import Union -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem -def chunk_embedding_fn(chunk: Chunk) -> str: +def chunk_embedding_fn(chunk: EmbeddableItem) -> str: """Convert embedding to PostgreSQL array string. Formats dense embedding as a PostgreSQL-compatible array string. Example: [1.0, 2.0] -> '{1.0,2.0}' Args: - chunk: Input Chunk object. + chunk: Input EmbeddableItem object. Returns: str: PostgreSQL array string representation of the embedding. @@ -51,65 +51,65 @@ def chunk_embedding_fn(chunk: Chunk) -> str: @dataclass class ColumnSpec: - """Specification for mapping Chunk fields to SQL columns for insertion. - - Defines how to extract and format values from Chunks into database columns, - handling the full pipeline from Python value to SQL insertion. - - The insertion process works as follows: - - value_fn extracts a value from the Chunk and formats it as needed - - The value is stored in a NamedTuple field with the specified python_type - - During SQL insertion, the value is bound to a ? placeholder - - Attributes: - column_name: The column name in the database table. - python_type: Python type for the NamedTuple field that will hold the - value. Must be compatible with must be compatible with - :class:`~apache_beam.coders.row_coder.RowCoder`. - value_fn: Function to extract and format the value from a Chunk. - Takes a Chunk and returns a value of python_type. - sql_typecast: Optional SQL type cast to append to the ? placeholder. - Common examples: - - "::float[]" for vector arrays - - "::jsonb" for JSON data - - Examples: - Basic text column (uses standard JDBC type mapping): - >>> ColumnSpec.text( - ... column_name="content", - ... value_fn=lambda chunk: chunk.content.text - ... ) - # Results in: INSERT INTO table (content) VALUES (?) - - Vector column with explicit array casting: - >>> ColumnSpec.vector( - ... column_name="embedding", - ... value_fn=lambda chunk: '{' + - ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' - ... ) - # Results in: INSERT INTO table (embedding) VALUES (?::float[]) - # The value_fn formats [1.0, 2.0] as '{1.0,2.0}' for PostgreSQL array - - Timestamp from metadata with explicit casting: - >>> ColumnSpec( - ... column_name="created_at", - ... python_type=str, - ... value_fn=lambda chunk: chunk.metadata.get("timestamp"), - ... sql_typecast="::timestamp" - ... ) - # Results in: INSERT INTO table (created_at) VALUES (?::timestamp) - # Allows inserting string timestamps with proper PostgreSQL casting - - Factory Methods: - text: Creates a text column specification (no type cast). - integer: Creates an integer column specification (no type cast). - float: Creates a float column specification (no type cast). - vector: Creates a vector column specification with float[] casting. - jsonb: Creates a JSONB column specification with jsonb casting. - """ + """Mapping of EmbeddableItem fields to SQL columns for insertion. + + Defines how to extract and format values from EmbeddableItems into + database columns, handling the full pipeline from Python value to + SQL insertion. + + The insertion process works as follows: + - value_fn extracts a value from the EmbeddableItem and formats it as needed + - The value is stored in a NamedTuple field with the specified python_type + - During SQL insertion, the value is bound to a ? placeholder + + Attributes: + column_name: The column name in the database table. + python_type: :class:`~apache_beam.coders.row_coder.RowCoder` compatible + python type. + value_fn: Function to extract and format the value from an + EmbeddableItem. + sql_typecast: Optional SQL type cast to append to the ? placeholder. + Common examples: + - "::float[]" for vector arrays + - "::jsonb" for JSON data + + Examples: + Basic text column (uses standard JDBC type mapping): + >>> ColumnSpec.text( + ... column_name="content", + ... value_fn=lambda chunk: chunk.content.text + ... ) + # Results in: INSERT INTO table (content) VALUES (?) + + Vector column with explicit array casting: + >>> ColumnSpec.vector( + ... column_name="embedding", + ... value_fn=lambda chunk: '{' + + ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' + ... ) + # Results in: INSERT INTO table (embedding) VALUES (?::float[]) + # The value_fn formats [1.0, 2.0] as '{1.0,2.0}' for PostgreSQL array + + Timestamp from metadata with explicit casting: + >>> ColumnSpec( + ... column_name="created_at", + ... python_type=str, + ... value_fn=lambda chunk: chunk.metadata.get("timestamp"), + ... sql_typecast="::timestamp" + ... ) + # Results in: INSERT INTO table (created_at) VALUES (?::timestamp) + # Allows inserting string timestamps with proper PostgreSQL casting + + Factory Methods: + text: Creates a text column specification (no type cast). + integer: Creates an integer column specification (no type cast). + float: Creates a float column specification (no type cast). + vector: Creates a vector column specification with float[] casting. + jsonb: Creates a JSONB column specification with jsonb casting. + """ column_name: str python_type: Type - value_fn: Callable[[Chunk], Any] + value_fn: Callable[[EmbeddableItem], Any] sql_typecast: Optional[str] = None @property @@ -119,19 +119,22 @@ def placeholder(self) -> str: @classmethod def text( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a text column specification.""" return cls(column_name, str, value_fn) @classmethod def integer( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create an integer column specification.""" return cls(column_name, int, value_fn) @classmethod def float( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a float column specification.""" return cls(column_name, float, value_fn) @@ -139,13 +142,15 @@ def float( def vector( cls, column_name: str, - value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec': + value_fn: Callable[[EmbeddableItem], Any] = chunk_embedding_fn + ) -> 'ColumnSpec': """Create a vector column specification.""" return cls(column_name, str, value_fn, "::float[]") @classmethod def jsonb( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + cls, column_name: str, value_fn: Callable[[EmbeddableItem], + Any]) -> 'ColumnSpec': """Create a JSONB column specification.""" return cls(column_name, str, value_fn, "::jsonb") @@ -187,7 +192,7 @@ def with_id_spec( ... convert_fn=lambda id: int(id.split('_')[1]) ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: value = chunk.id return convert_fn(value) if convert_fn else value @@ -224,7 +229,7 @@ def with_content_spec( ... convert_fn=len # Store content length instead of content ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if chunk.content.text is None: raise ValueError(f'Expected chunk to contain content. {chunk}') value = chunk.content.text @@ -264,7 +269,7 @@ def with_metadata_spec( ... sql_typecast="::text[]" ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if convert_fn: return convert_fn(chunk.metadata) return json.dumps( @@ -300,7 +305,7 @@ def with_embedding_spec( ... for x in values) + '}' ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if chunk.embedding is None or chunk.embedding.dense_embedding is None: raise ValueError(f'Expected chunk to contain embedding. {chunk}') values = chunk.embedding.dense_embedding @@ -334,7 +339,7 @@ def with_sparse_embedding_spec( ... convert_fn=lambda sparse: dict(zip(sparse[0], sparse[1])) ... ) """ - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: if chunk.embedding is None or chunk.embedding.sparse_embedding is None: raise ValueError(f'Expected chunk to contain sparse embedding. {chunk}') sparse_embedding = chunk.embedding.sparse_embedding @@ -389,7 +394,7 @@ def add_metadata_field( >>> builder.add_metadata_field( ... field="confidence", - ... python_type=intfloat, + ... python_type=float, ... convert_fn=lambda x: round(float(x), 2), ... default=0.0 ... ) @@ -398,14 +403,14 @@ def add_metadata_field( >>> builder.add_metadata_field( ... field="created_at", - ... python_type=intstr, + ... python_type=str, ... convert_fn=lambda ts: ts.replace('T', ' '), ... sql_typecast="::timestamp" ... ) """ name = column_name or field - def value_fn(chunk: Chunk) -> Any: + def value_fn(chunk: EmbeddableItem) -> Any: value = chunk.metadata.get(field, default) if value is not None and convert_fn is not None: value = convert_fn(value) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py index f79db470bca4..49314ee226c2 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/spanner.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/spanner.py @@ -76,54 +76,55 @@ from apache_beam.coders.row_coder import RowCoder from apache_beam.io.gcp import spanner from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem @dataclass class SpannerColumnSpec: """Column specification for Spanner vector writes. - - Defines how to extract and format values from Chunks for insertion into + + Defines how to extract and format values from EmbeddableItems + for insertion into Spanner table columns. Each spec maps to one column in the target table. - + Attributes: column_name: Name of the Spanner table column python_type: Python type for the NamedTuple field (required for RowCoder) - value_fn: Function to extract value from a Chunk - + value_fn: Function to extract value from an EmbeddableItem + Examples: String column: >>> SpannerColumnSpec( ... column_name="id", ... python_type=str, - ... value_fn=lambda chunk: chunk.id + ... value_fn=lambda embeddable: embeddable.id ... ) - + Array column with conversion: >>> SpannerColumnSpec( ... column_name="embedding", ... python_type=List[float], - ... value_fn=lambda chunk: chunk.embedding.dense_embedding + ... value_fn=lambda embeddable: embeddable.embedding.dense_embedding ... ) """ column_name: str python_type: Type - value_fn: Callable[[Chunk], Any] + value_fn: Callable[[EmbeddableItem], Any] -def _extract_and_convert(extract_fn, convert_fn, chunk): +def _extract_and_convert(extract_fn, convert_fn, embeddable): if convert_fn: - return convert_fn(extract_fn(chunk)) - return extract_fn(chunk) + return convert_fn(extract_fn(embeddable)) + return extract_fn(embeddable) class SpannerColumnSpecsBuilder: """Builder for creating Spanner column specifications. - + Provides a fluent API for defining table schemas and how to populate them - from Chunk objects. Supports standard Chunk fields (id, embedding, content, + from EmbeddableItem objects. Supports standard EmbeddableItem fields (id, embedding, content, metadata) and flattening metadata fields into dedicated columns. - + Example: >>> specs = ( ... SpannerColumnSpecsBuilder() @@ -141,13 +142,13 @@ def __init__(self): @staticmethod def with_defaults() -> 'SpannerColumnSpecsBuilder': """Create builder with default schema. - + Default schema includes: - - id (STRING): Chunk ID + - id (STRING): EmbeddableItem ID - embedding (ARRAY): Dense embedding vector - - content (STRING): Chunk content text + - content (STRING): EmbeddableItem content text - metadata (JSON): Full metadata as JSON - + Returns: Builder with default column specifications """ @@ -187,7 +188,8 @@ def with_id_spec( column_name=column_name, python_type=python_type, value_fn=functools.partial( - _extract_and_convert, lambda chunk: chunk.id, convert_fn))) + _extract_and_convert, lambda embeddable: embeddable.id, + convert_fn))) return self def with_embedding_spec( @@ -219,10 +221,10 @@ def with_embedding_spec( ... convert_fn=lambda vec: [round(x, 4) for x in vec] ... ) """ - def extract_fn(chunk: Chunk) -> List[float]: - if chunk.embedding is None or chunk.embedding.dense_embedding is None: - raise ValueError(f'Chunk must contain embedding: {chunk}') - return chunk.embedding.dense_embedding + def extract_fn(embeddable: EmbeddableItem) -> List[float]: + if not embeddable.dense_embedding: + raise ValueError(f'EmbeddableItem must contain embedding: {embeddable}') + return embeddable.dense_embedding self._specs.append( SpannerColumnSpec( @@ -264,10 +266,10 @@ def with_content_spec( ... convert_fn=lambda text: text[:1000] ... ) """ - def extract_fn(chunk: Chunk) -> str: - if chunk.content.text is None: - raise ValueError(f'Chunk must contain content: {chunk}') - return chunk.content.text + def extract_fn(embeddable: EmbeddableItem) -> str: + if embeddable.content.text is None: + raise ValueError(f'EmbeddableItem must contain content: {embeddable}') + return embeddable.content.text self._specs.append( SpannerColumnSpec( @@ -292,7 +294,7 @@ def with_metadata_spec( Note: Metadata is automatically converted to JSON string using json.dumps() """ - value_fn = lambda chunk: json.dumps(chunk.metadata) + value_fn = lambda embeddable: json.dumps(embeddable.metadata) self._specs.append( SpannerColumnSpec( column_name=column_name, python_type=str, value_fn=value_fn)) @@ -307,11 +309,11 @@ def add_metadata_field( default: Any = None) -> 'SpannerColumnSpecsBuilder': """Flatten a metadata field into its own column. - Extracts a specific field from chunk.metadata and stores it in a + Extracts a specific field from embeddable.metadata and stores it in a dedicated table column. Args: - field: Key in chunk.metadata to extract + field: Key in embeddable.metadata to extract python_type: Python type (must be explicitly specified) column_name: Column name (default: same as field) convert_fn: Optional converter for type casting/transformation @@ -355,8 +357,8 @@ def add_metadata_field( """ name = column_name or field - def value_fn(chunk: Chunk) -> Any: - return chunk.metadata.get(field, default) + def value_fn(embeddable: EmbeddableItem) -> Any: + return embeddable.metadata.get(field, default) self._specs.append( SpannerColumnSpec( @@ -370,7 +372,7 @@ def add_column( self, column_name: str, python_type: Type, - value_fn: Callable[[Chunk], Any]) -> 'SpannerColumnSpecsBuilder': + value_fn: Callable[[EmbeddableItem], Any]) -> 'SpannerColumnSpecsBuilder': """Add a custom column with full control. Args: @@ -386,14 +388,14 @@ def add_column( >>> builder.add_column( ... column_name="has_code", ... python_type=bool, - ... value_fn=lambda chunk: "```" in chunk.content.text + ... value_fn=lambda embeddable: "```" in embeddable.content.text ... ) Computed value: >>> builder.add_column( ... column_name="word_count", ... python_type=int, - ... value_fn=lambda chunk: len(chunk.content.text.split()) + ... value_fn=lambda embeddable: len(embeddable.content.text.split()) ... ) """ self._specs.append( @@ -444,15 +446,15 @@ def __init__(self, table_name: str, column_specs: List[SpannerColumnSpec]): # Register coder registry.register_coder(self.record_type, RowCoder) - def create_converter(self) -> Callable[[Chunk], NamedTuple]: - """Create converter function from Chunk to NamedTuple record. - + def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]: + """Create converter function from EmbeddableItem to NamedTuple record. + Returns: - Function that converts a Chunk to a NamedTuple record + Function that converts an EmbeddableItem to a NamedTuple record """ - def convert(chunk: Chunk) -> self.record_type: # type: ignore + def convert(embeddable: EmbeddableItem) -> self.record_type: # type: ignore values = { - col.column_name: col.value_fn(chunk) + col.column_name: col.value_fn(embeddable) for col in self.column_specs } return self.record_type(**values) # type: ignore @@ -608,11 +610,11 @@ def __init__(self, config: SpannerVectorWriterConfig): self.config = config self.schema_builder = config.schema_builder - def expand(self, pcoll: beam.PCollection[Chunk]): + def expand(self, pcoll: beam.PCollection[EmbeddableItem]): """Expand the transform. - + Args: - pcoll: PCollection of Chunks to write + pcoll: PCollection of EmbeddableItems to write """ # Select appropriate Spanner write transform based on write_mode write_transform_class = { diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 3bb0e01b68cc..6a08bb660518 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -16,9 +16,16 @@ # """Core types for RAG pipelines. + This module contains the core dataclasses used throughout the RAG pipeline -implementation, including Chunk and Embedding types that define the data -contracts between different stages of the pipeline. +implementation. The primary type is EmbeddableItem, which represents any +content that can be embedded and stored in a vector database. + +Types: + - Content: Container for embeddable content + - Embedding: Vector embedding with optional metadata + - EmbeddableItem: Universal container for embeddable content + - Chunk: Alias for EmbeddableItem (backward compatibility) """ import uuid @@ -33,49 +40,88 @@ @dataclass class Content: - """Container for embeddable content. Add new types as when as necessary. + """Container for embeddable content. - Args: - text: Text content to be embedded - """ + Args: + text: Text content to be embedded. + """ text: Optional[str] = None @dataclass class Embedding: - """Represents vector embeddings. + """Represents vector embeddings with optional metadata. - Args: - dense_embedding: Dense vector representation - sparse_embedding: Optional sparse vector representation for hybrid - search - """ + Args: + dense_embedding: Dense vector representation. + sparse_embedding: Optional sparse vector representation for hybrid search. + """ dense_embedding: Optional[List[float]] = None - # For hybrid search sparse_embedding: Optional[Tuple[List[int], List[float]]] = None @dataclass -class Chunk: - """Represents a chunk of embeddable content with metadata. +class EmbeddableItem: + """Universal container for embeddable content. - Args: - content: The actual content of the chunk - id: Unique identifier for the chunk - index: Index of this chunk within the original document - metadata: Additional metadata about the chunk (e.g., document source) - embedding: Vector embeddings of the content - """ + Represents any content that can be embedded and stored in a vector database. + Use factory methods for convenient construction, or construct directly with + a Content object. + + Examples: + Text (via factory): + item = EmbeddableItem.from_text( + "hello world", metadata={'src': 'doc'}) + + Text (direct, equivalent to old Chunk usage): + item = EmbeddableItem(content=Content(text="hello"), index=3) + + Args: + content: The content to embed. + id: Unique identifier. + index: Position within source document (for chunking use cases). + metadata: Additional metadata (e.g., document source, language). + embedding: Embedding populated by the embedding step. + """ content: Content id: str = field(default_factory=lambda: str(uuid.uuid4())) index: int = 0 metadata: Dict[str, Any] = field(default_factory=dict) embedding: Optional[Embedding] = None + @classmethod + def from_text( + cls, + text: str, + *, + id: Optional[str] = None, + index: int = 0, + metadata: Optional[Dict[str, Any]] = None, + ) -> 'EmbeddableItem': + """Create an EmbeddableItem with text content. + + Args: + text: The text content to embed + id: Unique identifier (auto-generated if not provided) + index: Position within source document (for chunking) + metadata: Additional metadata + """ + return cls( + content=Content(text=text), + id=id or str(uuid.uuid4()), + index=index, + metadata=metadata or {}, + ) + @property - def dense_embedding(self): + def dense_embedding(self) -> Optional[List[float]]: return self.embedding.dense_embedding if self.embedding else None @property - def sparse_embedding(self): + def sparse_embedding(self) -> Optional[Tuple[List[int], List[float]]]: return self.embedding.sparse_embedding if self.embedding else None + + +# Backward compatibility alias. Existing code using Chunk continues to work +# unchanged since Chunk IS EmbeddableItem. +Chunk = EmbeddableItem diff --git a/sdks/python/apache_beam/ml/rag/utils.py b/sdks/python/apache_beam/ml/rag/utils.py index d45e99be0ecb..e2d9962467a1 100644 --- a/sdks/python/apache_beam/ml/rag/utils.py +++ b/sdks/python/apache_beam/ml/rag/utils.py @@ -82,8 +82,8 @@ class MilvusHelpers: """Utility class providing helper methods for Milvus vector db operations.""" @staticmethod def sparse_embedding( - sparse_vector: Tuple[List[int], - List[float]]) -> Optional[Dict[int, float]]: + sparse_vector: Optional[Tuple[List[int], List[float]]] + ) -> Optional[Dict[int, float]]: if not sparse_vector: return None # Converts sparse embedding from (indices, values) tuple format to diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index de3e5b0c6a92..f6c15eeac092 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -45,7 +45,7 @@ from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RemoteModelHandler from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import EmbeddableItem from apache_beam.ml.rag.types import Embedding from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import EmbeddingTypeAdapter @@ -316,7 +316,7 @@ class VertexVideo: class VertexAIMultiModalInput: image: Optional[VertexImage] = None video: Optional[VertexVideo] = None - contextual_text: Optional[Chunk] = None + contextual_text: Optional[EmbeddableItem] = None class _VertexAIMultiModalEmbeddingHandler(RemoteModelHandler): @@ -387,7 +387,7 @@ def _multimodal_dict_input_fn( for item in batch: img: Optional[VertexImage] = None vid: Optional[VertexVideo] = None - text: Optional[Chunk] = None + text: Optional[EmbeddableItem] = None if image_column: img = item[image_column] if video_column: @@ -472,8 +472,8 @@ def __init__( is expected to be formatted as VertexVideo objects, containing a Vertex Video object an a VideoSegmentConfig object. text_column: The column containing text data to be embedded. This data is - expected to be formatted as Chunk objects, containing the string to be - embedded in the Chunk's content field. + expected to be formatted as EmbeddableItem objects, containing the string + to be embedded in the item's content field. dimension: The length of the embedding vector to generate. Must be one of 128, 256, 512, or 1408. If not set, Vertex AI's default value is 1408. If submitting video content, dimension *musst* be 1408.