diff --git a/gloe/__init__.py b/gloe/__init__.py index 1a976b62..10562657 100644 --- a/gloe/__init__.py +++ b/gloe/__init__.py @@ -8,10 +8,10 @@ from gloe.conditional import If, condition from gloe.ensurer import ensure from gloe.exceptions import UnsupportedTransformerArgException -from gloe.transformers import Transformer, MultiArgsTransformer +from gloe.transformers import Transformer from gloe.base_transformer import BaseTransformer, PreviousTransformer from gloe.base_transformer import TransformerException -from gloe.async_transformer import AsyncTransformer, MultiArgsAsyncTransformer +from gloe.async_transformer import AsyncTransformer __version__ = "0.7.0" @@ -33,5 +33,3 @@ setattr(Transformer, "__rshift__", _compose_nodes) setattr(AsyncTransformer, "__rshift__", _compose_nodes) -setattr(MultiArgsTransformer, "__rshift__", _compose_nodes) -setattr(MultiArgsAsyncTransformer, "__rshift__", _compose_nodes) diff --git a/gloe/_composition_utils.py b/gloe/_composition_utils.py index 4f35e5e6..00a77c17 100644 --- a/gloe/_composition_utils.py +++ b/gloe/_composition_utils.py @@ -2,10 +2,10 @@ from inspect import Signature from typing import TypeVar, Any, Optional, Union -from gloe.async_transformer import AsyncTransformer, MultiArgsAsyncTransformer +from gloe.async_transformer import AsyncTransformer from gloe.base_transformer import BaseTransformer from gloe.gateways._parallel import _Parallel, _ParallelAsync -from gloe.transformers import Transformer, MultiArgsTransformer +from gloe.transformers import Transformer from gloe._typing_utils import _match_types, _specify_types from gloe.exceptions import UnsupportedTransformerArgException @@ -76,55 +76,26 @@ def __len__(self): new_transformer: Optional[BaseTransformer] = None if is_transformer(transformer1) and is_transformer(transformer2): - if isinstance(transformer1, MultiArgsTransformer): + class NewTransformer1(BaseNewTransformer, Transformer[_In, _NextOut]): + def __init__(self): + super().__init__() + self._flow = transformer1._flow + transformer2._flow - class NewMultiArgsTransformer(BaseNewTransformer, MultiArgsTransformer): - def __init__(self): - super().__init__() - self._flow = transformer1._flow + transformer2._flow + def transform(self, data): + return None - def transform(self, data): - return None - - new_transformer = NewMultiArgsTransformer() - - else: - - class NewTransformer1(BaseNewTransformer, Transformer[_In, _NextOut]): - def __init__(self): - super().__init__() - self._flow = transformer1._flow + transformer2._flow - - def transform(self, data): - return None - - new_transformer = NewTransformer1() + new_transformer = NewTransformer1() else: - if isinstance(transformer1, MultiArgsAsyncTransformer): - - class NewMultiArgsAsyncTransformer( - BaseNewTransformer, MultiArgsAsyncTransformer - ): - def __init__(self): - super().__init__() - self._flow = transformer1._flow + transformer2._flow - - async def transform_async(self, data): - return None - - new_transformer = NewMultiArgsAsyncTransformer() - else: - - class NewTransformer2(BaseNewTransformer, AsyncTransformer[_In, _NextOut]): - def __init__(self): - super().__init__() - self._flow = transformer1._flow + transformer2._flow + class NewTransformer2(BaseNewTransformer, AsyncTransformer[_In, _NextOut]): + def __init__(self): + super().__init__() + self._flow = transformer1._flow + transformer2._flow - async def transform_async(self, data): - return None + async def transform_async(self, data): + return None - new_transformer = NewTransformer2() + new_transformer = NewTransformer2() new_transformer.__class__.__name__ = transformer2.__class__.__name__ new_transformer._label = transformer2.label @@ -157,67 +128,33 @@ def __len__(self): if is_transformer(incident_transformer) and is_transformer(receiving_transformers): - if isinstance(incident_transformer, MultiArgsTransformer): + class NewTransformer1(BaseNewTransformer, Transformer[_In, tuple[Any, ...]]): + def __init__(self): + super().__init__() + self._flow = incident_transformer._flow + [ + _Parallel(*receiving_transformers) + ] - class NewMultiArgsTransformer(BaseNewTransformer, MultiArgsTransformer): - def __init__(self): - super().__init__() - self._flow = incident_transformer._flow + [ - _Parallel(*receiving_transformers) - ] + def transform(self, data): + return None - def transform(self, data): - return None - - new_transformer = NewMultiArgsTransformer() - else: - - class NewTransformer1( - BaseNewTransformer, Transformer[_In, tuple[Any, ...]] - ): - def __init__(self): - super().__init__() - self._flow = incident_transformer._flow + [ - _Parallel(*receiving_transformers) - ] - - def transform(self, data): - return None - - new_transformer = NewTransformer1() + new_transformer = NewTransformer1() else: - if isinstance(incident_transformer, MultiArgsAsyncTransformer): - - class NewMultiArgsAsyncTransformer( - BaseNewTransformer, MultiArgsAsyncTransformer - ): - def __init__(self): - super().__init__() - self._flow = incident_transformer._flow + [ - _ParallelAsync(*receiving_transformers) - ] - - async def transform_async(self, data): - return None - - new_transformer = NewMultiArgsAsyncTransformer() - else: - - class NewTransformer2( - BaseNewTransformer, AsyncTransformer[_In, tuple[Any, ...]] - ): - def __init__(self): - super().__init__() - self._flow = incident_transformer._flow + [ - _ParallelAsync(*receiving_transformers) - ] + class NewTransformer2( + BaseNewTransformer, AsyncTransformer[_In, tuple[Any, ...]] + ): + def __init__(self): + super().__init__() + self._flow = incident_transformer._flow + [ + _ParallelAsync(*receiving_transformers) + ] - async def transform_async(self, data): - return None + async def transform_async(self, data): + return None - new_transformer = NewTransformer2() + new_transformer = NewTransformer2() # new_transformer._previous = cast(Transformer, receiving_transformers) new_transformer.__class__.__name__ = "Converge" diff --git a/gloe/async_transformer.py b/gloe/async_transformer.py index bd4d7008..5a94ef1b 100644 --- a/gloe/async_transformer.py +++ b/gloe/async_transformer.py @@ -106,13 +106,23 @@ def copy( @overload async def __call__(self: "AsyncTransformer[None, _Out]") -> _Out: - return await _execute_async_flow(self._flow, None) + pass + + @overload + async def __call__( + self: "AsyncTransformer[tuple[Unpack[Args]], _Out]", *args: Unpack[Args] + ) -> _Out: + pass @overload async def __call__(self, data: _In) -> _Out: - return await _execute_async_flow(self._flow, data) + pass - async def __call__(self, data=None): + async def __call__(self, *data): + if len(data) == 0: + return await _execute_async_flow(self._flow, None) + if len(data) == 1: + return await _execute_async_flow(self._flow, data[0]) return await _execute_async_flow(self._flow, data) @overload @@ -195,98 +205,3 @@ def __rshift__( def __rshift__(self, next_node): # pragma: no cover pass - - -class MultiArgsAsyncTransformer( - Generic[Unpack[Args], _Out], AsyncTransformer[tuple[Unpack[Args]], _Out] -): - @override - async def __call__( # type: ignore[override] - self: "MultiArgsAsyncTransformer[Unpack[Args], _Out]", *data: Unpack[Args] - ) -> _Out: - return await _execute_async_flow(self._flow, data) - - @overload - def __rshift__( - self, next_node: BaseTransformer[_Out, _NextOut] - ) -> "MultiArgsAsyncTransformer[Unpack[Args], _NextOut]": - pass - - @overload - def __rshift__( - self, - next_node: tuple[BaseTransformer[_Out, _NextOut], BaseTransformer[_Out, _O2]], - ) -> "MultiArgsAsyncTransformer[Unpack[Args], tuple[_NextOut, _O2]]": - pass - - @overload - def __rshift__( - self, - next_node: tuple[ - BaseTransformer[_Out, _NextOut], - BaseTransformer[_Out, _O2], - BaseTransformer[_Out, _O3], - ], - ) -> "MultiArgsAsyncTransformer[Unpack[Args], tuple[_NextOut, _O2, _O3]]": - pass - - @overload - def __rshift__( - self, - next_node: tuple[ - BaseTransformer[_Out, _NextOut], - BaseTransformer[_Out, _O2], - BaseTransformer[_Out, _O3], - BaseTransformer[_Out, _O4], - ], - ) -> "MultiArgsAsyncTransformer[Unpack[Args], tuple[_NextOut, _O2, _O3, _O4]]": - pass - - @overload - def __rshift__( - self, - next_node: tuple[ - BaseTransformer[_Out, _NextOut], - BaseTransformer[_Out, _O2], - BaseTransformer[_Out, _O3], - BaseTransformer[_Out, _O4], - BaseTransformer[_Out, _O5], - ], - ) -> "MultiArgsAsyncTransformer[Unpack[Args], tuple[_NextOut, _O2, _O3, _O4, _O5]]": - pass - - @overload - def __rshift__( - self, - next_node: tuple[ - BaseTransformer[_Out, _NextOut], - BaseTransformer[_Out, _O2], - BaseTransformer[_Out, _O3], - BaseTransformer[_Out, _O4], - BaseTransformer[_Out, _O5], - BaseTransformer[_Out, _O6], - ], - ) -> """MultiArgsAsyncTransformer[ - Unpack[Args], tuple[_NextOut, _O2, _O3, _O4, _O5, _O6] - ]""": - pass - - @overload - def __rshift__( - self, - next_node: tuple[ - BaseTransformer[_Out, _NextOut], - BaseTransformer[_Out, _O2], - BaseTransformer[_Out, _O3], - BaseTransformer[_Out, _O4], - BaseTransformer[_Out, _O5], - BaseTransformer[_Out, _O6], - BaseTransformer[_Out, _O7], - ], - ) -> """MultiArgsAsyncTransformer[ - Unpack[Args], tuple[_NextOut, _O2, _O3, _O4, _O5, _O6, _O7] - ]""": - pass - - def __rshift__(self, next_node): # pragma: no cover - pass diff --git a/gloe/functional.py b/gloe/functional.py index c072cce8..64f3fbb7 100644 --- a/gloe/functional.py +++ b/gloe/functional.py @@ -5,9 +5,8 @@ from typing_extensions import Concatenate, TypeVarTuple, Unpack, ParamSpec -from gloe.async_transformer import AsyncTransformer, MultiArgsAsyncTransformer -from gloe.exceptions import TransformerRequiresMultiArgs -from gloe.transformers import Transformer, MultiArgsTransformer +from gloe.async_transformer import AsyncTransformer +from gloe.transformers import Transformer __all__ = [ "transformer", @@ -154,14 +153,14 @@ async def transform_async(self, data: A) -> S: @overload -def transformer( - func: Callable[[A, B, Unpack[Rest]], S], -) -> MultiArgsTransformer[A, B, Unpack[Rest], S]: +def transformer(func: Callable[[], S]) -> Transformer[None, S]: pass @overload -def transformer(func: Callable[[], S]) -> Transformer[None, S]: +def transformer( + func: Callable[[A, B, Unpack[Rest]], S], +) -> Transformer[tuple[A, B, Unpack[Rest]], S]: pass @@ -192,27 +191,6 @@ def filter_subscribed_users(users: list[User]) -> list[User]: """ func_signature = inspect.signature(func) - if len(func_signature.parameters) > 1: - - class LambdaMultiArgsTransformer(MultiArgsTransformer): - __doc__ = func.__doc__ - __annotations__ = cast(FunctionType, func).__annotations__ - - def signature(self) -> Signature: - return func_signature - - def transform(self, data): - if type(data) is tuple: - if len(data) == 1: - raise TransformerRequiresMultiArgs() - return func(*data) - raise NotImplementedError() # pragma: no cover - - lambda_transformer1 = LambdaMultiArgsTransformer() - lambda_transformer1.__class__.__name__ = func.__name__ - lambda_transformer1._label = func.__name__ - return lambda_transformer1 - class LambdaTransformer(Transformer): __doc__ = func.__doc__ __annotations__ = cast(FunctionType, func).__annotations__ @@ -223,6 +201,8 @@ def signature(self) -> Signature: def transform(self, data=None): if len(func_signature.parameters) == 0: return func() + if len(func_signature.parameters) > 1: + return func(*data) return func(data) lambda_transformer2 = LambdaTransformer() @@ -232,19 +212,19 @@ def transform(self, data=None): @overload -def async_transformer( - func: Callable[[A, B, Unpack[Rest]], Awaitable[S]], -) -> MultiArgsAsyncTransformer[A, B, Unpack[Rest], S]: +def async_transformer(func: Callable[[], Awaitable[S]]) -> AsyncTransformer[None, S]: pass @overload -def async_transformer(func: Callable[[], Awaitable[S]]) -> AsyncTransformer[None, S]: +def async_transformer(func: Callable[[A], Awaitable[S]]) -> AsyncTransformer[A, S]: pass @overload -def async_transformer(func: Callable[[A], Awaitable[S]]) -> AsyncTransformer[A, S]: +def async_transformer( + func: Callable[[A, B, Unpack[Rest]], Awaitable[S]], +) -> AsyncTransformer[tuple[A, B, Unpack[Rest]], S]: pass @@ -272,27 +252,6 @@ async def get_user_by_role(role: str) -> list[User]: """ func_signature = inspect.signature(func) - if len(func_signature.parameters) > 1: - - class LambdaMultiArgsTransformer(MultiArgsAsyncTransformer): - __doc__ = func.__doc__ - __annotations__ = cast(FunctionType, func).__annotations__ - - def signature(self) -> Signature: - return func_signature - - async def transform_async(self, data): - if type(data) is tuple: - if len(data) == 1: - raise TransformerRequiresMultiArgs() - return await func(*data) - raise NotImplementedError() # pragma: no cover - - lambda_transformer1 = LambdaMultiArgsTransformer() - lambda_transformer1.__class__.__name__ = func.__name__ - lambda_transformer1._label = func.__name__ - return lambda_transformer1 - class LambdaAsyncTransformer(AsyncTransformer): __doc__ = func.__doc__ __annotations__ = cast(FunctionType, func).__annotations__ @@ -303,6 +262,8 @@ def signature(self) -> Signature: async def transform_async(self, data): if len(func_signature.parameters) == 0: return await func() + if len(func_signature.parameters) > 1: + return await func(*data) return await func(data) lambda_transformer = LambdaAsyncTransformer() diff --git a/gloe/transformers.py b/gloe/transformers.py index 446b8cea..19174a01 100644 --- a/gloe/transformers.py +++ b/gloe/transformers.py @@ -3,9 +3,9 @@ from typing import TypeVar, overload, cast, Optional, Any -from typing_extensions import TypeAlias, Unpack, TypeVarTuple, Generic, override +from typing_extensions import TypeAlias, TypeVarTuple, Unpack -from gloe.async_transformer import AsyncTransformer, MultiArgsAsyncTransformer +from gloe.async_transformer import AsyncTransformer from gloe._transformer_utils import catch_transformer_exception from gloe.base_transformer import BaseTransformer, Flow @@ -112,11 +112,21 @@ def _safe_transform(self, data: _I) -> _O: def __call__(self: "Transformer[None, _O]") -> _O: pass + @overload + def __call__( + self: "Transformer[tuple[Unpack[Args]], _O]", *args: Unpack[Args] + ) -> _O: + return _execute_flow(self._flow, args) + @overload def __call__(self, data: _I) -> _O: pass - def __call__(self, data=None): + def __call__(self, *data): + if len(data) == 0: + return _execute_flow(self._flow, None) + if len(data) == 1: + return _execute_flow(self._flow, data[0]) return _execute_flow(self._flow, data) @overload @@ -226,143 +236,3 @@ def __rshift__( def __rshift__(self, next_node): # pragma: no cover pass - - -class MultiArgsTransformer( - Generic[Unpack[Args], _O], Transformer[tuple[Unpack[Args]], _O] -): - # The below ignored override errors are recommended by the documentation itself, - # "if you decide that type safety is not necessary", which is clearly the case. - # https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides - @override - def __call__( # type: ignore[override] - self: "MultiArgsTransformer[Unpack[Args], _O]", *data: Unpack[Args] - ) -> _O: - if len(data) == 1 and type(data[0]) is tuple: # type: ignore - data = data[0] # type: ignore - return _execute_flow(self._flow, data) - - @overload # type: ignore[override] - @override - def __rshift__( - self, next_node: "Transformer[_O, O1]" - ) -> "MultiArgsTransformer[Unpack[Args], O1]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple["Tr[_O, O1]", "Tr[_O, O2]"], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2]]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple["Tr[_O, O1]", "Tr[_O, O2]", "Tr[_O, O3]"], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2, O3]]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple["Tr[_O, O1]", "Tr[_O, O2]", "Tr[_O, O3]", "Tr[_O, O4]"], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2, O3, O4]]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple[ - "Tr[_O, O1]", "Tr[_O, O2]", "Tr[_O, O3]", "Tr[_O, O4]", "Tr[_O, O5]" - ], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2, O3, O4, O5]]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple[ - "Tr[_O, O1]", - "Tr[_O, O2]", - "Tr[_O, O3]", - "Tr[_O, O4]", - "Tr[_O, O5]", - "Tr[_O, O6]", - ], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2, O3, O4, O5, O6]]": - pass - - @overload - @override - def __rshift__( - self, - next_node: tuple[ - "Tr[_O, O1]", - "Tr[_O, O2]", - "Tr[_O, O3]", - "Tr[_O, O4]", - "Tr[_O, O5]", - "Tr[_O, O6]", - "Tr[_O, O7]", - ], - ) -> "MultiArgsTransformer[Unpack[Args], tuple[O1, O2, O3, O4, O5, O6, O7]]": - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncTransformer[_O, O1] - ) -> MultiArgsAsyncTransformer[_I, O1]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext2[_O, O1, O2] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2]]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext3[_O, O1, O2, O3] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2, O3]]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext4[_O, O1, O2, O3, O4] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2, O3, O4]]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext5[_O, O1, O2, O3, O4, O5] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2, O3, O4, O5]]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext6[_O, O1, O2, O3, O4, O5, O6] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2, O3, O4, O5, O6]]: - pass - - @overload - @override - def __rshift__( - self, next_node: AsyncNext7[_O, O1, O2, O3, O4, O5, O6, O7] - ) -> MultiArgsAsyncTransformer[_I, tuple[O1, O2, O3, O4, O5, O6, O7]]: - pass - - @override - def __rshift__(self, next_node): # pragma: no cover - pass diff --git a/tests/multiargs/test_async_multiargs_transformer.py b/tests/multiargs/test_async_multiargs_transformer.py index 4f07454b..237c2bb1 100644 --- a/tests/multiargs/test_async_multiargs_transformer.py +++ b/tests/multiargs/test_async_multiargs_transformer.py @@ -35,7 +35,7 @@ async def test_single_arg_exception(self): async def concat(arg1: str, arg2: str) -> str: return arg1 + arg2 - with self.assertRaises(TransformerRequiresMultiArgs): + with self.assertRaises(TypeError): await concat("test") # type: ignore[call-arg] async def test_composition_transform_method(self): diff --git a/tests/multiargs/test_multiargs_transformer_basic.py b/tests/multiargs/test_multiargs_transformer_basic.py index 26f4177f..c1e063ba 100644 --- a/tests/multiargs/test_multiargs_transformer_basic.py +++ b/tests/multiargs/test_multiargs_transformer_basic.py @@ -56,7 +56,7 @@ def test_single_arg_exception(self): def concat(arg1: str, arg2: str) -> str: return arg1 + arg2 - with self.assertRaises(TransformerRequiresMultiArgs): + with self.assertRaises(TypeError): concat("test") # type: ignore[call-arg] def test_noargs_basic_call(self): diff --git a/tests/multiargs/test_multiargs_transformer_types.py b/tests/multiargs/test_multiargs_transformer_types.py index 9b13bbae..6452f226 100644 --- a/tests/multiargs/test_multiargs_transformer_types.py +++ b/tests/multiargs/test_multiargs_transformer_types.py @@ -3,7 +3,7 @@ from gloe.utils import forward from typing_extensions import assert_type -from gloe import transformer, MultiArgsTransformer, Transformer +from gloe import transformer, Transformer from tests.lib.transformers import square from tests.type_utils.mypy_test_suite import MypyTestSuite @@ -21,13 +21,13 @@ def test_transformer_multiple_args(self): def sum2(num1: int, num2: float) -> float: return num1 + num2 - assert_type(sum2, MultiArgsTransformer[int, float, float]) + assert_type(sum2, Transformer[tuple[int, float], float]) @transformer def sum3(num1: int, num2: int, num3: int) -> int: return num1 + num2 + num3 - assert_type(sum3, MultiArgsTransformer[int, int, int, int]) + assert_type(sum3, Transformer[tuple[int, int, int], int]) def test_noargs_transformer(self): """ @@ -47,7 +47,7 @@ def sum2(num1: float, num2: float) -> float: pipeline = sum2 >> square - assert_type(pipeline, MultiArgsTransformer[float, float, float]) + assert_type(pipeline, Transformer[tuple[float, float], float]) pipeline2 = forward[float]() >> (square, square) >> sum2 @@ -55,4 +55,4 @@ def sum2(num1: float, num2: float) -> float: pipeline3 = sum2 >> (square, square) >> sum2 - assert_type(pipeline3, MultiArgsTransformer[float, float, float]) + assert_type(pipeline3, Transformer[tuple[float, float], float])