diff --git a/ninja_extra/controllers/base.py b/ninja_extra/controllers/base.py index 2ec563f0..4de60732 100644 --- a/ninja_extra/controllers/base.py +++ b/ninja_extra/controllers/base.py @@ -73,23 +73,31 @@ class MissingAPIControllerDecoratorException(Exception): def get_route_functions(cls: Type) -> Iterable[RouteFunction]: """ - Get all route functions from a controller class. - This function will recursively search for route functions in the base classes of the controller class - in order that they are defined. + Return fresh RouteFunction instances for a controller class. - Args: - cls (Type): The controller class. - - Returns: - Iterable[RouteFunction]: An iterable of route functions. + Each call yields a clone of the RouteFunction template stored on the + controller method, ensuring metadata is not shared across subclasses. """ - bases = inspect.getmro(cls) - for base_cls in reversed(bases): - if base_cls not in [ControllerBase, ABC, object]: - for method in base_cls.__dict__.values(): - if hasattr(method, ROUTE_FUNCTION): - yield getattr(method, ROUTE_FUNCTION) + for _, method, template in _iter_route_templates(cls): + yield template.clone(method) + + +def _iter_route_templates( + cls: Type, +) -> Iterable[Tuple[str, Callable[..., Any], RouteFunction]]: + seen: set[str] = set() + for base_cls in inspect.getmro(cls): + if base_cls in (ControllerBase, ABC, object): + continue + for attr_name, method in base_cls.__dict__.items(): + if attr_name in seen: + continue + route_template = getattr(method, ROUTE_FUNCTION, None) + if route_template is None: + continue + seen.add(attr_name) + yield attr_name, method, route_template def get_all_controller_route_function( diff --git a/ninja_extra/controllers/model/builder.py b/ninja_extra/controllers/model/builder.py index 3f5dff22..30264ea5 100644 --- a/ninja_extra/controllers/model/builder.py +++ b/ninja_extra/controllers/model/builder.py @@ -51,7 +51,8 @@ def __init__( ) def _add_to_controller(self, func: t.Callable) -> None: - route_function = getattr(func, ROUTE_FUNCTION) + route_template = getattr(func, ROUTE_FUNCTION) + route_function = route_template.clone(func) route_function.api_controller = self._api_controller_instance self._api_controller_instance.add_controller_route_function(route_function) diff --git a/ninja_extra/controllers/route/route_functions.py b/ninja_extra/controllers/route/route_functions.py index b115acf1..e17d92b4 100644 --- a/ninja_extra/controllers/route/route_functions.py +++ b/ninja_extra/controllers/route/route_functions.py @@ -2,7 +2,18 @@ import warnings from contextlib import contextmanager from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from django.http import HttpRequest, HttpResponse @@ -16,6 +27,11 @@ from ninja_extra.controllers.base import APIController, ControllerBase from ninja_extra.controllers.route import Route from ninja_extra.operation import Operation + from ninja_extra.permissions import BasePermission + +RoutePermissions = Optional[ + List[Union[Type["BasePermission"], "BasePermission", Any]] +] class RouteFunctionContext: @@ -104,6 +120,28 @@ def as_view( as_view.get_route_function = lambda: self # type:ignore return as_view + def clone(self, view_func: Callable[..., Any]) -> "RouteFunction": + from ninja_extra.controllers.route import Route + + route_params = self.route.route_params.dict() + permissions: RoutePermissions + if self.route.permissions is None: + permissions = None + else: + permissions = cast(RoutePermissions, list(self.route.permissions)) + + if route_params["tags"] is not None: + route_params["tags"] = list(route_params["tags"]) + route_params["methods"] = list(route_params["methods"]) + + cloned_route = Route( + view_func, + **route_params, + permissions=permissions, + ) + + return type(self)(route=cloned_route) + def _process_view_function_result(self, result: Any) -> Any: """ This process any a returned value from view_func diff --git a/ninja_extra/helper.py b/ninja_extra/helper.py index f4b4de90..b9a1f96d 100644 --- a/ninja_extra/helper.py +++ b/ninja_extra/helper.py @@ -15,6 +15,17 @@ def get_function_name(func_class: t.Any) -> str: @t.no_type_check def get_route_function(func: t.Callable) -> t.Optional["RouteFunction"]: - if hasattr(func, ROUTE_FUNCTION): - return func.__dict__[ROUTE_FUNCTION] - return None # pragma: no cover + controller_instance = getattr(func, "__self__", None) + + if controller_instance is not None: + controller_class = controller_instance.__class__ + api_controller = controller_class.get_api_controller() + return api_controller._controller_class_route_functions.get(func.__name__) + + # Unbound function – return a clone of the template for introspection + underlying_func = getattr(func, "__func__", func) + route_template = getattr(underlying_func, ROUTE_FUNCTION, None) + if route_template is None: + return None # pragma: no cover + + return route_template.clone(underlying_func) diff --git a/ninja_extra/testing/client.py b/ninja_extra/testing/client.py index 2b313e02..50ace2c4 100644 --- a/ninja_extra/testing/client.py +++ b/ninja_extra/testing/client.py @@ -1,8 +1,9 @@ from json import dumps as json_dumps -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from unittest.mock import Mock from urllib.parse import urlencode +from django.urls import Resolver404 from ninja import NinjaAPI, Router from ninja.responses import NinjaJSONEncoder from ninja.testing.client import NinjaClientBase, NinjaResponse @@ -42,6 +43,20 @@ def request( ) return self._call(func, request, kwargs) # type: ignore + def _resolve( + self, method: str, path: str, data: Dict, request_params: Any + ) -> Tuple[Callable, Mock, Dict]: + url_path = path.split("?")[0].lstrip("/") + for url in self.urls: + try: + match = url.resolve(url_path) + except Resolver404: + continue + if match: + request = self._build_request(method, path, data, request_params) + return match.func, request, match.kwargs + raise Exception(f'Cannot resolve "{path}"') + class TestClient(NinjaExtraClientBase): def _call(self, func: Callable, request: Mock, kwargs: Dict) -> "NinjaResponse": diff --git a/tests/test_controller.py b/tests/test_controller.py index 2904de83..9295db6c 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -82,6 +82,22 @@ def example(self): pass +class ReportControllerBase(ControllerBase): + @http_get("") + def report(self): + return {"controller": type(self).__name__} + + +@api_controller("/alpha", urls_namespace="alpha") +class AlphaReportController(ReportControllerBase): + pass + + +@api_controller("/beta", urls_namespace="beta") +class BetaReportController(ReportControllerBase): + pass + + class TestAPIController: def test_api_controller_as_decorator(self): controller_type = api_controller("prefix", tags="new_tag", auth=FakeAuth())( @@ -321,6 +337,49 @@ async def test_controller_base_aget_object_or_none_works(self): assert isinstance(ex, exceptions.PermissionDenied) +def test_controller_subclass_routes_remain_isolated(): + api = NinjaExtraAPI() + api.register_controllers(AlphaReportController) + api.register_controllers(BetaReportController) + client = testing.TestClient(api) + + alpha_response = client.get("/alpha") + beta_response = client.get("/beta") + + assert alpha_response.status_code == 200 + assert beta_response.status_code == 200 + assert alpha_response.json() == {"controller": "AlphaReportController"} + assert beta_response.json() == {"controller": "BetaReportController"} + + +def test_controller_multi_level_inheritance_routes_isolated(): + """Test that route isolation works with multi-level inheritance.""" + # Middle layer doesn't override the route + class MiddleReportController(ReportControllerBase): + pass + + @api_controller("/gamma") + class GammaReportController(MiddleReportController): + pass + + @api_controller("/delta") + class DeltaReportController(MiddleReportController): + pass + + api = NinjaExtraAPI() + api.register_controllers(GammaReportController) + api.register_controllers(DeltaReportController) + client = testing.TestClient(api) + + gamma_response = client.get("/gamma") + delta_response = client.get("/delta") + + assert gamma_response.status_code == 200 + assert delta_response.status_code == 200 + assert gamma_response.json() == {"controller": "GammaReportController"} + assert delta_response.json() == {"controller": "DeltaReportController"} + + def test_controller_registration_through_string(): assert DisableAutoImportController.get_api_controller().registered is False