From 522e43a39e48fbd280a2568c198125c89bf9f508 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:23:45 -0800 Subject: [PATCH 1/9] =?UTF-8?q?=F0=9F=90=9B=20fix:=20allow=20factory=20met?= =?UTF-8?q?hod=20to=20get=20requested=20type=20as=20well=20as=20activating?= =?UTF-8?q?=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 33 +++++++++++++++------ tests/test_services.py | 65 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 0fc539e..c784c7c 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -336,9 +336,9 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: type): assert isinstance(context, ActivationScope) - return self.factory(context, parent_type) + return self.factory(context, parent_type, self._type) class SingletonFactoryTypeProvider: @@ -349,9 +349,9 @@ def __init__(self, _type, factory): self.factory = factory self.instance = None - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self.instance is None: - self.instance = self.factory(context, parent_type) + self.instance = self.factory(context, parent_type, self._type) return self.instance @@ -362,11 +362,11 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self._type in context.scoped_services: return context.scoped_services[self._type] - instance = self.factory(context, parent_type) + instance = self.factory(context, parent_type, self._type) context.scoped_services[self._type] = instance return instance @@ -412,7 +412,7 @@ def get_annotations_type_provider( life_style: ServiceLifeStyle, resolver_context: ResolutionContext, ): - def factory(context, parent_type): + def factory(context, parent_type, requested_type): instance = concrete_type() for name, resolver in resolvers.items(): setattr(instance, name, resolver(context, parent_type)) @@ -784,10 +784,12 @@ def exec( FactoryCallableNoArguments = Callable[[], Any] FactoryCallableSingleArgument = Callable[[ActivationScope], Any] FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any] +FactoryCallableThreeArguments = Callable[[ActivationScope, Type, Type], Any] FactoryCallableType = Union[ FactoryCallableNoArguments, FactoryCallableSingleArgument, FactoryCallableTwoArguments, + FactoryCallableThreeArguments, ] @@ -797,7 +799,7 @@ class FactoryWrapperNoArgs: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, parent_type): return self.factory() @@ -807,10 +809,20 @@ class FactoryWrapperContextArg: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, parent_type): return self.factory(context) +class FactoryWrapperPartentArg: + __slots__ = ("factory",) + + def __init__(self, factory): + self.factory = factory + + def __call__(self, context, activating_type, parent_type): + return self.factory(context, activating_type) + + class Container(ContainerProtocol): """ Configuration class for a collection of services. @@ -1097,6 +1109,9 @@ def _check_factory(factory, signature, handled_type) -> Callable: return FactoryWrapperContextArg(factory) if params_len == 2: + return FactoryWrapperPartentArg(factory) + + if params_len == 3: return factory raise InvalidFactory(handled_type) diff --git a/tests/test_services.py b/tests/test_services.py index 866eae0..ee9daed 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -754,19 +754,21 @@ def test_invalid_factory_too_many_arguments_throws(method_name): container = Container() method = getattr(container, method_name) - def factory(context, activating_type, extra_argument_mistake): + def factory(context, activating_type, requested_type, extra_argument_mistake): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two): + def factory(context, activating_type, requested_type, extra_argument_mistake, two): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two, three): + def factory( + context, activating_type, requested_type, extra_argument_mistake, two, three + ): return Cat("Celine") with raises(InvalidFactory): @@ -949,6 +951,15 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca return Cat("Celine") +def cat_factory_with_context_activating_type_and_requested_type( + context, activating_type, requested_type +) -> Cat: + assert isinstance(context, ActivationScope) + assert activating_type is Cat + assert requested_type is Cat + return Cat("Celine") + + @pytest.mark.parametrize( "method_name,factory", [ @@ -962,6 +973,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca cat_factory_no_args, cat_factory_with_context, cat_factory_with_context_and_activating_type, + cat_factory_with_context_activating_type_and_requested_type, ] ], ) @@ -1156,6 +1168,53 @@ def factory(_, activating_type) -> Logger: ) +@pytest.mark.parametrize( + "method_name", ["add_transient_by_factory", "add_scoped_by_factory"] +) +def test_factory_can_receive_requested_type_as_parameter(method_name): + class Db: + def __init__(self, activating, requested): + self.activating = activating + self.requested = requested + + class Fetcher: + def __init__(self, db: Db): + self.db = db + + container = Container() + container._add_exact_transient(Foo) + + def factory(self, activating_type, requested_type) -> Db: + return Db( + activating_type.__module__ + "." + activating_type.__name__, + requested_type.__module__ + "." + requested_type.__name__, + ) + + method = getattr(container, method_name) + method(factory, Db) + + container._add_exact_transient(Fetcher) + + provider = container.build_provider() + + db = provider.get(Db) + + assert db is not None + assert db.activating is not None + assert db.activating == "tests.test_services.Db" + assert db.requested is not None + assert db.requested == "tests.test_services.Db" + + fetcher = provider.get(Fetcher) + + assert fetcher is not None + assert fetcher.db is not None + assert fetcher.db.activating is not None + assert fetcher.db.activating == "tests.test_services.Fetcher" + assert fetcher.db.requested is not None + assert fetcher.db.requested == "tests.test_services.Db" + + def test_service_provider_supports_set_by_class(): provider = Services() From b03cb80f30df0d10c84a866662d1260e704a76ee Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:46:18 -0800 Subject: [PATCH 2/9] =?UTF-8?q?=F0=9F=9A=A8=20fix:=20lint=20issues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 10 +++++----- tests/test_services.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index c784c7c..4ccdfda 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -412,7 +412,7 @@ def get_annotations_type_provider( life_style: ServiceLifeStyle, resolver_context: ResolutionContext, ): - def factory(context, parent_type, requested_type): + def factory(context, parent_type, registered_type): instance = concrete_type() for name, resolver in resolvers.items(): setattr(instance, name, resolver(context, parent_type)) @@ -799,7 +799,7 @@ class FactoryWrapperNoArgs: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type, parent_type): + def __call__(self, context, activating_type, registered_type): return self.factory() @@ -809,7 +809,7 @@ class FactoryWrapperContextArg: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type, parent_type): + def __call__(self, context, activating_type, registered_type): return self.factory(context) @@ -819,7 +819,7 @@ class FactoryWrapperPartentArg: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type, parent_type): + def __call__(self, context, activating_type, registered_type): return self.factory(context, activating_type) @@ -1110,7 +1110,7 @@ def _check_factory(factory, signature, handled_type) -> Callable: if params_len == 2: return FactoryWrapperPartentArg(factory) - + if params_len == 3: return factory diff --git a/tests/test_services.py b/tests/test_services.py index ee9daed..47882b3 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -952,7 +952,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca def cat_factory_with_context_activating_type_and_requested_type( - context, activating_type, requested_type + context, activating_type, requested_type ) -> Cat: assert isinstance(context, ActivationScope) assert activating_type is Cat From cdf4ff0c694b8af546192df4ea2f5cddd5988d93 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:51:15 -0800 Subject: [PATCH 3/9] =?UTF-8?q?=F0=9F=9A=A8=20fix:=20lint=20issues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_fn_exec.py | 1 + tests/test_services.py | 24 ++++++++---------------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/tests/test_fn_exec.py b/tests/test_fn_exec.py index 541e800..91ff62e 100644 --- a/tests/test_fn_exec.py +++ b/tests/test_fn_exec.py @@ -2,6 +2,7 @@ Functions exec tests. exec functions are designed to enable executing any function injecting parameters. """ + import pytest from rodi import Container, inject diff --git a/tests/test_services.py b/tests/test_services.py index 47882b3..29c11fc 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2337,8 +2337,7 @@ def factory() -> annotation: def test_factory_without_locals_raises(): - def factory_without_context() -> None: - ... + def factory_without_context() -> None: ... with pytest.raises(FactoryMissingContextException): _get_factory_annotations_or_throw(factory_without_context) @@ -2346,8 +2345,7 @@ def factory_without_context() -> None: def test_factory_with_locals_get_annotations(): @inject() - def factory_without_context() -> "Cat": - ... + def factory_without_context() -> "Cat": ... annotations = _get_factory_annotations_or_throw(factory_without_context) @@ -2364,21 +2362,17 @@ def test_deps_github_scenario(): └── HTTPClient """ - class HTTPClient: - ... + class HTTPClient: ... - class CommentsService: - ... + class CommentsService: ... - class ChecksService: - ... + class ChecksService: ... class CLAHandler: comments_service: CommentsService checks_service: ChecksService - class GitHubSettings: - ... + class GitHubSettings: ... class GitHubAuthHandler: settings: GitHubSettings @@ -2494,11 +2488,9 @@ def test_provide_protocol_generic() -> None: T = TypeVar("T") class P(Protocol[T]): - def foo(self, t: T) -> T: - ... + def foo(self, t: T) -> T: ... - class A: - ... + class A: ... class Impl(P[A]): def foo(self, t: A) -> A: From d8cecac32418fef090b8716d60421a3da5baa8f4 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:04:35 -0800 Subject: [PATCH 4/9] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20fix:=20message=20for?= =?UTF-8?q?=20`rodi.InvalidFactory`=20exception?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 4ccdfda..72d6db8 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -203,7 +203,9 @@ def __init__(self, _type): super().__init__( f"The factory specified for type {class_name(_type)} is not " f"valid, it must be a function with either these signatures: " - f"def example_factory(context, type): " + f"def example_factory(context, activating_type, registered_type): " + f"or," + f"def example_factory(context, activating_type): " f"or," f"def example_factory(context): " f"or," From 1e30e6ef3345c07fa12c871dcd7c8127515199f9 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:08:28 -0800 Subject: [PATCH 5/9] =?UTF-8?q?=F0=9F=9A=A8=20fix:=20lint=20issues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_services.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index c6b09ac..e9cc3d4 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2530,8 +2530,7 @@ class B: def test_provide_protocol_with_attribute_dependency() -> None: class P(Protocol): - def foo(self) -> Any: - ... + def foo(self) -> Any: ... class Dependency: pass @@ -2558,8 +2557,7 @@ def foo(self) -> Any: def test_provide_protocol_with_init_dependency() -> None: class P(Protocol): - def foo(self) -> Any: - ... + def foo(self) -> Any: ... class Dependency: pass @@ -2612,11 +2610,9 @@ def test_provide_protocol_generic_with_inner_dependency() -> None: T = TypeVar("T") class P(Protocol[T]): - def foo(self, t: T) -> T: - ... + def foo(self, t: T) -> T: ... - class A: - ... + class A: ... class Dependency: pass From 8cd73643c09912c841e140d67ac10bc2edeb1ea3 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:12:06 -0800 Subject: [PATCH 6/9] =?UTF-8?q?=F0=9F=94=A7=20chore:=20make=20black=20matc?= =?UTF-8?q?h=20flake8=20rules?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .flake8 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 0f6f5a8..6703612 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] exclude = __pycache__,built,build,venv -ignore = E203, E266, W503 +ignore = E203, E266, W503, E704, E701 max-line-length = 88 max-complexity = 18 select = B,C,E,F,W,T4,B9 From 1a13accb4ed4b96799effa1add6d922da5bc4ee4 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:27:23 -0800 Subject: [PATCH 7/9] =?UTF-8?q?=F0=9F=94=96=20chore:=20update=20version=20?= =?UTF-8?q?to=20allow=20installing=20over=202.0.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rodi/__about__.py b/rodi/__about__.py index ff6ef86..1e0555b 100644 --- a/rodi/__about__.py +++ b/rodi/__about__.py @@ -1 +1 @@ -__version__ = "2.0.6" +__version__ = "2.0.7.dev1" From b16073a275d4c2ff758340065cfb305098e087bc Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:44:54 -0500 Subject: [PATCH 8/9] =?UTF-8?q?=F0=9F=90=9B=20fix:=20not=20resolving=20ali?= =?UTF-8?q?ases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 10 ++++++- tests/test_services.py | 63 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 318bdc2..5dd0145 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -511,7 +511,7 @@ def _get_resolvers_for_parameters( # but at least Optional could be supported in the future raise UnsupportedUnionTypeException(param_name, concrete_type) - if param_type is _empty: + if param_type is _empty or param_type not in services._map: if services.strict: raise CannotResolveParameterException(param_name, concrete_type) @@ -523,6 +523,14 @@ def _get_resolvers_for_parameters( else: aliases = services._aliases[param_name] + if not aliases: + cls_name = class_name(param_type) + aliases = ( + services._aliases[cls_name] + or services._aliases[cls_name.lower()] + or services._aliases[to_standard_param_name(cls_name)] + ) + if aliases: assert ( len(aliases) == 1 diff --git a/tests/test_services.py b/tests/test_services.py index e9cc3d4..2e67f88 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -687,6 +687,33 @@ def __init__(self, cats_controller, service_settings): assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) +def test_alias_dep_resolving(): + container = arrange_cats_example() + + class BaseClass: + pass + + class DerivedClass(BaseClass): + pass + + class UsingAliasByType: + def __init__(self, example: BaseClass): + self.example = example + + def resolve_derived_class(_) -> DerivedClass: + return DerivedClass() + + container.add_scoped_by_factory(resolve_derived_class, DerivedClass) + container.add_alias("BaseClass", DerivedClass) + container.add_scoped(UsingAliasByType) + + provider = container.build_provider() + u = provider.get(UsingAliasByType) + + assert isinstance(u, UsingAliasByType) + assert isinstance(u.example, DerivedClass) + + def test_get_service_by_name_or_alias(): container = arrange_cats_example() container.add_alias("k", CatsController) @@ -2381,7 +2408,8 @@ def factory() -> annotation: def test_factory_without_locals_raises(): - def factory_without_context() -> None: ... + def factory_without_context() -> None: + pass with pytest.raises(FactoryMissingContextException): _get_factory_annotations_or_throw(factory_without_context) @@ -2389,7 +2417,8 @@ def factory_without_context() -> None: ... def test_factory_with_locals_get_annotations(): @inject() - def factory_without_context() -> "Cat": ... + def factory_without_context() -> "Cat": + pass annotations = _get_factory_annotations_or_throw(factory_without_context) @@ -2406,17 +2435,21 @@ def test_deps_github_scenario(): └── HTTPClient """ - class HTTPClient: ... + class HTTPClient: + pass - class CommentsService: ... + class CommentsService: + pass - class ChecksService: ... + class ChecksService: + pass class CLAHandler: comments_service: CommentsService checks_service: ChecksService - class GitHubSettings: ... + class GitHubSettings: + pass class GitHubAuthHandler: settings: GitHubSettings @@ -2530,7 +2563,8 @@ class B: def test_provide_protocol_with_attribute_dependency() -> None: class P(Protocol): - def foo(self) -> Any: ... + def foo(self) -> Any: + pass class Dependency: pass @@ -2557,7 +2591,8 @@ def foo(self) -> Any: def test_provide_protocol_with_init_dependency() -> None: class P(Protocol): - def foo(self) -> Any: ... + def foo(self) -> Any: + pass class Dependency: pass @@ -2586,9 +2621,11 @@ def test_provide_protocol_generic() -> None: T = TypeVar("T") class P(Protocol[T]): - def foo(self, t: T) -> T: ... + def foo(self, t: T) -> T: + pass - class A: ... + class A: + pass class Impl(P[A]): def foo(self, t: A) -> A: @@ -2610,9 +2647,11 @@ def test_provide_protocol_generic_with_inner_dependency() -> None: T = TypeVar("T") class P(Protocol[T]): - def foo(self, t: T) -> T: ... + def foo(self, t: T) -> T: + pass - class A: ... + class A: + pass class Dependency: pass From d092b4b3e67d51fa8828f165cdb14a16a7d7f2df Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:58:23 -0500 Subject: [PATCH 9/9] =?UTF-8?q?=F0=9F=90=9B=20fix:=20handle=20resolving=20?= =?UTF-8?q?alias=20directly?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 7 +++++++ tests/test_services.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/rodi/__init__.py b/rodi/__init__.py index 5dd0145..ee30708 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -746,6 +746,13 @@ def get( scope = ActivationScope(self) resolver = self._map.get(desired_type) + if not resolver: + cls_name = class_name(desired_type) + resolver = ( + self._map.get(cls_name) + or self._map.get(cls_name.lower()) + or self._map.get(to_standard_param_name(cls_name)) + ) scoped_service = scope.scoped_services.get(desired_type) if scope else None if not resolver and not scoped_service: diff --git a/tests/test_services.py b/tests/test_services.py index 2e67f88..bdd3fe3 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -713,6 +713,9 @@ def resolve_derived_class(_) -> DerivedClass: assert isinstance(u, UsingAliasByType) assert isinstance(u.example, DerivedClass) + b = provider.get(BaseClass) + assert isinstance(b, DerivedClass) + def test_get_service_by_name_or_alias(): container = arrange_cats_example()