Skip to content
49 changes: 49 additions & 0 deletions efemel/pipeline/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections.abc import Callable
import inspect
from typing import Any
from typing import TypeGuard


# --- Type Aliases ---
class PipelineContext(dict):
"""Generic, untyped context available to all pipeline operations."""

pass


# Define the specific callables for clarity
ContextAwareCallable = Callable[[Any, PipelineContext], Any]
ContextAwareReduceCallable = Callable[[Any, Any, PipelineContext], Any]


def get_function_param_count(func: Callable[..., Any]) -> int:
"""
Returns the number of parameters a function accepts, excluding `self` or `cls`.
This is useful for determining if a function is context-aware.
"""
try:
sig = inspect.signature(func)
params = [p for p in sig.parameters.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)]
return len(params)
except (ValueError, TypeError):
return 0


def is_context_aware(func: Callable[..., Any]) -> TypeGuard[ContextAwareCallable]:
"""
Checks if a function is "context-aware" by inspecting its signature.

This function uses a TypeGuard, allowing Mypy to narrow the type of
the checked function in conditional blocks.
"""
return get_function_param_count(func) >= 2


def is_context_aware_reduce(func: Callable[..., Any]) -> TypeGuard[ContextAwareReduceCallable]:
"""
Checks if a function is "context-aware" by inspecting its signature.

This function uses a TypeGuard, allowing Mypy to narrow the type of
the checked function in conditional blocks.
"""
return get_function_param_count(func) >= 3
51 changes: 46 additions & 5 deletions efemel/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import Any
from typing import TypedDict
from typing import TypeVar
from typing import overload

from efemel.pipeline.helpers import is_context_aware

from .transformers.transformer import Transformer

Expand All @@ -28,15 +31,53 @@ class Pipeline[T]:
def __init__(self, *data: Iterable[T]):
self.data_source: Iterable[T] = itertools.chain.from_iterable(data) if len(data) > 1 else data[0]
self.processed_data: Iterator = iter(self.data_source)
self.ctx = PipelineContext()

def apply[U](self, transformer: Transformer[T, U] | Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]":
def context(self, ctx: PipelineContext) -> "Pipeline[T]":
"""
Sets the context for the pipeline.
"""
self.ctx = ctx
return self

@overload
def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ...

@overload
def apply[U](self, transformer: Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]": ...

@overload
def apply[U](
self,
transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]],
) -> "Pipeline[U]": ...

def apply[U](
self,
transformer: Transformer[T, U]
| Callable[[Iterable[T]], Iterator[U]]
| Callable[[Iterable[T], PipelineContext], Iterator[U]],
) -> "Pipeline[U]":
"""
Applies a transformer to the current data source.
"""
# The transformer is called with the current processed data, producing a new iterator
new_data = transformer(self.processed_data)
# Create a new pipeline with the transformed data
self.processed_data = new_data

match transformer:
case Transformer():
# If a Transformer instance is provided, use its __call__ method
self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore
case _ if callable(transformer):
# If a callable function is provided, call it with the current data and context

if is_context_aware(transformer):
processed_transformer = transformer
else:
processed_transformer = lambda data, ctx: transformer(data) # type: ignore # noqa: E731

self.processed_data = processed_transformer(self.processed_data, self.ctx) # type: ignore
case _:
raise TypeError("Transformer must be a Transformer instance or a callable function")

return self # type: ignore

def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
Expand Down
147 changes: 147 additions & 0 deletions efemel/pipeline/transformers/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Parallel transformer implementation using multiple threads."""

from collections import deque
from collections.abc import Iterable
from collections.abc import Iterator
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
import copy
from functools import partial
import itertools
import threading
from typing import TypedDict

from .transformer import DEFAULT_CHUNK_SIZE
from .transformer import InternalTransformer
from .transformer import PipelineContext
from .transformer import Transformer


# --- Type Definitions ---
class ParallelPipelineContextType(TypedDict):
"""A specific context type for parallel transformers that includes a lock."""

lock: threading.Lock


# --- Class Definition ---
class ParallelTransformer[In, Out](Transformer[In, Out]):
"""
A transformer that executes operations concurrently using multiple threads.
"""

def __init__(
self,
max_workers: int = 4,
ordered: bool = True,
chunk_size: int = DEFAULT_CHUNK_SIZE,
transformer: InternalTransformer[In, Out] | None = None,
):
"""
Initialize the parallel transformer.

Args:
max_workers: Maximum number of worker threads.
ordered: If True, results are yielded in order. If False, results
are yielded as they complete.
chunk_size: Size of data chunks to process.
transformer: The transformation logic chain.
"""
super().__init__(chunk_size, transformer)
self.max_workers = max_workers
self.ordered = ordered

@classmethod
def from_transformer[T, U](
cls,
transformer: Transformer[T, U],
chunk_size: int | None = None,
max_workers: int = 4,
ordered: bool = True,
) -> "ParallelTransformer[T, U]":
"""
Create a ParallelTransformer from an existing Transformer's logic.

Args:
transformer: The base transformer to copy the transformation logic from.
chunk_size: Optional chunk size override.
max_workers: Maximum number of worker threads.
ordered: If True, results are yielded in order.

Returns:
A new ParallelTransformer with the same transformation logic.
"""
return cls(
chunk_size=chunk_size or transformer.chunk_size,
transformer=copy.deepcopy(transformer.transformer), # type: ignore
max_workers=max_workers,
ordered=ordered,
)

def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
"""
Executes the transformer on data concurrently.

A new `threading.Lock` is created and added to the context for each call
to ensure execution runs are isolated and thread-safe.
"""
# Determine the context for this run, passing it by reference as requested.
run_context = context or self.context
# Add a per-call lock for thread safety.
run_context["lock"] = threading.Lock()

def process_chunk(chunk: list[In], shared_context: PipelineContext) -> list[Out]:
"""
Process a single chunk by passing the chunk and context explicitly
to the transformer chain. This is safer and avoids mutating self.
"""
return self.transformer(chunk, shared_context)

# Create a partial function with the run_context "baked in".
process_chunk_with_context = partial(process_chunk, shared_context=run_context)

def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
"""Generate results in their original order."""
futures: deque[Future[list[Out]]] = deque()
for _ in range(self.max_workers + 1):
try:
chunk = next(chunks_iter)
futures.append(executor.submit(process_chunk_with_context, chunk))
except StopIteration:
break
while futures:
yield futures.popleft().result()
try:
chunk = next(chunks_iter)
futures.append(executor.submit(process_chunk_with_context, chunk))
except StopIteration:
continue

def _unordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
"""Generate results as they complete."""
futures = {
executor.submit(process_chunk_with_context, chunk)
for chunk in itertools.islice(chunks_iter, self.max_workers + 1)
}
while futures:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
yield future.result()
try:
chunk = next(chunks_iter)
futures.add(executor.submit(process_chunk_with_context, chunk))
except StopIteration:
continue

def result_iterator_manager() -> Iterator[Out]:
"""Manage the thread pool and yield flattened results."""
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
chunks_to_process = self._chunk_generator(data)
gen_func = _ordered_generator if self.ordered else _unordered_generator
processed_chunks_iterator = gen_func(chunks_to_process, executor)
for result_chunk in processed_chunks_iterator:
yield from result_chunk

return result_iterator_manager()
Loading
Loading