diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 5e01308..0000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -# E501: let black handle line length -# W503 is incompatible with PEP 8 -ignore = E501,W503 - diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 378d268..dec80ae 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12"] steps: - uses: actions/checkout@v2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d471ae..e2e4161 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,15 @@ repos: - - repo: https://github.com/ambv/black - rev: 24.1.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.9 hooks: - - id: black - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.15.0 hooks: - id: mypy additional_dependencies: diff --git a/app/_vendor/LICENSE.fastapi_versioning b/app/_vendor/LICENSE.fastapi_versioning new file mode 100644 index 0000000..d93181b --- /dev/null +++ b/app/_vendor/LICENSE.fastapi_versioning @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Dean Way + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/app/_vendor/__init__.py b/app/_vendor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/_vendor/fastapi_versioning/__init__.py b/app/_vendor/fastapi_versioning/__init__.py new file mode 100644 index 0000000..f86bcf7 --- /dev/null +++ b/app/_vendor/fastapi_versioning/__init__.py @@ -0,0 +1,8 @@ +from .routing import versioned_api_route +from .versioning import VersionedFastAPI, version + +__all__ = [ + "VersionedFastAPI", + "versioned_api_route", + "version", +] diff --git a/app/_vendor/fastapi_versioning/routing.py b/app/_vendor/fastapi_versioning/routing.py new file mode 100644 index 0000000..eeb34dc --- /dev/null +++ b/app/_vendor/fastapi_versioning/routing.py @@ -0,0 +1,18 @@ +from typing import Any, Type + +from fastapi.routing import APIRoute + + +def versioned_api_route( + major: int = 1, minor: int = 0, route_class: Type[APIRoute] = APIRoute +) -> Type[APIRoute]: + class VersionedAPIRoute(route_class): # type: ignore + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + try: + self.endpoint._api_version = (major, minor) + except AttributeError: + # Support bound methods + self.endpoint.__func__._api_version = (major, minor) + + return VersionedAPIRoute diff --git a/app/_vendor/fastapi_versioning/versioning.py b/app/_vendor/fastapi_versioning/versioning.py new file mode 100644 index 0000000..5d9be98 --- /dev/null +++ b/app/_vendor/fastapi_versioning/versioning.py @@ -0,0 +1,83 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, TypeVar, cast + +from fastapi import FastAPI +from fastapi.routing import APIRoute +from starlette.routing import BaseRoute + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def version(major: int, minor: int = 0) -> Callable[[CallableT], CallableT]: + def decorator(func: CallableT) -> CallableT: + func._api_version = (major, minor) # type: ignore + return func + + return decorator + + +def version_to_route( + route: BaseRoute, + default_version: Tuple[int, int], +) -> Tuple[Tuple[int, int], APIRoute]: + api_route = cast(APIRoute, route) + version = getattr(api_route.endpoint, "_api_version", default_version) + return version, api_route + + +def VersionedFastAPI( + app: FastAPI, + version_format: str = "{major}.{minor}", + prefix_format: str = "/v{major}_{minor}", + default_version: Tuple[int, int] = (1, 0), + enable_latest: bool = False, + **kwargs: Any, +) -> FastAPI: + parent_app = FastAPI( + title=app.title, + **kwargs, + ) + version_route_mapping: Dict[Tuple[int, int], List[APIRoute]] = defaultdict(list) + version_routes = [version_to_route(route, default_version) for route in app.routes] + + for version, route in version_routes: + version_route_mapping[version].append(route) + + unique_routes = {} + versions = sorted(version_route_mapping.keys()) + for version in versions: + major, minor = version + prefix = prefix_format.format(major=major, minor=minor) + semver = version_format.format(major=major, minor=minor) + versioned_app = FastAPI( + title=app.title, + description=app.description, + version=semver, + docs_url=None, + redoc_url=None, + ) + for route in version_route_mapping[version]: + for method in route.methods: + unique_routes[route.path + "|" + method] = route + for route in unique_routes.values(): + versioned_app.router.routes.append(route) + parent_app.mount(prefix, versioned_app) + + @parent_app.get(f"{prefix}/openapi.json", name=semver, tags=["Versions"]) + @parent_app.get(f"{prefix}/docs", name=semver, tags=["Documentations"]) + def noop() -> None: ... + + if enable_latest: + prefix = "/latest" + major, minor = version + semver = version_format.format(major=major, minor=minor) + versioned_app = FastAPI( + title=app.title, + description=app.description, + version=semver, + ) + for route in unique_routes.values(): + versioned_app.router.routes.append(route) + parent_app.mount(prefix, versioned_app) + + return parent_app diff --git a/app/api/login.py b/app/api/login.py index 4d9e510..081cca7 100644 --- a/app/api/login.py +++ b/app/api/login.py @@ -1,32 +1,133 @@ +import httpx from datetime import datetime, timedelta, timezone -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import APIRouter, Depends, HTTPException, Response, Request, status from fastapi.security import OAuth2PasswordRequestForm from fastapi.logger import logger from sqlalchemy.orm import Session -from .. import deps, crud, utils, auth -from ..settings import ACCESS_TOKEN_EXPIRE_MINUTES +from .. import deps, crud, utils, auth, schemas +from ..settings import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + OIDC_CLIENT_SECRET, + OIDC_SCOPE, +) router = APIRouter() +def create_access_token(db, username, response) -> dict[str, str]: + db_user = crud.get_user_by_username(db, username) + if db_user is None: + db_user = crud.create_user(db, username) + response.status_code = status.HTTP_201_CREATED + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = utils.create_access_token(db_user.username, expire=expire) + crud.update_user_login_token_expire_date(db, db_user, expire) + logger.info(f"User {username} successfully logged in") + return {"access_token": access_token, "token_type": "bearer"} + + @router.post("/login", status_code=status.HTTP_200_OK) def login( response: Response, db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends(), ): - if not auth.authenticate_user(form_data.username.lower(), form_data.password): - logger.warning(f"Authentication failed for {form_data.username.lower()}") + """Login using username/password""" + username = form_data.username.lower() + if not auth.authenticate_user(username, form_data.password): + logger.warning(f"Authentication failed for {username}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", ) - logger.info(f"User {form_data.username.lower()} successfully logged in") - db_user = crud.get_user_by_username(db, form_data.username.lower()) - if db_user is None: - db_user = crud.create_user(db, form_data.username.lower()) - response.status_code = status.HTTP_201_CREATED - expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = utils.create_access_token(db_user.username, expire=expire) - crud.update_user_login_token_expire_date(db, db_user, expire) - return {"access_token": access_token, "token_type": "bearer"} + return create_access_token(db, username, response) + + +@router.post("/open_id_connect", status_code=status.HTTP_200_OK) +async def open_id_connect( + oidc_auth: schemas.OpenIdConnectAuth, + response: Response, + request: Request, + db: Session = Depends(deps.get_db), +): + """Login using OpenID Connect Authentication Code flow from mobile client""" + oidc_config = request.state.oidc_config + data = { + "client_id": oidc_auth.client_id, + "client_secret": OIDC_CLIENT_SECRET, + "code": oidc_auth.code, + "code_verifier": oidc_auth.code_verifier, + "grant_type": "authorization_code", + "redirect_uri": oidc_auth.redirect_uri, + } + logger.info( + "Login via OIDC Authentication Code flow. " + f"Sending {data} to {oidc_config['token_endpoint']} to retrieve token." + ) + async with httpx.AsyncClient() as client: + try: + response = await client.post( + oidc_config["token_endpoint"], + data=data, + ) + response.raise_for_status() + except httpx.RequestError as exc: + logger.error( + f"An error occurred while requesting {exc.request.url!r}: {exc}." + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"An error occurred while requesting {exc.request.url!r}", + ) + except httpx.HTTPStatusError as exc: + logger.error(f"Failed to get OIDC token: {response.content}") + raise HTTPException( + status_code=exc.response.status_code, detail="Failed to get OIDC token" + ) + result = response.json() + access_token = result["access_token"] + id_token = result["id_token"] + logger.debug("Retrieved access and id tokens. Validating id_token.") + try: + utils.validate_id_token( + id_token, + access_token, + request.state.jwks_client, + request.state.oidc_config["id_token_signing_alg_values_supported"], + oidc_auth.client_id, + ) + except Exception as e: + logger.warning(f"id_token validation failed: {e}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="id_token validation failed", + ) + headers = {"Authorization": f"Bearer {access_token}"} + data = { + "client_id": oidc_auth.client_id, + "client_secret": OIDC_CLIENT_SECRET, + "scope": OIDC_SCOPE, + } + logger.info("Retrieving user info.") + try: + response = await client.post( + oidc_config["userinfo_endpoint"], + headers=headers, + data=data, + ) + response.raise_for_status() + except httpx.RequestError as exc: + logger.error( + f"An error occurred while requesting {exc.request.url!r}: {exc}." + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"An error occurred while requesting {exc.request.url!r}", + ) + except httpx.HTTPStatusError as exc: + logger.error(f"Failed to get user info: {response.content}") + raise HTTPException( + status_code=exc.response.status_code, detail="Failed to get user info" + ) + username = response.json()["preferred_username"].lower() + return create_access_token(db, username, response) diff --git a/app/api/users.py b/app/api/users.py index 9c4f216..b0695aa 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, Depends, Response, HTTPException, status -from fastapi_versioning import version +from .._vendor.fastapi_versioning import version from sqlalchemy.orm import Session from typing import List from .. import deps, crud, models, schemas diff --git a/app/auth.py b/app/auth.py index 2936746..8363812 100644 --- a/app/auth.py +++ b/app/auth.py @@ -37,10 +37,7 @@ def ldap_authenticate_user(username: str, password: str) -> bool: validate=ssl.CERT_REQUIRED, version=ssl.PROTOCOL_TLSv1_2, ciphers="ALL" ) server = ldap3.Server(LDAP_HOST, port=LDAP_PORT, use_ssl=LDAP_USE_SSL, tls=tls) - if LDAP_USER_DN: - user_search_dn = f"{LDAP_USER_DN},{LDAP_BASE_DN}" - else: - user_search_dn = LDAP_BASE_DN + user_search_dn = f"{LDAP_USER_DN},{LDAP_BASE_DN}" if LDAP_USER_DN else LDAP_BASE_DN bind_user = f"{LDAP_USER_RDN_ATTR}={username},{user_search_dn}" connection = ldap3.Connection( server=server, diff --git a/app/cookie_auth.py b/app/cookie_auth.py deleted file mode 100644 index 9e86df0..0000000 --- a/app/cookie_auth.py +++ /dev/null @@ -1,21 +0,0 @@ -from fastapi.responses import Response -from itsdangerous.serializer import Serializer -from .settings import SECRET_KEY, ACCESS_TOKEN_EXPIRE_MINUTES, AUTH_COOKIE_NAME - -serializer = Serializer(str(SECRET_KEY)) - - -def set_auth(response: Response, user_id: int): - val = serializer.dumps(user_id) - response.set_cookie( - AUTH_COOKIE_NAME, - val, - secure=False, - expires=ACCESS_TOKEN_EXPIRE_MINUTES * 60, - httponly=True, - samesite="Lax", - ) - - -def logout(response: Response): - response.delete_cookie(AUTH_COOKIE_NAME) diff --git a/app/crud.py b/app/crud.py index 30f6a62..febd5ee 100644 --- a/app/crud.py +++ b/app/crud.py @@ -215,10 +215,7 @@ def get_user_notifications( if filter_services_id is not None: query = query.filter(models.Notification.service_id.in_(filter_services_id)) query = query.order_by(desc(models.Notification.timestamp)) - if limit > 0: - query = query.limit(limit) - else: - query = query.all() + query = query.limit(limit) if limit > 0 else query.all() notifications = [un.to_user_notification() for un in query] # Sorting in ascending order is mostly for backward compatibility if sort == schemas.SortOrder.asc: diff --git a/app/deps.py b/app/deps.py index 95f5ae6..bf3a4d4 100644 --- a/app/deps.py +++ b/app/deps.py @@ -1,16 +1,29 @@ from fastapi import Depends, HTTPException, status from starlette.requests import Request -from fastapi.security import OAuth2PasswordBearer, APIKeyCookie +from fastapi.security import OAuth2PasswordBearer from fastapi.logger import logger -from itsdangerous.exc import BadSignature from sqlalchemy.orm import Session from jwt import PyJWTError, ExpiredSignatureError -from . import crud, models, utils, cookie_auth +from authlib.integrations.starlette_client import OAuth +from . import crud, models, utils from .database import SessionLocal -from .settings import AUTH_COOKIE_NAME +from .settings import ( + OIDC_NAME, + OIDC_SERVER_URL, + OIDC_CLIENT_ID, + OIDC_CLIENT_SECRET, + OIDC_SCOPE, +) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") -cookie_sec = APIKeyCookie(name=AUTH_COOKIE_NAME) +oauth = OAuth() +oauth.register( + OIDC_NAME, + client_id=OIDC_CLIENT_ID, + client_secret=str(OIDC_CLIENT_SECRET), + server_metadata_url=OIDC_SERVER_URL, + client_kwargs={"scope": OIDC_SCOPE}, +) def get_db(): @@ -22,7 +35,7 @@ def get_db(): def get_current_user( - db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) + request: Request, db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) ) -> models.User: """Return the current user based on the bearer token from the header""" credentials_exception = HTTPException( @@ -30,6 +43,14 @@ def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + # Special case for swagger UI + # To avoid implementing OpenId Connect flow with the Authorize button, + # we use the cookie from the session that should be present if the user + # already logged in via the web UI + # We inject a dummy bearer token as one is expected by the oauth2_scheme + # If the user isn't logged in, this will return a 401 + if token == "swagger-ui": + return get_current_user_from_session(request, db) try: payload = utils.decode_access_token(token) except ExpiredSignatureError: @@ -73,21 +94,16 @@ def get_current_admin_user( return current_user -def get_current_user_from_cookie( +def get_current_user_from_session( request: Request, db: Session = Depends(get_db) ) -> models.User: unauthorized_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication" ) - if AUTH_COOKIE_NAME not in request.cookies: + user_id = request.session.get("user_id") + if user_id is None: raise unauthorized_exception - cookie = request.cookies[AUTH_COOKIE_NAME] - try: - user_id = cookie_auth.serializer.loads(cookie) - except BadSignature as e: - logger.warning(f"Bad Signature, invalid cookie value: {e}") - raise unauthorized_exception - user = crud.get_user(db, user_id) + user = crud.get_user(db, int(user_id)) if user is None: logger.warning(f"Unknown user id {user_id}") raise unauthorized_exception diff --git a/app/main.py b/app/main.py index 3145ecb..9d1cc24 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,28 @@ +import contextlib import logging +import httpx +import jwt import sentry_sdk from pathlib import Path +from typing import AsyncIterator, TypedDict from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from fastapi import FastAPI -from fastapi_versioning import VersionedFastAPI +from ._vendor.fastapi_versioning import VersionedFastAPI from fastapi.logger import logger from fastapi.staticfiles import StaticFiles from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware from . import monitoring from .api import login, users, services -from .views import exceptions, account, notifications, settings -from .settings import SENTRY_DSN, ESS_NOTIFY_SERVER_ENVIRONMENT, SECRET_KEY +from .views import exceptions, account, notifications, settings, docs +from .settings import ( + SENTRY_DSN, + ESS_NOTIFY_SERVER_ENVIRONMENT, + SECRET_KEY, + SESSION_MAX_AGE, + OIDC_SERVER_URL, + OIDC_ENABLED, +) # The following logging setup assumes the app is run with gunicorn @@ -21,23 +32,55 @@ logger.handlers = gunicorn_error_logger.handlers logger.setLevel(gunicorn_error_logger.level) + +class State(TypedDict): + oidc_config: dict[str, str] + jwks_client: jwt.PyJWKClient | None + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[State]: + if OIDC_ENABLED: + async with httpx.AsyncClient() as client: + r = await client.get(OIDC_SERVER_URL) + oidc_config = r.json() + jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + else: + oidc_config = {} + jwks_client = None + yield {"oidc_config": oidc_config, "jwks_client": jwks_client} + + # Main application to serve HTML middleware = [ Middleware( - SessionMiddleware, secret_key=SECRET_KEY, session_cookie="notify_session" + SessionMiddleware, + secret_key=SECRET_KEY, + session_cookie="notify_session", + max_age=SESSION_MAX_AGE, + same_site="strict", + https_only=True, ) ] -app = FastAPI(exception_handlers=exceptions.exception_handlers, middleware=middleware) +app = FastAPI( + exception_handlers=exceptions.exception_handlers, + middleware=middleware, + lifespan=lifespan, + docs_url=None, + redoc_url=None, +) app.include_router(account.router) app.include_router(notifications.router, prefix="/notifications") app.include_router(settings.router, prefix="/settings") +app.include_router(docs.router) + # Serve static files app_dir = Path(__file__).parent.resolve() app.mount("/static", StaticFiles(directory=str(app_dir / "static")), name="static") # API mounted under /api -original_api = FastAPI() +original_api = FastAPI(docs_url=None, redoc_url=None) original_api.include_router(monitoring.router, prefix="/-", tags=["monitoring"]) original_api.include_router(login.router, tags=["login"]) original_api.include_router(users.router, prefix="/users", tags=["users"]) @@ -51,10 +94,13 @@ original_api, version_format="{major}", prefix_format="/v{major}", + docs_url=None, + redoc_url=None, ) app.mount("/api", versioned_api) + if SENTRY_DSN: sentry_sdk.init(dsn=SENTRY_DSN, environment=ESS_NOTIFY_SERVER_ENVIRONMENT) app = SentryAsgiMiddleware(app) diff --git a/app/schemas.py b/app/schemas.py index 87e361f..a7effca 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -164,3 +164,10 @@ class Aps(BaseModel): class ApnPayload(BaseModel): aps: Aps + + +class OpenIdConnectAuth(BaseModel): + code: str + code_verifier: str + client_id: str + redirect_uri: str diff --git a/app/settings.py b/app/settings.py index 171460f..c210de2 100644 --- a/app/settings.py +++ b/app/settings.py @@ -15,6 +15,7 @@ config = Config() # Should be set to "ldap" or "url" +# This is still supported for the API even when OIDC is enabled AUTHENTICATION_METHOD = config("AUTHENTICATION_METHOD", cast=str, default="ldap") # LDAP configuration LDAP_HOST = config("LDAP_HOST", cast=str, default="ldap.example.org") @@ -24,6 +25,21 @@ LDAP_USER_DN = config("LDAP_USER_DN", cast=str, default="") LDAP_USER_RDN_ATTR = config("LDAP_USER_RDN_ATTR", cast=str, default="uid") +# OpenID Connect configuration +# When enabled OIDC will be used for: +# - web login (only method supported) +# - API login (old authentication method still supported as well) +OIDC_ENABLED = config("OIDC_ENABLED", cast=bool, default=False) +OIDC_NAME = config("OIDC_NAME", cast=str, default="keycloak") +OIDC_SERVER_URL = config( + "OIDC_SERVER_URL", + cast=str, + default="https://keycloak.example.org/auth/realms/myrealm/.well-known/openid-configuration", +) +OIDC_CLIENT_ID = config("OIDC_CLIENT_ID", cast=str, default="notify") +OIDC_CLIENT_SECRET = config("OIDC_CLIENT_SECRET", cast=Secret, default="!secret") +OIDC_SCOPE = config("OIDC_SCOPE", cast=str, default="openid email profile") + # URL to use when AUTHENTICATION_METHOD is set to "url" AUTHENTICATION_URL = config( "AUTHENTICATION_URL", cast=str, default="https//auth.example.org/login" @@ -36,6 +52,8 @@ "SQLALCHEMY_DATABASE_URL", cast=str, default="sqlite:///./sql_app.db" ) SQLALCHEMY_DEBUG = config("SQLALCHEMY_DEBUG", cast=bool, default=False) +# Session expiry time in seconds: 12 hours (12 * 60 * 60 = 43200) +SESSION_MAX_AGE = config("SESSION_MAX_AGE", cast=int, default=43200) APNS_ALGORITHM = "ES256" APNS_KEY_ID = config("APNS_KEY_ID", cast=Secret, default="key-id") APNS_AUTH_KEY = config("APNS_AUTH_KEY", cast=Secret, default=DUMMY_PRIVATE_KEY) @@ -63,8 +81,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = config( "ACCESS_TOKEN_EXPIRE_MINUTES", cast=int, default=43200 ) -# Cookie name -AUTH_COOKIE_NAME = config("AUTH_COOKIE_NAME", cast=str, default="notify_token") # Number of push notifications sent in parallel NB_PARALLEL_PUSH = config("NB_PARALLEL_PUSH", cast=int, default=50) diff --git a/app/static/js/swagger-ui-custom.js b/app/static/js/swagger-ui-custom.js new file mode 100644 index 0000000..81c9e7e --- /dev/null +++ b/app/static/js/swagger-ui-custom.js @@ -0,0 +1,33 @@ +window.onload = function () { + // Extract API version from the URL (e.g., "/api/v1/docs" -> "v1") + const pathParts = window.location.pathname.split("/"); + const version = pathParts.length >= 3 ? pathParts[2] : "v1"; // Default to v1 if missing + + // Construct the OpenAPI URL dynamically + const openapiUrl = `/api/${version}/openapi.json`; + + setTimeout(() => { + fetch(openapiUrl) // Load the correct OpenAPI schema + .then(response => response.json()) + .then(spec => { + spec.host = window.location.host; + spec.schemes = [window.location.protocol.replace(':', '')]; + + spec.info.description = 'To perform authenticated requests, do not use "Authorize" but login via the web UI first.'; + + window.ui = SwaggerUIBundle({ + spec: spec, + dom_id: '#swagger-ui', + deepLinking: true, + presets: [SwaggerUIBundle.presets.apis, SwaggerUIBundle.SwaggerUIStandalonePreset], + requestInterceptor: request => { + // Add custom bearer token so that the user is retrieved from the session + // (if logged in) + request.headers['Authorization'] = "Bearer swagger-ui"; + return request; + }, + }); + }) + .catch(error => console.error(`Error loading OpenAPI spec for ${version}:`, error)); + }, 1000); +}; diff --git a/app/templates/400.html b/app/templates/400.html new file mode 100644 index 0000000..a2635a5 --- /dev/null +++ b/app/templates/400.html @@ -0,0 +1,8 @@ +{%- extends "base.html" %} + +{% block title %}Bad Request{% endblock %} + +{% block main %} +
{{ detail }}
+{%- endblock %} diff --git a/app/utils.py b/app/utils.py index 46a3e65..0378918 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,4 +1,5 @@ import asyncio +import base64 import httpx import ipaddress import uuid @@ -119,3 +120,39 @@ async def send_notification(notification_id: int) -> None: await android_client.aclose() finally: db.close() + + +def validate_id_token( + id_token: str, + access_token: str, + jwks_client: jwt.PyJWKClient, + signing_algos: list[str], + client_id: str, +) -> None: + """Raise an exception if the validation of the id token fails""" + # See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow + signing_key = jwks_client.get_signing_key_from_jwt(id_token) + # Decode and verify id_token claims + # expiration, issued at, not before, audience and issuer + data = jwt.decode_complete( + id_token, + key=signing_key, + audience=client_id, + algorithms=signing_algos, + require=["exp", "iat", "nbf", "aud", "iss"], + verify_signature=True, + ) + payload, header = data["payload"], data["header"] + alg_obj = jwt.get_algorithm_by_name(header["alg"]) + # compute at_hash, then validate + # access_token must be bytes (not str) + digest = alg_obj.compute_hash_digest(access_token.encode("utf-8")) + at_hash = ( + base64.urlsafe_b64encode(digest[: (len(digest) // 2)]) + .rstrip(b"=") + .decode("utf-8") + ) + if at_hash != payload["at_hash"]: + raise ValueError( + f"at_hash value {payload['at_hash']} doesn't match computed {at_hash}" + ) diff --git a/app/views/account.py b/app/views/account.py index 1ba2f69..099df0a 100644 --- a/app/views/account.py +++ b/app/views/account.py @@ -1,11 +1,12 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, status, HTTPException from fastapi.logger import logger from starlette.responses import HTMLResponse, RedirectResponse from starlette.requests import Request from sqlalchemy.orm import Session +from authlib.integrations.base_client.errors import OAuthError from . import templates -from .. import deps, cookie_auth, crud, auth, models -from ..settings import APP_NAME +from .. import deps, crud, auth, models +from ..settings import APP_NAME, OIDC_ENABLED router = APIRouter() @@ -13,22 +14,26 @@ @router.get("/", response_class=HTMLResponse, name="index") async def index( request: Request, - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): return RedirectResponse(url="/notifications") @router.get("/login", response_class=HTMLResponse) async def login_get(request: Request): - return templates.TemplateResponse( - "login.html", - { - "request": request, - "username": "", - "password": "", - "error": "", - }, - ) + if OIDC_ENABLED: + redirect_uri = request.url_for("oidc_auth") + return await deps.oauth.keycloak.authorize_redirect(request, redirect_uri) + else: + return templates.TemplateResponse( + "login.html", + { + "request": request, + "username": "", + "password": "", + "error": "", + }, + ) @router.post("/login", response_class=HTMLResponse) @@ -36,6 +41,10 @@ async def login_post( request: Request, db: Session = Depends(deps.get_db), ): + if OIDC_ENABLED: + raise HTTPException( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Invalid method" + ) form = await request.form() username = form.get("username", "").lower().strip() password = form.get("password", "").strip() @@ -59,14 +68,35 @@ async def login_post( db_user = crud.create_user(db, username.lower()) resp = RedirectResponse("/", status_code=status.HTTP_302_FOUND) - cookie_auth.set_auth(resp, db_user.id) + request.session["user_id"] = db_user.id return resp +@router.get("/auth") +async def oidc_auth( + request: Request, + db: Session = Depends(deps.get_db), +): + try: + token = await deps.oauth.keycloak.authorize_access_token(request) + except OAuthError as e: + logger.warning(f"OAuthError on OpenID Connect redirect: {e}") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + user_info = token["userinfo"] + if user_info: + username = user_info["preferred_username"].lower() + db_user = crud.get_user_by_username(db, username) + if db_user is None: + db_user = crud.create_user(db, username) + request.session["user_id"] = db_user.id + return RedirectResponse(url=request.session.pop("next", "/")) + return RedirectResponse(url="/login") + + @router.get("/logout") -def logout(): +def logout(request: Request): response = RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND) - cookie_auth.logout(response) + request.session.pop("user_id", None) return response diff --git a/app/views/docs.py b/app/views/docs.py new file mode 100644 index 0000000..5be68c0 --- /dev/null +++ b/app/views/docs.py @@ -0,0 +1,27 @@ +from fastapi import APIRouter +from starlette.responses import HTMLResponse + +router = APIRouter() + + +# Dynamic Swagger UI Route (works for `/api/v1/docs` and `/api/v2/docs`) +# Override the default Swagger UI endpoint to load some custom javascript +# and inject a bearer token +@router.get("/api/{version}/docs", include_in_schema=False) +async def custom_swagger_ui(version: str): + html = """ + + + + +