From 6e0bd8873aa144aa92d4be57b601bfebd66a4807 Mon Sep 17 00:00:00 2001 From: ringoldsdev Date: Thu, 17 Jul 2025 16:30:53 +0000 Subject: [PATCH 1/4] feat: implemented http transformer allowing distributed processing --- .github/workflows/publish.yml | 4 +- laygo/__init__.py | 14 ++--- laygo/errors.py | 2 +- laygo/pipeline.py | 5 +- laygo/transformers/http.py | 91 +++++++++++++++++++++++++++++++ laygo/transformers/parallel.py | 8 +-- laygo/transformers/transformer.py | 4 +- pyproject.toml | 5 +- uv.lock | 5 ++ 9 files changed, 118 insertions(+), 20 deletions(-) create mode 100644 laygo/transformers/http.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 47b8e74..280457c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -38,9 +38,9 @@ jobs: - name: Update version in __init__.py run: | - sed -i 's|VERSION_NUMBER|${{ steps.version.outputs.VERSION }}|g' laygo/__init__.py + sed -i 's|"0.1.0"|"${{ steps.version.outputs.VERSION }}"|g' pyproject.toml echo "Updated version to ${{ steps.version.outputs.VERSION }}" - cat laygo/__init__.py + cat pyproject.toml - name: Build package run: uv build diff --git a/laygo/__init__.py b/laygo/__init__.py index ea35f02..611995c 100644 --- a/laygo/__init__.py +++ b/laygo/__init__.py @@ -2,18 +2,18 @@ Laygo - A lightweight Python library for building resilient, in-memory data pipelines """ -__version__ = "VERSION_NUMBER" - -from .errors import ErrorHandler -from .helpers import PipelineContext -from .pipeline import Pipeline -from .transformers.parallel import ParallelTransformer -from .transformers.transformer import Transformer +from laygo.errors import ErrorHandler +from laygo.helpers import PipelineContext +from laygo.pipeline import Pipeline +from laygo.transformers.http import HTTPTransformer +from laygo.transformers.parallel import ParallelTransformer +from laygo.transformers.transformer import Transformer __all__ = [ "Pipeline", "Transformer", "ParallelTransformer", + "HTTPTransformer", "PipelineContext", "ErrorHandler", ] diff --git a/laygo/errors.py b/laygo/errors.py index 8afd5df..0741540 100644 --- a/laygo/errors.py +++ b/laygo/errors.py @@ -1,6 +1,6 @@ from collections.abc import Callable -from laygo.helpers import PipelineContext +from laygo import PipelineContext ChunkErrorHandler = Callable[[list, Exception, PipelineContext], None] diff --git a/laygo/pipeline.py b/laygo/pipeline.py index 921cb41..b287469 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -6,11 +6,10 @@ from typing import TypeVar from typing import overload -from laygo.helpers import PipelineContext +from laygo import PipelineContext +from laygo import Transformer from laygo.helpers import is_context_aware -from .transformers.transformer import Transformer - T = TypeVar("T") PipelineFunction = Callable[[T], Any] diff --git a/laygo/transformers/http.py b/laygo/transformers/http.py new file mode 100644 index 0000000..215b60a --- /dev/null +++ b/laygo/transformers/http.py @@ -0,0 +1,91 @@ +""" +The final, self-sufficient DistributedTransformer. +""" + +from collections.abc import Iterable +from collections.abc import Iterator +from concurrent.futures import FIRST_COMPLETED +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import wait +import hashlib +import itertools +import pickle + +import requests + +from laygo import PipelineContext +from laygo import Transformer + + +class HTTPTransformer(Transformer): + """ + A self-sufficient, chainable transformer that manages its own + distributed execution and worker endpoint definition. + """ + + def __init__(self, base_url: str, endpoint: str | None = None, max_workers: int = 8): + super().__init__() + self.base_url = base_url.rstrip("/") + self.endpoint = endpoint + self.max_workers = max_workers + self.session = requests.Session() + self._worker_url: str + + def _finalize_config(self): + """Determines the final worker URL, generating one if needed.""" + if self._worker_url: + return + + if self.endpoint: + path = self.endpoint + else: + # Using pickle to serialize the function chain and hashing for a unique ID + serialized_logic = pickle.dumps(self.transformer) + hash_id = hashlib.sha1(serialized_logic).hexdigest()[:16] + path = f"/autogen/{hash_id}" + + self.endpoint = path.lstrip("/") + self._worker_url = f"{self.base_url}/{self.endpoint}" + + def __call__(self, data: Iterable, context=None) -> Iterator: + """CLIENT-SIDE: Called by the Pipeline to start distributed processing.""" + self._finalize_config() + + def process_chunk(chunk: list) -> list: + """Target for a thread: sends one chunk to the worker.""" + try: + response = self.session.post(self._worker_url, json=chunk, timeout=300) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + print(f"Error calling worker {self._worker_url}: {e}") + return [] + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + chunk_iterator = self._chunk_generator(data) + futures = {executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunk_iterator, self.max_workers)} + while futures: + done, futures = wait(futures, return_when=FIRST_COMPLETED) + for future in done: + yield from future.result() + try: + new_chunk = next(chunk_iterator) + futures.add(executor.submit(process_chunk, new_chunk)) + except StopIteration: + continue + + def get_route(self): + """ + Function that returns the route for the worker. + This is used to register the worker in a Flask app or similar. + + Returns: + A tuple containing the endpoint and the worker function. + """ + self._finalize_config() + + def worker_view_func(chunk: list, context: PipelineContext): + """The actual Flask view function for this transformer's logic.""" + return self.transformer(chunk, context) + + return (f"/{self.endpoint}", worker_view_func) diff --git a/laygo/transformers/parallel.py b/laygo/transformers/parallel.py index 8794cc0..0444d33 100644 --- a/laygo/transformers/parallel.py +++ b/laygo/transformers/parallel.py @@ -12,10 +12,10 @@ import itertools import threading -from .transformer import DEFAULT_CHUNK_SIZE -from .transformer import InternalTransformer -from .transformer import PipelineContext -from .transformer import Transformer +from laygo import PipelineContext +from laygo import Transformer +from laygo.transformers.transformer import DEFAULT_CHUNK_SIZE +from laygo.transformers.transformer import InternalTransformer class ParallelPipelineContextType(PipelineContext): diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index 6ed5df6..abe7669 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -9,8 +9,8 @@ from typing import Union from typing import overload -from laygo.errors import ErrorHandler -from laygo.helpers import PipelineContext +from laygo import ErrorHandler +from laygo import PipelineContext from laygo.helpers import is_context_aware from laygo.helpers import is_context_aware_reduce diff --git a/pyproject.toml b/pyproject.toml index 2cba014..1d3a9aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ path = "laygo/__init__.py" [project] name = "laygo" -dynamic = ["version"] +version = "0.1.0" description = "A lightweight Python library for building resilient, in-memory data pipelines with elegant, chainable syntax" readme = "README.md" requires-python = ">=3.12" @@ -29,12 +29,15 @@ classifiers = [ "Typing :: Typed", ] +dependencies = ["requests>=2.32"] + [project.urls] Homepage = "https://github.com/ringoldsdev/laygo-python" Documentation = "https://github.com/ringoldsdev/laygo-python/wiki" Repository = "https://github.com/ringoldsdev/laygo-python.git" Issues = "https://github.com/ringoldsdev/laygo-python/issues" + [project.optional-dependencies] dev = [ "pytest>=7.0.0", diff --git a/uv.lock b/uv.lock index 59fe8e3..1606c82 100644 --- a/uv.lock +++ b/uv.lock @@ -209,7 +209,11 @@ wheels = [ [[package]] name = "laygo" +version = "0.1.0" source = { editable = "." } +dependencies = [ + { name = "requests" }, +] [package.optional-dependencies] dev = [ @@ -221,6 +225,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, + { name = "requests", specifier = ">=2.32" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "twine", marker = "extra == 'dev'", specifier = ">=4.0.0" }, ] From 0fa429d219d4e80b53d1d93e2e4b4d8733128283 Mon Sep 17 00:00:00 2001 From: ringoldsdev Date: Thu, 17 Jul 2025 19:39:14 +0000 Subject: [PATCH 2/4] fix: http transformer types and added tests --- laygo/errors.py | 2 +- laygo/pipeline.py | 4 +- laygo/transformers/http.py | 90 ++++++++++++++++++++++++++----- laygo/transformers/parallel.py | 4 +- laygo/transformers/transformer.py | 4 +- pyproject.toml | 1 + tests/test_http_transformer.py | 59 ++++++++++++++++++++ uv.lock | 14 +++++ 8 files changed, 158 insertions(+), 20 deletions(-) create mode 100644 tests/test_http_transformer.py diff --git a/laygo/errors.py b/laygo/errors.py index 0741540..8afd5df 100644 --- a/laygo/errors.py +++ b/laygo/errors.py @@ -1,6 +1,6 @@ from collections.abc import Callable -from laygo import PipelineContext +from laygo.helpers import PipelineContext ChunkErrorHandler = Callable[[list, Exception, PipelineContext], None] diff --git a/laygo/pipeline.py b/laygo/pipeline.py index b287469..44d4a9c 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -6,9 +6,9 @@ from typing import TypeVar from typing import overload -from laygo import PipelineContext -from laygo import Transformer +from laygo.helpers import PipelineContext from laygo.helpers import is_context_aware +from laygo.transformers.transformer import Transformer T = TypeVar("T") PipelineFunction = Callable[[T], Any] diff --git a/laygo/transformers/http.py b/laygo/transformers/http.py index 215b60a..6a2fbc0 100644 --- a/laygo/transformers/http.py +++ b/laygo/transformers/http.py @@ -1,7 +1,8 @@ """ -The final, self-sufficient DistributedTransformer. +The final, self-sufficient DistributedTransformer with corrected typing. """ +from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator from concurrent.futures import FIRST_COMPLETED @@ -10,14 +11,26 @@ import hashlib import itertools import pickle +from typing import Any +from typing import TypeVar +from typing import Union +from typing import overload import requests -from laygo import PipelineContext -from laygo import Transformer +from laygo.errors import ErrorHandler +from laygo.helpers import PipelineContext +from laygo.transformers.transformer import ChunkErrorHandler +from laygo.transformers.transformer import PipelineFunction +from laygo.transformers.transformer import Transformer +In = TypeVar("In") +Out = TypeVar("Out") +T = TypeVar("T") +U = TypeVar("U") -class HTTPTransformer(Transformer): + +class HTTPTransformer(Transformer[In, Out]): """ A self-sufficient, chainable transformer that manages its own distributed execution and worker endpoint definition. @@ -29,17 +42,18 @@ def __init__(self, base_url: str, endpoint: str | None = None, max_workers: int self.endpoint = endpoint self.max_workers = max_workers self.session = requests.Session() - self._worker_url: str + self._worker_url: str | None = None def _finalize_config(self): """Determines the final worker URL, generating one if needed.""" - if self._worker_url: + if hasattr(self, "_worker_url") and self._worker_url: return if self.endpoint: path = self.endpoint else: - # Using pickle to serialize the function chain and hashing for a unique ID + if not self.transformer: + raise ValueError("Cannot determine endpoint for an empty transformer.") serialized_logic = pickle.dumps(self.transformer) hash_id = hashlib.sha1(serialized_logic).hexdigest()[:16] path = f"/autogen/{hash_id}" @@ -47,14 +61,20 @@ def _finalize_config(self): self.endpoint = path.lstrip("/") self._worker_url = f"{self.base_url}/{self.endpoint}" - def __call__(self, data: Iterable, context=None) -> Iterator: + # --- Original HTTPTransformer Methods --- + + def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: """CLIENT-SIDE: Called by the Pipeline to start distributed processing.""" self._finalize_config() def process_chunk(chunk: list) -> list: """Target for a thread: sends one chunk to the worker.""" try: - response = self.session.post(self._worker_url, json=chunk, timeout=300) + response = self.session.post( + self._worker_url, # type: ignore + json=chunk, + timeout=300, + ) response.raise_for_status() return response.json() except requests.RequestException as e: @@ -78,14 +98,58 @@ def get_route(self): """ Function that returns the route for the worker. This is used to register the worker in a Flask app or similar. - - Returns: - A tuple containing the endpoint and the worker function. """ self._finalize_config() def worker_view_func(chunk: list, context: PipelineContext): - """The actual Flask view function for this transformer's logic.""" + """The actual worker logic for this transformer.""" return self.transformer(chunk, context) return (f"/{self.endpoint}", worker_view_func) + + # --- Overridden Chaining Methods to Preserve Type --- + + def on_error(self, handler: ChunkErrorHandler[In, Out] | ErrorHandler) -> "HTTPTransformer[In, Out]": + super().on_error(handler) + return self + + def map[U](self, function: PipelineFunction[Out, U]) -> "HTTPTransformer[In, U]": + super().map(function) + return self # type: ignore + + def filter(self, predicate: PipelineFunction[Out, bool]) -> "HTTPTransformer[In, Out]": + super().filter(predicate) + return self + + @overload + def flatten[T](self: "HTTPTransformer[In, list[T]]") -> "HTTPTransformer[In, T]": ... + @overload + def flatten[T](self: "HTTPTransformer[In, tuple[T, ...]]") -> "HTTPTransformer[In, T]": ... + @overload + def flatten[T](self: "HTTPTransformer[In, set[T]]") -> "HTTPTransformer[In, T]": ... + def flatten[T]( + self: Union["HTTPTransformer[In, list[T]]", "HTTPTransformer[In, tuple[T, ...]]", "HTTPTransformer[In, set[T]]"], + ) -> "HTTPTransformer[In, T]": + super().flatten() + return self # type: ignore + + def tap(self, function: PipelineFunction[Out, Any]) -> "HTTPTransformer[In, Out]": + super().tap(function) + return self + + def apply[T](self, t: Callable[["HTTPTransformer[In, Out]"], "Transformer[In, T]"]) -> "HTTPTransformer[In, T]": + # Note: The type hint for `t` is slightly adjusted to reflect it receives an HTTPTransformer + super().apply(t) # type: ignore + return self # type: ignore + + def catch[U]( + self, + sub_pipeline_builder: Callable[[Transformer[Out, Out]], Transformer[Out, U]], + on_error: ChunkErrorHandler[Out, U] | None = None, + ) -> "HTTPTransformer[In, U]": + super().catch(sub_pipeline_builder, on_error) + return self # type: ignore + + def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "HTTPTransformer[In, Out]": + super().short_circuit(function) + return self diff --git a/laygo/transformers/parallel.py b/laygo/transformers/parallel.py index 0444d33..f573c3d 100644 --- a/laygo/transformers/parallel.py +++ b/laygo/transformers/parallel.py @@ -12,10 +12,10 @@ import itertools import threading -from laygo import PipelineContext -from laygo import Transformer +from laygo.helpers import PipelineContext from laygo.transformers.transformer import DEFAULT_CHUNK_SIZE from laygo.transformers.transformer import InternalTransformer +from laygo.transformers.transformer import Transformer class ParallelPipelineContextType(PipelineContext): diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index abe7669..6ed5df6 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -9,8 +9,8 @@ from typing import Union from typing import overload -from laygo import ErrorHandler -from laygo import PipelineContext +from laygo.errors import ErrorHandler +from laygo.helpers import PipelineContext from laygo.helpers import is_context_aware from laygo.helpers import is_context_aware_reduce diff --git a/pyproject.toml b/pyproject.toml index 1d3a9aa..70a709a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dev = [ "pytest>=7.0.0", "ruff>=0.1.0", "twine>=4.0.0", + "requests-mock>=1.12.1", ] [tool.ruff] diff --git a/tests/test_http_transformer.py b/tests/test_http_transformer.py new file mode 100644 index 0000000..97868e3 --- /dev/null +++ b/tests/test_http_transformer.py @@ -0,0 +1,59 @@ +# Assuming the classes from your latest example are in a file named `pipeline_lib.py` +# This includes Pipeline, Transformer, and your HTTPTransformer. +import requests_mock + +from laygo import HTTPTransformer +from laygo import Pipeline +from laygo import PipelineContext + + +class TestHTTPTransformer: + """ + Test suite for the HTTPTransformer class. + """ + + def test_distributed_transformer_with_mock(self): + """ + Tests the HTTPTransformer by mocking the worker endpoint. + This test validates that the client-side of the transformer correctly + calls the endpoint and processes the response from the (mocked) worker. + """ + # 1. Define the transformer's properties + base_url = "http://mock-worker.com" + endpoint = "/process/data" + worker_url = f"{base_url}{endpoint}" + + # 2. Define the transformer and its logic using the chainable API. + # This single instance holds both the client and server logic. + http_transformer = ( + HTTPTransformer(base_url=base_url, endpoint=endpoint).map(lambda x: x * 2).filter(lambda x: x > 10) + ) + + # Set a small chunk_size to ensure the client makes multiple requests + http_transformer.chunk_size = 4 + + # 3. Get the worker's logic from the transformer itself + # The `get_route` method provides the exact function the worker would run. + _, worker_view_func = http_transformer.get_route() + + # 4. Configure the mock endpoint to use the real worker logic + def mock_response(request, context): + """The behavior of the mocked Flask endpoint.""" + input_chunk = request.json() + # Call the actual view function logic obtained from get_route() + # We pass None for the context as it's not used in this simple case. + output_chunk = worker_view_func(chunk=input_chunk, context=PipelineContext()) + return output_chunk + + # Use requests_mock context manager + with requests_mock.Mocker() as m: + m.post(worker_url, json=mock_response) + + # 5. Run the standard Pipeline with the configured transformer + initial_data = list(range(10)) # [0, 1, 2, ..., 9] + pipeline = Pipeline(initial_data).apply(http_transformer) + result = pipeline.to_list() + + # 6. Assert the final result + expected_result = [12, 14, 16, 18] + assert sorted(result) == sorted(expected_result) diff --git a/uv.lock b/uv.lock index 1606c82..15e5f9d 100644 --- a/uv.lock +++ b/uv.lock @@ -218,6 +218,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "pytest" }, + { name = "requests-mock" }, { name = "ruff" }, { name = "twine" }, ] @@ -226,6 +227,7 @@ dev = [ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "requests", specifier = ">=2.32" }, + { name = "requests-mock", marker = "extra == 'dev'", specifier = ">=1.12.1" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "twine", marker = "extra == 'dev'", specifier = ">=4.0.0" }, ] @@ -384,6 +386,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" }, ] +[[package]] +name = "requests-mock" +version = "1.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/32/587625f91f9a0a3d84688bf9cfc4b2480a7e8ec327cefd0ff2ac891fd2cf/requests-mock-1.12.1.tar.gz", hash = "sha256:e9e12e333b525156e82a3c852f22016b9158220d2f47454de9cae8a77d371401", size = 60901, upload-time = "2024-03-29T03:54:29.446Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/ec/889fbc557727da0c34a33850950310240f2040f3b1955175fdb2b36a8910/requests_mock-1.12.1-py2.py3-none-any.whl", hash = "sha256:b1e37054004cdd5e56c84454cc7df12b25f90f382159087f4b6915aaeef39563", size = 27695, upload-time = "2024-03-29T03:54:27.64Z" }, +] + [[package]] name = "requests-toolbelt" version = "1.0.0" From d2adbfb6cb0a9ee52afacbb6be5ed2731c9bf328 Mon Sep 17 00:00:00 2001 From: ringoldsdev Date: Thu, 17 Jul 2025 19:48:30 +0000 Subject: [PATCH 3/4] fix: type errors --- laygo/transformers/http.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/laygo/transformers/http.py b/laygo/transformers/http.py index 6a2fbc0..610ba24 100644 --- a/laygo/transformers/http.py +++ b/laygo/transformers/http.py @@ -127,11 +127,12 @@ def flatten[T](self: "HTTPTransformer[In, list[T]]") -> "HTTPTransformer[In, T]" def flatten[T](self: "HTTPTransformer[In, tuple[T, ...]]") -> "HTTPTransformer[In, T]": ... @overload def flatten[T](self: "HTTPTransformer[In, set[T]]") -> "HTTPTransformer[In, T]": ... - def flatten[T]( + # Forgive me for I have sinned, but this is necessary to avoid type errors + # Sinec I'm setting self type in the parent class, overriding it isn't allowed + def flatten[T]( # type: ignore self: Union["HTTPTransformer[In, list[T]]", "HTTPTransformer[In, tuple[T, ...]]", "HTTPTransformer[In, set[T]]"], ) -> "HTTPTransformer[In, T]": - super().flatten() - return self # type: ignore + return super().flatten() # type: ignore def tap(self, function: PipelineFunction[Out, Any]) -> "HTTPTransformer[In, Out]": super().tap(function) From bf0af9c635072ed72c9d7d540100a77f08fcdcfa Mon Sep 17 00:00:00 2001 From: ringoldsdev Date: Thu, 17 Jul 2025 19:53:04 +0000 Subject: [PATCH 4/4] fix: types --- laygo/transformers/http.py | 3 +- laygo/transformers/parallel.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/laygo/transformers/http.py b/laygo/transformers/http.py index 610ba24..f181160 100644 --- a/laygo/transformers/http.py +++ b/laygo/transformers/http.py @@ -132,7 +132,8 @@ def flatten[T](self: "HTTPTransformer[In, set[T]]") -> "HTTPTransformer[In, T]": def flatten[T]( # type: ignore self: Union["HTTPTransformer[In, list[T]]", "HTTPTransformer[In, tuple[T, ...]]", "HTTPTransformer[In, set[T]]"], ) -> "HTTPTransformer[In, T]": - return super().flatten() # type: ignore + super().flatten() # type: ignore + return self # type: ignore def tap(self, function: PipelineFunction[Out, Any]) -> "HTTPTransformer[In, Out]": super().tap(function) diff --git a/laygo/transformers/parallel.py b/laygo/transformers/parallel.py index f573c3d..c247ab5 100644 --- a/laygo/transformers/parallel.py +++ b/laygo/transformers/parallel.py @@ -1,6 +1,7 @@ """Parallel transformer implementation using multiple threads.""" from collections import deque +from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator from concurrent.futures import FIRST_COMPLETED @@ -11,10 +12,16 @@ from functools import partial import itertools import threading +from typing import Any +from typing import Union +from typing import overload +from laygo.errors import ErrorHandler from laygo.helpers import PipelineContext from laygo.transformers.transformer import DEFAULT_CHUNK_SIZE +from laygo.transformers.transformer import ChunkErrorHandler from laygo.transformers.transformer import InternalTransformer +from laygo.transformers.transformer import PipelineFunction from laygo.transformers.transformer import Transformer @@ -142,3 +149,53 @@ def result_iterator_manager() -> Iterator[Out]: yield from result_chunk return result_iterator_manager() + + # --- Overridden Chaining Methods to Preserve Type --- + + def on_error(self, handler: ChunkErrorHandler[In, Out] | ErrorHandler) -> "ParallelTransformer[In, Out]": + super().on_error(handler) + return self + + def map[U](self, function: PipelineFunction[Out, U]) -> "ParallelTransformer[In, U]": + super().map(function) + return self # type: ignore + + def filter(self, predicate: PipelineFunction[Out, bool]) -> "ParallelTransformer[In, Out]": + super().filter(predicate) + return self + + @overload + def flatten[T](self: "ParallelTransformer[In, list[T]]") -> "ParallelTransformer[In, T]": ... + @overload + def flatten[T](self: "ParallelTransformer[In, tuple[T, ...]]") -> "ParallelTransformer[In, T]": ... + @overload + def flatten[T](self: "ParallelTransformer[In, set[T]]") -> "ParallelTransformer[In, T]": ... + def flatten[T]( # type: ignore + self: Union[ + "ParallelTransformer[In, list[T]]", "ParallelTransformer[In, tuple[T, ...]]", "ParallelTransformer[In, set[T]]" + ], + ) -> "ParallelTransformer[In, T]": + super().flatten() # type: ignore + return self # type: ignore + + def tap(self, function: PipelineFunction[Out, Any]) -> "ParallelTransformer[In, Out]": + super().tap(function) + return self + + def apply[T]( + self, t: Callable[["ParallelTransformer[In, Out]"], "Transformer[In, T]"] + ) -> "ParallelTransformer[In, T]": + super().apply(t) # type: ignore + return self # type: ignore + + def catch[U]( + self, + sub_pipeline_builder: Callable[[Transformer[Out, Out]], Transformer[Out, U]], + on_error: ChunkErrorHandler[Out, U] | None = None, + ) -> "ParallelTransformer[In, U]": + super().catch(sub_pipeline_builder, on_error) + return self # type: ignore + + def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "ParallelTransformer[In, Out]": + super().short_circuit(function) + return self