diff --git a/efemel/pipeline.py b/efemel/pipeline.py index 34492d5..aea2baa 100644 --- a/efemel/pipeline.py +++ b/efemel/pipeline.py @@ -13,10 +13,10 @@ from concurrent.futures import FIRST_COMPLETED from concurrent.futures import ThreadPoolExecutor from concurrent.futures import wait -from itertools import chain from typing import Any from typing import Self from typing import TypeVar +from typing import Union from typing import overload T = TypeVar("T") # Type variable for the elements in the pipeline @@ -40,7 +40,7 @@ class Pipeline[T]: generator: Generator[list[T], None, None] - def __init__(self, source: Iterable[T], chunk_size: int = 1000) -> None: + def __init__(self, source: Union[Iterable[T], "Pipeline[T]"], chunk_size: int = 1000) -> None: """ Initialize a new Pipeline with the given data source. @@ -49,11 +49,13 @@ def __init__(self, source: Iterable[T], chunk_size: int = 1000) -> None: If source is another Pipeline, it will be efficiently composed. chunk_size: Number of elements per chunk (default: 1000) """ - if isinstance(source, Pipeline): - # If source is another Pipeline, use its generator directly to avoid double-chunking - self.generator = source.generator - else: - self.generator = self._chunked(source, chunk_size) + match source: + case Pipeline(): + # If source is already a Pipeline, we can use its generator directly + self.generator = source.generator + case Iterable(): + # If source is an iterable, we will chunk it + self.generator = self._chunked(source, chunk_size) self.chunk_size = chunk_size @@ -97,38 +99,35 @@ def from_pipeline(cls, pipeline: "Pipeline[T]") -> "Pipeline[T]": def __iter__(self) -> Generator[T, None, None]: """Iterate over elements by flattening chunks.""" - for chunk in self.generator: - yield from chunk + return (item for chunk in self.generator for item in chunk) def to_list(self) -> list[T]: """Convert the pipeline to a list by concatenating all chunks.""" - result = [] - for chunk in self.generator: - result.extend(chunk) - return result + return [item for chunk in self.generator for item in chunk] def first(self) -> T: """Get the first element from the pipeline.""" - for chunk in self.generator: - if chunk: - return chunk[0] - raise StopIteration("Pipeline is empty") + item = next(self.generator, None) + + if item is None: + raise StopIteration("Pipeline is empty") + + return item.pop(0) def filter(self, predicate: Callable[[T], bool]) -> "Pipeline[T]": """Filter elements using a predicate, applied per chunk.""" - def filter_chunk(chunk: list[T]) -> list[T]: - return [x for x in chunk if predicate(x)] - - return Pipeline._from_chunks((filter_chunk(chunk) for chunk in self.generator), self.chunk_size) + return Pipeline._from_chunks( + ([x for x in chunk if predicate(x)] for chunk in self.generator), + self.chunk_size, + ) def map(self, function: Callable[[T], U]) -> "Pipeline[U]": """Transform elements using a function, applied per chunk.""" - - def map_chunk(chunk: list[T]) -> list[U]: - return [function(x) for x in chunk] - - return Pipeline._from_chunks((map_chunk(chunk) for chunk in self.generator), self.chunk_size) + return Pipeline._from_chunks( + ([function(x) for x in chunk] for chunk in self.generator), + self.chunk_size, + ) def reduce(self, function: Callable[[U, T], U], initial: U) -> "Pipeline[U]": """Reduce elements to a single value using the given function.""" @@ -142,34 +141,26 @@ def tap(self, function: Callable[[T], Any]) -> Self: """Apply side effect to each element without modifying data.""" def tap_chunk(chunk: list[T]) -> list[T]: - for item in chunk: - function(item) - return chunk + return [item for item in chunk if function(item) or True] - return Pipeline._from_chunks((tap_chunk(chunk) for chunk in self.generator), self.chunk_size) + return Pipeline._from_chunks(tap_chunk(chunk) for chunk in self.generator) def each(self, function: Callable[[T], Any]) -> None: """Apply function to each element (terminal operation).""" - for chunk in self.generator: - for item in chunk: - function(item) + deque((function(item) for chunk in self.generator for item in chunk), maxlen=0) def noop(self) -> None: """Consume the pipeline without any operation.""" # Consume all elements in the pipeline without any operation - for _ in chain.from_iterable(self.generator): - continue + deque(self.generator, maxlen=0) def passthrough(self) -> Self: """Return the pipeline unchanged (identity operation).""" return self - def apply(self, *functions: Callable[[Self], "Pipeline[U]"]) -> "Pipeline[U]": + def apply(self, function: Callable[[Self], "Pipeline[U]"]) -> "Pipeline[U]": """Apply sequence of transformation functions.""" - result: Pipeline[Any] = self - for function in functions: - result = function(result) - return result + return function(self) @overload def flatten(self: "Pipeline[list[U]]") -> "Pipeline[U]": ... @@ -181,7 +172,7 @@ def flatten(self: "Pipeline[tuple[U, ...]]") -> "Pipeline[U]": ... def flatten(self: "Pipeline[set[U]]") -> "Pipeline[U]": ... def flatten( - self: "Pipeline[list[U]] | Pipeline[tuple[U, ...]] | Pipeline[set[U]]", + self: Union["Pipeline[list[U]]", "Pipeline[tuple[U, ...]]", "Pipeline[set[U]]"], ) -> "Pipeline[Any]": """Flatten iterable chunks into a single pipeline of elements. @@ -201,9 +192,7 @@ def flatten( def flatten_generator() -> Generator[Any, None, None]: """Generator that yields individual flattened items.""" - for chunk in self.generator: - for iterable in chunk: - yield from iterable + return (item for chunk in self.generator for iterable in chunk for item in iterable) # Re-chunk the flattened stream to maintain consistent chunk size return Pipeline._from_chunks(self._chunked(flatten_generator(), self.chunk_size), self.chunk_size) @@ -312,6 +301,10 @@ def chain_generator(): for pipeline in pipelines: yield from pipeline.generator + # chain preserves chunk structure exactly (just concatenates generators) # Use chunk_size from the first pipeline, or default if no pipelines chunk_size = pipelines[0].chunk_size if pipelines else 1000 - return cls._from_chunks(chain_generator(), chunk_size) + new_pipeline = cls.__new__(cls) + new_pipeline.generator = chain_generator() + new_pipeline.chunk_size = chunk_size + return new_pipeline diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cb6baf9..261d971 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -364,7 +364,7 @@ def test_passthrough(self): # Data should be unchanged assert result.to_list() == [1, 2, 3, 4, 5] - def test_apply_single_function(self): + def test_apply_function(self): """Test apply with single function.""" pipeline = Pipeline([1, 2, 3, 4, 5]) @@ -374,27 +374,6 @@ def double_pipeline(p): result = pipeline.apply(double_pipeline) assert result.to_list() == [2, 4, 6, 8, 10] - def test_apply_multiple_functions(self): - """Test apply with multiple functions.""" - pipeline = Pipeline([1, 2, 3, 4, 5]) - - def double_pipeline(p): - return p.map(lambda x: x * 2) - - def filter_even(p): - return p.filter(lambda x: x % 2 == 0) - - result = pipeline.apply(double_pipeline, filter_even) - assert result.to_list() == [2, 4, 6, 8, 10] - - def test_apply_no_functions(self): - """Test apply with no functions.""" - pipeline = Pipeline([1, 2, 3, 4, 5]) - - # Should return the same pipeline - result = pipeline.apply() - assert result.to_list() == [1, 2, 3, 4, 5] - def test_flatten_basic(self): """Test basic flatten operation.""" pipeline = Pipeline([[1, 2], [3, 4], [5]])