From 76b0974cc8a7d4cd6eaa046b09644c00407a63e4 Mon Sep 17 00:00:00 2001 From: Trond Hindenes Date: Mon, 26 Dec 2022 12:20:24 +0100 Subject: [PATCH] allow BYO auth backend --- piccolo_admin/endpoints.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/piccolo_admin/endpoints.py b/piccolo_admin/endpoints.py index 609bc566..791a8048 100644 --- a/piccolo_admin/endpoints.py +++ b/piccolo_admin/endpoints.py @@ -43,6 +43,8 @@ from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse from starlette.staticfiles import StaticFiles +from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser +from piccolo_api.shared.auth import UnauthenticatedUser, User from .translations.data import TRANSLATIONS from .translations.models import ( @@ -358,6 +360,7 @@ def __init__( translations: t.List[Translation] = None, allowed_hosts: t.Sequence[str] = [], debug: bool = False, + auth_backend: t.Optional[AuthenticationBackend] = None, ) -> None: super().__init__( title=site_name, @@ -653,16 +656,25 @@ def __init__( app=StaticFiles(directory=os.path.join(ASSET_PATH, "js")), ) - auth_middleware = partial( - AuthenticationMiddleware, - backend=SessionsAuthBackend( - auth_table=auth_table, - session_table=session_table, - admin_only=True, - increase_expiry=increase_expiry, - ), - on_error=handle_auth_exception, - ) + + + if auth_backend: + auth_middleware = partial( + AuthenticationMiddleware, + backend=auth_backend, + on_error=handle_auth_exception, + ) + else: + auth_middleware = partial( + AuthenticationMiddleware, + backend=SessionsAuthBackend( + auth_table=auth_table, + session_table=session_table, + admin_only=True, + increase_expiry=increase_expiry, + ), + on_error=handle_auth_exception, + ) self.mount(path="/api", app=auth_middleware(private_app)) self.mount(path="/public", app=public_app) @@ -952,6 +964,7 @@ def create_admin( auto_include_related: bool = True, allowed_hosts: t.Sequence[str] = [], debug: bool = False, + auth_backend: t.Optional[AuthenticationBackend] = None, ): """ :param tables: @@ -1099,4 +1112,5 @@ def create_admin( translations=translations, allowed_hosts=allowed_hosts, debug=debug, + auth_backend=auth_backend )