Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions sdks/python/apache_beam/ml/rag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""RAG-specific embedding adapters.

This module provides adapters for extracting content from
EmbeddableItem instances and mapping embeddings back. Adapters
are used by EmbeddingsManager to support various input types.
"""

from collections.abc import Sequence
from typing import List

from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import EmbeddingTypeAdapter


def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]:
"""Creates adapter for converting between Chunk and Embedding types.
The adapter:
- Extracts text from Chunk.content.text for embedding
- Creates Embedding objects from model output
- Sets Embedding in Chunk.embedding
Returns:
EmbeddingTypeAdapter configured for RAG pipeline types
"""
def create_text_adapter(
) -> EmbeddingTypeAdapter[EmbeddableItem, EmbeddableItem]:
"""Creates adapter for text content embedding.

Works with any EmbeddableItem that has text content
(content.text). Extracts text for embedding and maps
results back as Embedding objects.

Returns:
EmbeddingTypeAdapter configured for text embedding
"""
return EmbeddingTypeAdapter(
input_fn=_extract_chunk_text, output_fn=_add_embedding_fn)
input_fn=_extract_text, output_fn=_add_embedding_fn)


# Backward compatibility alias.
create_rag_adapter = create_text_adapter


def _extract_chunk_text(chunks: Sequence[Chunk]) -> List[str]:
"""Extract text from chunks for embedding."""
chunk_texts = []
for chunk in chunks:
if not chunk.content.text:
raise ValueError("Expected chunk text content.")
chunk_texts.append(chunk.content.text)
return chunk_texts
def _extract_text(items: Sequence[EmbeddableItem]) -> List[str]:
"""Extract text from items for embedding."""
texts = []
for item in items:
if not item.content.text:
raise ValueError(
f"Expected text content in {type(item).__name__} {item.id}, "
"got None")
texts.append(item.content.text)
return texts


def _add_embedding_fn(
chunks: Sequence[Chunk], embeddings: Sequence[List[float]]) -> List[Chunk]:
"""Create Embeddings from chunks and embedding vectors."""
for chunk, embedding in zip(chunks, embeddings):
chunk.embedding = Embedding(dense_embedding=embedding)
return list(chunks)
items: Sequence[EmbeddableItem],
embeddings: Sequence[List[float]]) -> List[EmbeddableItem]:
"""Create Embeddings from items and embedding vectors."""
for item, embedding in zip(items, embeddings):
item.embedding = Embedding(dense_embedding=embedding)
return list(items)
10 changes: 5 additions & 5 deletions sdks/python/apache_beam/ml/rag/embeddings/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest

from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.embeddings.base import create_text_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
Expand All @@ -41,18 +41,18 @@ def setUp(self):

def test_adapter_input_conversion(self):
"""Test the RAG type adapter converts correctly."""
adapter = create_rag_adapter()
adapter = create_text_adapter()

# Test input conversion
texts = adapter.input_fn(self.test_chunks)
self.assertEqual(texts, ["This is a test sentence.", "Another example."])

def test_adapter_input_conversion_missing_text_content(self):
"""Test the RAG type adapter converts correctly."""
adapter = create_rag_adapter()
adapter = create_text_adapter()

# Test input conversion
with self.assertRaisesRegex(ValueError, "Expected chunk text content"):
with self.assertRaisesRegex(ValueError, "Expected text content"):
adapter.input_fn([
Chunk(
content=Content(),
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_adapter_output_conversion(self):
},
content=Content(text='Another example.')),
]
adapter = create_rag_adapter()
adapter = create_text_adapter()

embeddings = adapter.output_fn(self.test_chunks, mock_embeddings)
self.assertListEqual(embeddings, expected)
Expand Down
11 changes: 6 additions & 5 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.embeddings.base import create_text_adapter
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
"sentence-transformers is required to use "
"HuggingfaceTextEmbeddings."
"Please install it with using `pip install sentence-transformers`.")
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
super().__init__(type_adapter=create_text_adapter(), **kwargs)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.model_class = SentenceTransformer
Expand All @@ -67,8 +67,9 @@ def get_model_handler(self):

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
) -> beam.PTransform[beam.PCollection[EmbeddableItem],
beam.PCollection[EmbeddableItem]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
inference_args=self.inference_args).with_output_types(EmbeddableItem)
11 changes: 6 additions & 5 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.embeddings.base import create_text_adapter
from apache_beam.ml.rag.types import EmbeddableItem
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
"vertexai is required to use VertexAITextEmbeddings. "
"Please install it with `pip install google-cloud-aiplatform`")

super().__init__(type_adapter=create_rag_adapter(), **kwargs)
super().__init__(type_adapter=create_text_adapter(), **kwargs)
self.model_name = model_name
self.title = title
self.task_type = task_type
Expand All @@ -90,8 +90,9 @@ def get_model_handler(self):

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
) -> beam.PTransform[beam.PCollection[EmbeddableItem],
beam.PCollection[EmbeddableItem]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
inference_args=self.inference_args).with_output_types(EmbeddableItem)
Loading
Loading