Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions laygo/transformers/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,14 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -

def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]:
"""Helper to run the execution logic with a given context."""
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
executor = get_reusable_executor(max_workers=self.max_workers)
executor = get_reusable_executor(max_workers=self.max_workers)

chunks_to_process = self._chunk_generator(data)
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)
chunks_to_process = self._chunk_generator(data)
gen_func = self._ordered_generator if self.ordered else self._unordered_generator
processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context)

for result_chunk in processed_chunks_iterator:
yield from result_chunk
for result_chunk in processed_chunks_iterator:
yield from result_chunk

# ... The rest of the file remains the same ...
def _ordered_generator(
Expand Down
44 changes: 31 additions & 13 deletions laygo/transformers/threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import MutableMapping
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
import copy
from functools import partial
import itertools
from multiprocessing.managers import DictProxy
import threading
from typing import Any
from typing import Union
Expand Down Expand Up @@ -101,25 +103,41 @@ def from_transformer[T, U](

def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
"""
Executes the transformer on data concurrently.

A new `threading.Lock` is created and added to the context for each call
to ensure execution runs are isolated and thread-safe.
Executes the transformer on data concurrently. It uses the shared
context provided by the Pipeline, if available.
"""
# Determine the context for this run, passing it by reference as requested.
run_context = context or self.context
# Add a per-call lock for thread safety.
run_context["lock"] = threading.Lock()

def process_chunk(chunk: list[In], shared_context: PipelineContext) -> list[Out]:
run_context = context if context is not None else self.context

# Detect if the context is already managed by the Pipeline.
is_managed_context = isinstance(run_context, DictProxy)

if is_managed_context:
# Use the existing shared context and lock from the Pipeline.
shared_context = run_context
yield from self._execute_with_context(data, shared_context)
# The context is live, so no need to update it here.
# The Pipeline's __del__ will handle final state.
else:
# Fallback for standalone use: create a thread-safe context.
# Since threads share memory, we can use the context directly with a lock.
if "lock" not in run_context:
run_context["lock"] = threading.Lock()

yield from self._execute_with_context(data, run_context)
# Context is already updated in-place for threads (shared memory)

def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]:
"""Helper to run the execution logic with a given context."""

def process_chunk(chunk: list[In], shared_context: MutableMapping[str, Any]) -> list[Out]:
"""
Process a single chunk by passing the chunk and context explicitly
to the transformer chain. This is safer and avoids mutating self.
"""
return self.transformer(chunk, shared_context)
return self.transformer(chunk, shared_context) # type: ignore

# Create a partial function with the run_context "baked in".
process_chunk_with_context = partial(process_chunk, shared_context=run_context)
# Create a partial function with the shared_context "baked in".
process_chunk_with_context = partial(process_chunk, shared_context=shared_context)

def _ordered_generator(chunks_iter: Iterator[list[In]], executor: ThreadPoolExecutor) -> Iterator[list[Out]]:
"""Generate results in their original order."""
Expand Down
14 changes: 0 additions & 14 deletions tests/test_parallel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import multiprocessing as mp
import time
from unittest.mock import patch

from laygo import ErrorHandler
from laygo import ParallelTransformer
Expand Down Expand Up @@ -159,19 +158,6 @@ def test_unordered_vs_ordered_same_elements(self):
assert sorted(ordered_result) == sorted(unordered_result)
assert ordered_result == [x * 2 for x in data]

def test_process_pool_management(self):
"""Test that process pool is properly created and cleaned up."""
with patch("laygo.transformers.parallel.ProcessPoolExecutor") as mock_executor:
mock_executor.return_value.__enter__.return_value = mock_executor.return_value
mock_executor.return_value.__exit__.return_value = None
mock_executor.return_value.submit.return_value.result.return_value = [2, 4]
transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2)
list(transformer([1, 2]))

mock_executor.assert_called_with(max_workers=2)
mock_executor.return_value.__enter__.assert_called_once()
mock_executor.return_value.__exit__.assert_called_once()


class TestParallelTransformerChunkingAndEdgeCases:
"""Test chunking behavior and edge cases."""
Expand Down
Loading