diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py index 83a4fc1a778f..95bb985f52ea 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -16,13 +16,15 @@ """Tests for apache_beam.ml.rag.chunking.langchain.""" +import functools import unittest import apache_beam as beam -from apache_beam.ml.rag.types import Chunk from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import BeamAssertException from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.testing.util import is_not_empty try: from apache_beam.ml.rag.chunking.langchain import LangChainChunker @@ -41,13 +43,10 @@ TRANSFORMERS_AVAILABLE = False -def chunk_equals(expected, actual): - """Custom equality function for Chunk objects.""" - if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): - return False - return ( - expected.content == actual.content and expected.index == actual.index and - expected.metadata == actual.metadata) +def assert_true(elements, assert_fn, error_message_fn): + if not assert_fn(elements): + raise BeamAssertException(error_message_fn(elements)) + return True @unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.') @@ -83,9 +82,15 @@ def test_no_metadata_fields(self): | provider.get_ptransform_for_processing()) chunks_count = chunks | beam.combiners.Count.Globally() - assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks_count, is_not_empty(), 'Has chunks') - assert_that(chunks, lambda x: all(c.metadata == {} for c in x)) + assert_that( + chunks, + functools.partial( + assert_true, + assert_fn=lambda x: (all(c.metadata == {} for c in x)), + error_message_fn=lambda x: f"Expected empty metadata, actual {x}") + ) def test_multiple_metadata_fields(self): """Test chunking with multiple metadata fields.""" @@ -94,6 +99,7 @@ def test_multiple_metadata_fields(self): document_field='content', metadata_fields=['source', 'language'], text_splitter=splitter) + expected_metadata = {'source': 'simple.txt', 'language': 'en'} with TestPipeline() as p: chunks = ( @@ -102,18 +108,20 @@ def test_multiple_metadata_fields(self): | provider.get_ptransform_for_processing()) chunks_count = chunks | beam.combiners.Count.Globally() - assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks_count, is_not_empty(), 'Has chunks') assert_that( chunks, - lambda x: all( - c.metadata == { - 'source': 'simple.txt', 'language': 'en' - } for c in x)) + functools.partial( + assert_true, + assert_fn=lambda x: all( + c.metadata == expected_metadata for c in x), + error_message_fn=lambda x: + f"Expected metadata {expected_metadata}, actual {x}")) def test_recursive_splitter_no_overlap(self): """Test RecursiveCharacterTextSplitter with no overlap.""" splitter = RecursiveCharacterTextSplitter( - chunk_size=30, chunk_overlap=0, separators=[". "]) + chunk_size=30, chunk_overlap=0, separators=[".", " "]) provider = LangChainChunker( document_field='content', metadata_fields=['source'], @@ -126,8 +134,14 @@ def test_recursive_splitter_no_overlap(self): | provider.get_ptransform_for_processing()) chunks_count = chunks | beam.combiners.Count.Globally() - assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') - assert_that(chunks, lambda x: all(len(c.content.text) <= 30 for c in x)) + assert_that(chunks_count, is_not_empty(), 'Has chunks') + assert_that( + chunks, + functools.partial( + assert_true, + assert_fn=lambda x: all(len(c.content.text) <= 30 for c in x), + error_message_fn=lambda x: f"Expected len(chunk) <= 30, \ + actual {[len(c.content.text) for c in x]}")) @unittest.skipIf(not TRANSFORMERS_AVAILABLE, "transformers not available") def test_huggingface_tokenizer_splitter(self): @@ -155,13 +169,13 @@ def check_token_lengths(chunks): # Verify each chunk's token length is within limits num_tokens = len(tokenizer.encode(chunk.content.text)) if not num_tokens <= 10: - raise AssertionError( + raise BeamAssertException( f"Chunk has {num_tokens} tokens, expected <= 10") return True chunks_count = chunks | beam.combiners.Count.Globally() - assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks_count, is_not_empty(), 'Has chunks') assert_that(chunks, check_token_lengths) def test_invalid_document_field(self): diff --git a/sdks/python/setup.py b/sdks/python/setup.py index da9e0b2e7477..5b3e3cab6db7 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -486,6 +486,7 @@ def get_portability_package_data(): 'ml_test': [ 'datatable', 'embeddings', + 'langchain', 'onnxruntime', 'sentence-transformers', 'skl2onnx', @@ -505,6 +506,7 @@ def get_portability_package_data(): 'datatable', 'embeddings', 'onnxruntime', + 'langchain', 'sentence-transformers', 'skl2onnx', 'pillow',