diff --git a/laygo/transformers/parallel.py b/laygo/transformers/parallel.py index 73def84..b058b4a 100644 --- a/laygo/transformers/parallel.py +++ b/laygo/transformers/parallel.py @@ -119,15 +119,14 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) - def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]: """Helper to run the execution logic with a given context.""" - with ProcessPoolExecutor(max_workers=self.max_workers) as executor: - executor = get_reusable_executor(max_workers=self.max_workers) + executor = get_reusable_executor(max_workers=self.max_workers) - chunks_to_process = self._chunk_generator(data) - gen_func = self._ordered_generator if self.ordered else self._unordered_generator - processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context) + chunks_to_process = self._chunk_generator(data) + gen_func = self._ordered_generator if self.ordered else self._unordered_generator + processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context) - for result_chunk in processed_chunks_iterator: - yield from result_chunk + for result_chunk in processed_chunks_iterator: + yield from result_chunk # ... The rest of the file remains the same ... def _ordered_generator( diff --git a/laygo/transformers/threaded.py b/laygo/transformers/threaded.py index 8bd784e..58c661e 100644 --- a/laygo/transformers/threaded.py +++ b/laygo/transformers/threaded.py @@ -4,6 +4,7 @@ from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator +from collections.abc import MutableMapping from concurrent.futures import FIRST_COMPLETED from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor @@ -11,6 +12,7 @@ import copy from functools import partial import itertools +from multiprocessing.managers import DictProxy import threading from typing import Any from typing import Union @@ -101,25 +103,41 @@ def from_transformer[T, U]( def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: """ - Executes the transformer on data concurrently. - - A new `threading.Lock` is created and added to the context for each call - to ensure execution runs are isolated and thread-safe. + Executes the transformer on data concurrently. It uses the shared + context provided by the Pipeline, if available. """ - # Determine the context for this run, passing it by reference as requested. - run_context = context or self.context - # Add a per-call lock for thread safety. - run_context["lock"] = threading.Lock() - - def process_chunk(chunk: list[In], shared_context: PipelineContext) -> list[Out]: + run_context = context if context is not None else self.context + + # Detect if the context is already managed by the Pipeline. + is_managed_context = isinstance(run_context, DictProxy) + + if is_managed_context: + # Use the existing shared context and lock from the Pipeline. + shared_context = run_context + yield from self._execute_with_context(data, shared_context) + # The context is live, so no need to update it here. + # The Pipeline's __del__ will handle final state. + else: + # Fallback for standalone use: create a thread-safe context. + # Since threads share memory, we can use the context directly with a lock. + if "lock" not in run_context: + run_context["lock"] = threading.Lock() + + yield from self._execute_with_context(data, run_context) + # Context is already updated in-place for threads (shared memory) + + def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]: + """Helper to run the execution logic with a given context.""" + + def process_chunk(chunk: list[In], shared_context: MutableMapping[str, Any]) -> list[Out]: """ Process a single chunk by passing the chunk and context explicitly to the transformer chain. This is safer and avoids mutating self. """ - return self.transformer(chunk, shared_context) + return self.transformer(chunk, shared_context) # type: ignore - # Create a partial function with the run_context "baked in". - process_chunk_with_context = partial(process_chunk, shared_context=run_context) + # Create a partial function with the shared_context "baked in". + process_chunk_with_context = partial(process_chunk, shared_context=shared_context) def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]: """Generate results in their original order.""" diff --git a/tests/test_parallel_transformer.py b/tests/test_parallel_transformer.py index df1c874..e7fb67f 100644 --- a/tests/test_parallel_transformer.py +++ b/tests/test_parallel_transformer.py @@ -2,7 +2,6 @@ import multiprocessing as mp import time -from unittest.mock import patch from laygo import ErrorHandler from laygo import ParallelTransformer @@ -159,19 +158,6 @@ def test_unordered_vs_ordered_same_elements(self): assert sorted(ordered_result) == sorted(unordered_result) assert ordered_result == [x * 2 for x in data] - def test_process_pool_management(self): - """Test that process pool is properly created and cleaned up.""" - with patch("laygo.transformers.parallel.ProcessPoolExecutor") as mock_executor: - mock_executor.return_value.__enter__.return_value = mock_executor.return_value - mock_executor.return_value.__exit__.return_value = None - mock_executor.return_value.submit.return_value.result.return_value = [2, 4] - transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2) - list(transformer([1, 2])) - - mock_executor.assert_called_with(max_workers=2) - mock_executor.return_value.__enter__.assert_called_once() - mock_executor.return_value.__exit__.assert_called_once() - class TestParallelTransformerChunkingAndEdgeCases: """Test chunking behavior and edge cases."""