diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 5e01308..0000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -# E501: let black handle line length -# W503 is incompatible with PEP 8 -ignore = E501,W503 - diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 378d268..dec80ae 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12"] steps: - uses: actions/checkout@v2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d471ae..e2e4161 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,15 @@ repos: - - repo: https://github.com/ambv/black - rev: 24.1.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.9 hooks: - - id: black - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.15.0 hooks: - id: mypy additional_dependencies: diff --git a/app/_vendor/LICENSE.fastapi_versioning b/app/_vendor/LICENSE.fastapi_versioning new file mode 100644 index 0000000..d93181b --- /dev/null +++ b/app/_vendor/LICENSE.fastapi_versioning @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Dean Way + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/app/_vendor/__init__.py b/app/_vendor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/_vendor/fastapi_versioning/__init__.py b/app/_vendor/fastapi_versioning/__init__.py new file mode 100644 index 0000000..f86bcf7 --- /dev/null +++ b/app/_vendor/fastapi_versioning/__init__.py @@ -0,0 +1,8 @@ +from .routing import versioned_api_route +from .versioning import VersionedFastAPI, version + +__all__ = [ + "VersionedFastAPI", + "versioned_api_route", + "version", +] diff --git a/app/_vendor/fastapi_versioning/routing.py b/app/_vendor/fastapi_versioning/routing.py new file mode 100644 index 0000000..eeb34dc --- /dev/null +++ b/app/_vendor/fastapi_versioning/routing.py @@ -0,0 +1,18 @@ +from typing import Any, Type + +from fastapi.routing import APIRoute + + +def versioned_api_route( + major: int = 1, minor: int = 0, route_class: Type[APIRoute] = APIRoute +) -> Type[APIRoute]: + class VersionedAPIRoute(route_class): # type: ignore + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + try: + self.endpoint._api_version = (major, minor) + except AttributeError: + # Support bound methods + self.endpoint.__func__._api_version = (major, minor) + + return VersionedAPIRoute diff --git a/app/_vendor/fastapi_versioning/versioning.py b/app/_vendor/fastapi_versioning/versioning.py new file mode 100644 index 0000000..5d9be98 --- /dev/null +++ b/app/_vendor/fastapi_versioning/versioning.py @@ -0,0 +1,83 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, TypeVar, cast + +from fastapi import FastAPI +from fastapi.routing import APIRoute +from starlette.routing import BaseRoute + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def version(major: int, minor: int = 0) -> Callable[[CallableT], CallableT]: + def decorator(func: CallableT) -> CallableT: + func._api_version = (major, minor) # type: ignore + return func + + return decorator + + +def version_to_route( + route: BaseRoute, + default_version: Tuple[int, int], +) -> Tuple[Tuple[int, int], APIRoute]: + api_route = cast(APIRoute, route) + version = getattr(api_route.endpoint, "_api_version", default_version) + return version, api_route + + +def VersionedFastAPI( + app: FastAPI, + version_format: str = "{major}.{minor}", + prefix_format: str = "/v{major}_{minor}", + default_version: Tuple[int, int] = (1, 0), + enable_latest: bool = False, + **kwargs: Any, +) -> FastAPI: + parent_app = FastAPI( + title=app.title, + **kwargs, + ) + version_route_mapping: Dict[Tuple[int, int], List[APIRoute]] = defaultdict(list) + version_routes = [version_to_route(route, default_version) for route in app.routes] + + for version, route in version_routes: + version_route_mapping[version].append(route) + + unique_routes = {} + versions = sorted(version_route_mapping.keys()) + for version in versions: + major, minor = version + prefix = prefix_format.format(major=major, minor=minor) + semver = version_format.format(major=major, minor=minor) + versioned_app = FastAPI( + title=app.title, + description=app.description, + version=semver, + docs_url=None, + redoc_url=None, + ) + for route in version_route_mapping[version]: + for method in route.methods: + unique_routes[route.path + "|" + method] = route + for route in unique_routes.values(): + versioned_app.router.routes.append(route) + parent_app.mount(prefix, versioned_app) + + @parent_app.get(f"{prefix}/openapi.json", name=semver, tags=["Versions"]) + @parent_app.get(f"{prefix}/docs", name=semver, tags=["Documentations"]) + def noop() -> None: ... + + if enable_latest: + prefix = "/latest" + major, minor = version + semver = version_format.format(major=major, minor=minor) + versioned_app = FastAPI( + title=app.title, + description=app.description, + version=semver, + ) + for route in unique_routes.values(): + versioned_app.router.routes.append(route) + parent_app.mount(prefix, versioned_app) + + return parent_app diff --git a/app/api/login.py b/app/api/login.py index 4d9e510..081cca7 100644 --- a/app/api/login.py +++ b/app/api/login.py @@ -1,32 +1,133 @@ +import httpx from datetime import datetime, timedelta, timezone -from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi import APIRouter, Depends, HTTPException, Response, Request, status from fastapi.security import OAuth2PasswordRequestForm from fastapi.logger import logger from sqlalchemy.orm import Session -from .. import deps, crud, utils, auth -from ..settings import ACCESS_TOKEN_EXPIRE_MINUTES +from .. import deps, crud, utils, auth, schemas +from ..settings import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + OIDC_CLIENT_SECRET, + OIDC_SCOPE, +) router = APIRouter() +def create_access_token(db, username, response) -> dict[str, str]: + db_user = crud.get_user_by_username(db, username) + if db_user is None: + db_user = crud.create_user(db, username) + response.status_code = status.HTTP_201_CREATED + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = utils.create_access_token(db_user.username, expire=expire) + crud.update_user_login_token_expire_date(db, db_user, expire) + logger.info(f"User {username} successfully logged in") + return {"access_token": access_token, "token_type": "bearer"} + + @router.post("/login", status_code=status.HTTP_200_OK) def login( response: Response, db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends(), ): - if not auth.authenticate_user(form_data.username.lower(), form_data.password): - logger.warning(f"Authentication failed for {form_data.username.lower()}") + """Login using username/password""" + username = form_data.username.lower() + if not auth.authenticate_user(username, form_data.password): + logger.warning(f"Authentication failed for {username}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", ) - logger.info(f"User {form_data.username.lower()} successfully logged in") - db_user = crud.get_user_by_username(db, form_data.username.lower()) - if db_user is None: - db_user = crud.create_user(db, form_data.username.lower()) - response.status_code = status.HTTP_201_CREATED - expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = utils.create_access_token(db_user.username, expire=expire) - crud.update_user_login_token_expire_date(db, db_user, expire) - return {"access_token": access_token, "token_type": "bearer"} + return create_access_token(db, username, response) + + +@router.post("/open_id_connect", status_code=status.HTTP_200_OK) +async def open_id_connect( + oidc_auth: schemas.OpenIdConnectAuth, + response: Response, + request: Request, + db: Session = Depends(deps.get_db), +): + """Login using OpenID Connect Authentication Code flow from mobile client""" + oidc_config = request.state.oidc_config + data = { + "client_id": oidc_auth.client_id, + "client_secret": OIDC_CLIENT_SECRET, + "code": oidc_auth.code, + "code_verifier": oidc_auth.code_verifier, + "grant_type": "authorization_code", + "redirect_uri": oidc_auth.redirect_uri, + } + logger.info( + "Login via OIDC Authentication Code flow. " + f"Sending {data} to {oidc_config['token_endpoint']} to retrieve token." + ) + async with httpx.AsyncClient() as client: + try: + response = await client.post( + oidc_config["token_endpoint"], + data=data, + ) + response.raise_for_status() + except httpx.RequestError as exc: + logger.error( + f"An error occurred while requesting {exc.request.url!r}: {exc}." + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"An error occurred while requesting {exc.request.url!r}", + ) + except httpx.HTTPStatusError as exc: + logger.error(f"Failed to get OIDC token: {response.content}") + raise HTTPException( + status_code=exc.response.status_code, detail="Failed to get OIDC token" + ) + result = response.json() + access_token = result["access_token"] + id_token = result["id_token"] + logger.debug("Retrieved access and id tokens. Validating id_token.") + try: + utils.validate_id_token( + id_token, + access_token, + request.state.jwks_client, + request.state.oidc_config["id_token_signing_alg_values_supported"], + oidc_auth.client_id, + ) + except Exception as e: + logger.warning(f"id_token validation failed: {e}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="id_token validation failed", + ) + headers = {"Authorization": f"Bearer {access_token}"} + data = { + "client_id": oidc_auth.client_id, + "client_secret": OIDC_CLIENT_SECRET, + "scope": OIDC_SCOPE, + } + logger.info("Retrieving user info.") + try: + response = await client.post( + oidc_config["userinfo_endpoint"], + headers=headers, + data=data, + ) + response.raise_for_status() + except httpx.RequestError as exc: + logger.error( + f"An error occurred while requesting {exc.request.url!r}: {exc}." + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"An error occurred while requesting {exc.request.url!r}", + ) + except httpx.HTTPStatusError as exc: + logger.error(f"Failed to get user info: {response.content}") + raise HTTPException( + status_code=exc.response.status_code, detail="Failed to get user info" + ) + username = response.json()["preferred_username"].lower() + return create_access_token(db, username, response) diff --git a/app/api/users.py b/app/api/users.py index 9c4f216..b0695aa 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, Depends, Response, HTTPException, status -from fastapi_versioning import version +from .._vendor.fastapi_versioning import version from sqlalchemy.orm import Session from typing import List from .. import deps, crud, models, schemas diff --git a/app/auth.py b/app/auth.py index 2936746..8363812 100644 --- a/app/auth.py +++ b/app/auth.py @@ -37,10 +37,7 @@ def ldap_authenticate_user(username: str, password: str) -> bool: validate=ssl.CERT_REQUIRED, version=ssl.PROTOCOL_TLSv1_2, ciphers="ALL" ) server = ldap3.Server(LDAP_HOST, port=LDAP_PORT, use_ssl=LDAP_USE_SSL, tls=tls) - if LDAP_USER_DN: - user_search_dn = f"{LDAP_USER_DN},{LDAP_BASE_DN}" - else: - user_search_dn = LDAP_BASE_DN + user_search_dn = f"{LDAP_USER_DN},{LDAP_BASE_DN}" if LDAP_USER_DN else LDAP_BASE_DN bind_user = f"{LDAP_USER_RDN_ATTR}={username},{user_search_dn}" connection = ldap3.Connection( server=server, diff --git a/app/cookie_auth.py b/app/cookie_auth.py deleted file mode 100644 index 9e86df0..0000000 --- a/app/cookie_auth.py +++ /dev/null @@ -1,21 +0,0 @@ -from fastapi.responses import Response -from itsdangerous.serializer import Serializer -from .settings import SECRET_KEY, ACCESS_TOKEN_EXPIRE_MINUTES, AUTH_COOKIE_NAME - -serializer = Serializer(str(SECRET_KEY)) - - -def set_auth(response: Response, user_id: int): - val = serializer.dumps(user_id) - response.set_cookie( - AUTH_COOKIE_NAME, - val, - secure=False, - expires=ACCESS_TOKEN_EXPIRE_MINUTES * 60, - httponly=True, - samesite="Lax", - ) - - -def logout(response: Response): - response.delete_cookie(AUTH_COOKIE_NAME) diff --git a/app/crud.py b/app/crud.py index 30f6a62..febd5ee 100644 --- a/app/crud.py +++ b/app/crud.py @@ -215,10 +215,7 @@ def get_user_notifications( if filter_services_id is not None: query = query.filter(models.Notification.service_id.in_(filter_services_id)) query = query.order_by(desc(models.Notification.timestamp)) - if limit > 0: - query = query.limit(limit) - else: - query = query.all() + query = query.limit(limit) if limit > 0 else query.all() notifications = [un.to_user_notification() for un in query] # Sorting in ascending order is mostly for backward compatibility if sort == schemas.SortOrder.asc: diff --git a/app/deps.py b/app/deps.py index 95f5ae6..bf3a4d4 100644 --- a/app/deps.py +++ b/app/deps.py @@ -1,16 +1,29 @@ from fastapi import Depends, HTTPException, status from starlette.requests import Request -from fastapi.security import OAuth2PasswordBearer, APIKeyCookie +from fastapi.security import OAuth2PasswordBearer from fastapi.logger import logger -from itsdangerous.exc import BadSignature from sqlalchemy.orm import Session from jwt import PyJWTError, ExpiredSignatureError -from . import crud, models, utils, cookie_auth +from authlib.integrations.starlette_client import OAuth +from . import crud, models, utils from .database import SessionLocal -from .settings import AUTH_COOKIE_NAME +from .settings import ( + OIDC_NAME, + OIDC_SERVER_URL, + OIDC_CLIENT_ID, + OIDC_CLIENT_SECRET, + OIDC_SCOPE, +) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") -cookie_sec = APIKeyCookie(name=AUTH_COOKIE_NAME) +oauth = OAuth() +oauth.register( + OIDC_NAME, + client_id=OIDC_CLIENT_ID, + client_secret=str(OIDC_CLIENT_SECRET), + server_metadata_url=OIDC_SERVER_URL, + client_kwargs={"scope": OIDC_SCOPE}, +) def get_db(): @@ -22,7 +35,7 @@ def get_db(): def get_current_user( - db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) + request: Request, db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) ) -> models.User: """Return the current user based on the bearer token from the header""" credentials_exception = HTTPException( @@ -30,6 +43,14 @@ def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + # Special case for swagger UI + # To avoid implementing OpenId Connect flow with the Authorize button, + # we use the cookie from the session that should be present if the user + # already logged in via the web UI + # We inject a dummy bearer token as one is expected by the oauth2_scheme + # If the user isn't logged in, this will return a 401 + if token == "swagger-ui": + return get_current_user_from_session(request, db) try: payload = utils.decode_access_token(token) except ExpiredSignatureError: @@ -73,21 +94,16 @@ def get_current_admin_user( return current_user -def get_current_user_from_cookie( +def get_current_user_from_session( request: Request, db: Session = Depends(get_db) ) -> models.User: unauthorized_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication" ) - if AUTH_COOKIE_NAME not in request.cookies: + user_id = request.session.get("user_id") + if user_id is None: raise unauthorized_exception - cookie = request.cookies[AUTH_COOKIE_NAME] - try: - user_id = cookie_auth.serializer.loads(cookie) - except BadSignature as e: - logger.warning(f"Bad Signature, invalid cookie value: {e}") - raise unauthorized_exception - user = crud.get_user(db, user_id) + user = crud.get_user(db, int(user_id)) if user is None: logger.warning(f"Unknown user id {user_id}") raise unauthorized_exception diff --git a/app/main.py b/app/main.py index 3145ecb..9d1cc24 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,28 @@ +import contextlib import logging +import httpx +import jwt import sentry_sdk from pathlib import Path +from typing import AsyncIterator, TypedDict from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from fastapi import FastAPI -from fastapi_versioning import VersionedFastAPI +from ._vendor.fastapi_versioning import VersionedFastAPI from fastapi.logger import logger from fastapi.staticfiles import StaticFiles from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware from . import monitoring from .api import login, users, services -from .views import exceptions, account, notifications, settings -from .settings import SENTRY_DSN, ESS_NOTIFY_SERVER_ENVIRONMENT, SECRET_KEY +from .views import exceptions, account, notifications, settings, docs +from .settings import ( + SENTRY_DSN, + ESS_NOTIFY_SERVER_ENVIRONMENT, + SECRET_KEY, + SESSION_MAX_AGE, + OIDC_SERVER_URL, + OIDC_ENABLED, +) # The following logging setup assumes the app is run with gunicorn @@ -21,23 +32,55 @@ logger.handlers = gunicorn_error_logger.handlers logger.setLevel(gunicorn_error_logger.level) + +class State(TypedDict): + oidc_config: dict[str, str] + jwks_client: jwt.PyJWKClient | None + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[State]: + if OIDC_ENABLED: + async with httpx.AsyncClient() as client: + r = await client.get(OIDC_SERVER_URL) + oidc_config = r.json() + jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + else: + oidc_config = {} + jwks_client = None + yield {"oidc_config": oidc_config, "jwks_client": jwks_client} + + # Main application to serve HTML middleware = [ Middleware( - SessionMiddleware, secret_key=SECRET_KEY, session_cookie="notify_session" + SessionMiddleware, + secret_key=SECRET_KEY, + session_cookie="notify_session", + max_age=SESSION_MAX_AGE, + same_site="strict", + https_only=True, ) ] -app = FastAPI(exception_handlers=exceptions.exception_handlers, middleware=middleware) +app = FastAPI( + exception_handlers=exceptions.exception_handlers, + middleware=middleware, + lifespan=lifespan, + docs_url=None, + redoc_url=None, +) app.include_router(account.router) app.include_router(notifications.router, prefix="/notifications") app.include_router(settings.router, prefix="/settings") +app.include_router(docs.router) + # Serve static files app_dir = Path(__file__).parent.resolve() app.mount("/static", StaticFiles(directory=str(app_dir / "static")), name="static") # API mounted under /api -original_api = FastAPI() +original_api = FastAPI(docs_url=None, redoc_url=None) original_api.include_router(monitoring.router, prefix="/-", tags=["monitoring"]) original_api.include_router(login.router, tags=["login"]) original_api.include_router(users.router, prefix="/users", tags=["users"]) @@ -51,10 +94,13 @@ original_api, version_format="{major}", prefix_format="/v{major}", + docs_url=None, + redoc_url=None, ) app.mount("/api", versioned_api) + if SENTRY_DSN: sentry_sdk.init(dsn=SENTRY_DSN, environment=ESS_NOTIFY_SERVER_ENVIRONMENT) app = SentryAsgiMiddleware(app) diff --git a/app/schemas.py b/app/schemas.py index 87e361f..a7effca 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -164,3 +164,10 @@ class Aps(BaseModel): class ApnPayload(BaseModel): aps: Aps + + +class OpenIdConnectAuth(BaseModel): + code: str + code_verifier: str + client_id: str + redirect_uri: str diff --git a/app/settings.py b/app/settings.py index 171460f..c210de2 100644 --- a/app/settings.py +++ b/app/settings.py @@ -15,6 +15,7 @@ config = Config() # Should be set to "ldap" or "url" +# This is still supported for the API even when OIDC is enabled AUTHENTICATION_METHOD = config("AUTHENTICATION_METHOD", cast=str, default="ldap") # LDAP configuration LDAP_HOST = config("LDAP_HOST", cast=str, default="ldap.example.org") @@ -24,6 +25,21 @@ LDAP_USER_DN = config("LDAP_USER_DN", cast=str, default="") LDAP_USER_RDN_ATTR = config("LDAP_USER_RDN_ATTR", cast=str, default="uid") +# OpenID Connect configuration +# When enabled OIDC will be used for: +# - web login (only method supported) +# - API login (old authentication method still supported as well) +OIDC_ENABLED = config("OIDC_ENABLED", cast=bool, default=False) +OIDC_NAME = config("OIDC_NAME", cast=str, default="keycloak") +OIDC_SERVER_URL = config( + "OIDC_SERVER_URL", + cast=str, + default="https://keycloak.example.org/auth/realms/myrealm/.well-known/openid-configuration", +) +OIDC_CLIENT_ID = config("OIDC_CLIENT_ID", cast=str, default="notify") +OIDC_CLIENT_SECRET = config("OIDC_CLIENT_SECRET", cast=Secret, default="!secret") +OIDC_SCOPE = config("OIDC_SCOPE", cast=str, default="openid email profile") + # URL to use when AUTHENTICATION_METHOD is set to "url" AUTHENTICATION_URL = config( "AUTHENTICATION_URL", cast=str, default="https//auth.example.org/login" @@ -36,6 +52,8 @@ "SQLALCHEMY_DATABASE_URL", cast=str, default="sqlite:///./sql_app.db" ) SQLALCHEMY_DEBUG = config("SQLALCHEMY_DEBUG", cast=bool, default=False) +# Session expiry time in seconds: 12 hours (12 * 60 * 60 = 43200) +SESSION_MAX_AGE = config("SESSION_MAX_AGE", cast=int, default=43200) APNS_ALGORITHM = "ES256" APNS_KEY_ID = config("APNS_KEY_ID", cast=Secret, default="key-id") APNS_AUTH_KEY = config("APNS_AUTH_KEY", cast=Secret, default=DUMMY_PRIVATE_KEY) @@ -63,8 +81,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = config( "ACCESS_TOKEN_EXPIRE_MINUTES", cast=int, default=43200 ) -# Cookie name -AUTH_COOKIE_NAME = config("AUTH_COOKIE_NAME", cast=str, default="notify_token") # Number of push notifications sent in parallel NB_PARALLEL_PUSH = config("NB_PARALLEL_PUSH", cast=int, default=50) diff --git a/app/static/js/swagger-ui-custom.js b/app/static/js/swagger-ui-custom.js new file mode 100644 index 0000000..81c9e7e --- /dev/null +++ b/app/static/js/swagger-ui-custom.js @@ -0,0 +1,33 @@ +window.onload = function () { + // Extract API version from the URL (e.g., "/api/v1/docs" -> "v1") + const pathParts = window.location.pathname.split("/"); + const version = pathParts.length >= 3 ? pathParts[2] : "v1"; // Default to v1 if missing + + // Construct the OpenAPI URL dynamically + const openapiUrl = `/api/${version}/openapi.json`; + + setTimeout(() => { + fetch(openapiUrl) // Load the correct OpenAPI schema + .then(response => response.json()) + .then(spec => { + spec.host = window.location.host; + spec.schemes = [window.location.protocol.replace(':', '')]; + + spec.info.description = 'To perform authenticated requests, do not use "Authorize" but login via the web UI first.'; + + window.ui = SwaggerUIBundle({ + spec: spec, + dom_id: '#swagger-ui', + deepLinking: true, + presets: [SwaggerUIBundle.presets.apis, SwaggerUIBundle.SwaggerUIStandalonePreset], + requestInterceptor: request => { + // Add custom bearer token so that the user is retrieved from the session + // (if logged in) + request.headers['Authorization'] = "Bearer swagger-ui"; + return request; + }, + }); + }) + .catch(error => console.error(`Error loading OpenAPI spec for ${version}:`, error)); + }, 1000); +}; diff --git a/app/templates/400.html b/app/templates/400.html new file mode 100644 index 0000000..a2635a5 --- /dev/null +++ b/app/templates/400.html @@ -0,0 +1,8 @@ +{%- extends "base.html" %} + +{% block title %}Bad Request{% endblock %} + +{% block main %} +

Bad Request

+

{{ detail }}

+{%- endblock %} diff --git a/app/utils.py b/app/utils.py index 46a3e65..0378918 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,4 +1,5 @@ import asyncio +import base64 import httpx import ipaddress import uuid @@ -119,3 +120,39 @@ async def send_notification(notification_id: int) -> None: await android_client.aclose() finally: db.close() + + +def validate_id_token( + id_token: str, + access_token: str, + jwks_client: jwt.PyJWKClient, + signing_algos: list[str], + client_id: str, +) -> None: + """Raise an exception if the validation of the id token fails""" + # See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow + signing_key = jwks_client.get_signing_key_from_jwt(id_token) + # Decode and verify id_token claims + # expiration, issued at, not before, audience and issuer + data = jwt.decode_complete( + id_token, + key=signing_key, + audience=client_id, + algorithms=signing_algos, + require=["exp", "iat", "nbf", "aud", "iss"], + verify_signature=True, + ) + payload, header = data["payload"], data["header"] + alg_obj = jwt.get_algorithm_by_name(header["alg"]) + # compute at_hash, then validate + # access_token must be bytes (not str) + digest = alg_obj.compute_hash_digest(access_token.encode("utf-8")) + at_hash = ( + base64.urlsafe_b64encode(digest[: (len(digest) // 2)]) + .rstrip(b"=") + .decode("utf-8") + ) + if at_hash != payload["at_hash"]: + raise ValueError( + f"at_hash value {payload['at_hash']} doesn't match computed {at_hash}" + ) diff --git a/app/views/account.py b/app/views/account.py index 1ba2f69..099df0a 100644 --- a/app/views/account.py +++ b/app/views/account.py @@ -1,11 +1,12 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, status, HTTPException from fastapi.logger import logger from starlette.responses import HTMLResponse, RedirectResponse from starlette.requests import Request from sqlalchemy.orm import Session +from authlib.integrations.base_client.errors import OAuthError from . import templates -from .. import deps, cookie_auth, crud, auth, models -from ..settings import APP_NAME +from .. import deps, crud, auth, models +from ..settings import APP_NAME, OIDC_ENABLED router = APIRouter() @@ -13,22 +14,26 @@ @router.get("/", response_class=HTMLResponse, name="index") async def index( request: Request, - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): return RedirectResponse(url="/notifications") @router.get("/login", response_class=HTMLResponse) async def login_get(request: Request): - return templates.TemplateResponse( - "login.html", - { - "request": request, - "username": "", - "password": "", - "error": "", - }, - ) + if OIDC_ENABLED: + redirect_uri = request.url_for("oidc_auth") + return await deps.oauth.keycloak.authorize_redirect(request, redirect_uri) + else: + return templates.TemplateResponse( + "login.html", + { + "request": request, + "username": "", + "password": "", + "error": "", + }, + ) @router.post("/login", response_class=HTMLResponse) @@ -36,6 +41,10 @@ async def login_post( request: Request, db: Session = Depends(deps.get_db), ): + if OIDC_ENABLED: + raise HTTPException( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Invalid method" + ) form = await request.form() username = form.get("username", "").lower().strip() password = form.get("password", "").strip() @@ -59,14 +68,35 @@ async def login_post( db_user = crud.create_user(db, username.lower()) resp = RedirectResponse("/", status_code=status.HTTP_302_FOUND) - cookie_auth.set_auth(resp, db_user.id) + request.session["user_id"] = db_user.id return resp +@router.get("/auth") +async def oidc_auth( + request: Request, + db: Session = Depends(deps.get_db), +): + try: + token = await deps.oauth.keycloak.authorize_access_token(request) + except OAuthError as e: + logger.warning(f"OAuthError on OpenID Connect redirect: {e}") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + user_info = token["userinfo"] + if user_info: + username = user_info["preferred_username"].lower() + db_user = crud.get_user_by_username(db, username) + if db_user is None: + db_user = crud.create_user(db, username) + request.session["user_id"] = db_user.id + return RedirectResponse(url=request.session.pop("next", "/")) + return RedirectResponse(url="/login") + + @router.get("/logout") -def logout(): +def logout(request: Request): response = RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND) - cookie_auth.logout(response) + request.session.pop("user_id", None) return response diff --git a/app/views/docs.py b/app/views/docs.py new file mode 100644 index 0000000..5be68c0 --- /dev/null +++ b/app/views/docs.py @@ -0,0 +1,27 @@ +from fastapi import APIRouter +from starlette.responses import HTMLResponse + +router = APIRouter() + + +# Dynamic Swagger UI Route (works for `/api/v1/docs` and `/api/v2/docs`) +# Override the default Swagger UI endpoint to load some custom javascript +# and inject a bearer token +@router.get("/api/{version}/docs", include_in_schema=False) +async def custom_swagger_ui(version: str): + html = """ + + + + + Notify SwaggerUI + + +
+
+ + + + + """ + return HTMLResponse(html) diff --git a/app/views/exceptions.py b/app/views/exceptions.py index 18f8819..ba19e8c 100644 --- a/app/views/exceptions.py +++ b/app/views/exceptions.py @@ -9,6 +9,14 @@ async def not_authenticated(request: Request, exc: HTTPException): return RedirectResponse(url="/login") +async def bad_request(request: Request, exc: HTTPException): + return templates.TemplateResponse( + "400.html", + {"request": request, "detail": exc.detail}, + status_code=exc.status_code, + ) + + async def forbidden(request: Request, exc: HTTPException): return templates.TemplateResponse( "403.html", {"request": request}, status_code=exc.status_code @@ -29,4 +37,9 @@ async def server_error(request: Request, exc: HTTPException): ) -exception_handlers = {401: not_authenticated, 404: not_found, 500: server_error} +exception_handlers = { + 400: bad_request, + 401: not_authenticated, + 404: not_found, + 500: server_error, +} diff --git a/app/views/notifications.py b/app/views/notifications.py index fa25e93..67f0ab6 100644 --- a/app/views/notifications.py +++ b/app/views/notifications.py @@ -12,7 +12,7 @@ async def notifications_get( request: Request, db: Session = Depends(deps.get_db), - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): try: notifications_limit = request.session["notifications_limit"] @@ -50,10 +50,10 @@ async def notifications_post( request: Request, notifications_limit: int = Form(50), db: Session = Depends(deps.get_db), - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): form = await request.form() - selected_categories = [key for key in form.keys() if key != "notifications_limit"] + selected_categories = [key for key in form if key != "notifications_limit"] request.session["selected_categories"] = selected_categories request.session["notifications_limit"] = notifications_limit user_services = crud.get_user_services(db, current_user) @@ -95,7 +95,7 @@ async def notifications_post( async def notifications_update( request: Request, db: Session = Depends(deps.get_db), - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): user_services = crud.get_user_services(db, current_user) categories = {service.id: service.category for service in user_services} diff --git a/app/views/settings.py b/app/views/settings.py index 026ea5a..9307e0b 100644 --- a/app/views/settings.py +++ b/app/views/settings.py @@ -12,7 +12,7 @@ async def settings_get( request: Request, db: Session = Depends(deps.get_db), - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): services = crud.get_user_services(db, current_user) return templates.TemplateResponse( @@ -25,7 +25,7 @@ async def settings_get( async def settings_post( request: Request, db: Session = Depends(deps.get_db), - current_user: models.User = Depends(deps.get_current_user_from_cookie), + current_user: models.User = Depends(deps.get_current_user_from_session), ): form = await request.form() selected_categories = list(form.keys()) diff --git a/pyproject.toml b/pyproject.toml index b16ce4a..b0b1bc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,70 @@ [build-system] -requires = ["setuptools >= 42", "wheel", "setuptools_scm[toml]>=3.4"] +requires = ["setuptools>=64", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] +version_file = "app/_version.py" + +[tool.setuptools] +packages = ["app"] + +[project] +name = "ess-notify" +dynamic = ["version"] +description = "ESS notification server" +readme = "README.md" +dependencies = [ + "alembic", + "aiofiles", + "authlib", + "cryptography", + "fastapi", + "pydantic>=2.3", + "google-auth", + "requests", + "h2", + "itsdangerous", + "jinja2", + "python-multipart", + "httpx", + "PyJWT>=2.10", + "ldap3", + "SQLAlchemy<1.4", + "uvicorn[standard]", + "gunicorn", + "sentry-sdk", + "typer", +] +requires-python = ">= 3.11" +license = { text = "BSD-2-Clause AND MIT" } + +[project.optional-dependencies] +postgres = ["psycopg2"] +tests = [ + "packaging", + "pytest", + "pytest-cov", + "pytest-asyncio", + "pytest-mock", + "pytest-factoryboy", + "respx", + "Faker", +] + +[project.urls] +Repository = "https://github.com/europeanspallationsource/notify-server" + +[project.scripts] +notify-server = "app.command:cli" + +[tool.ruff.lint] +select = [ + # pycodestyle + "E4", # Import + "E7", # Statement + "E9", # Runtime + # Pyflakes + "F", + # flake8-simplify + "SIM", +] diff --git a/requirements.txt b/requirements.txt index dfeba81..1bf9eb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,50 +1,145 @@ -aiofiles==23.2.1 -alembic==1.13.1 -annotated-types==0.6.0 -anyio==4.2.0 -cachetools==5.3.2 -certifi==2024.2.2 -cffi==1.16.0 -charset-normalizer==3.3.2 -click==8.1.7 -cryptography==42.0.2 -fastapi==0.109.2 -fastapi-versioning==0.10.0 -google-auth==2.27.0 -gunicorn==21.2.0 +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o requirements.txt +aiofiles==24.1.0 + # via ess-notify (pyproject.toml) +alembic==1.14.1 + # via ess-notify (pyproject.toml) +annotated-types==0.7.0 + # via pydantic +anyio==4.8.0 + # via + # httpx + # starlette + # watchfiles +authlib==1.5.1 + # via ess-notify (pyproject.toml) +cachetools==5.5.2 + # via google-auth +certifi==2025.1.31 + # via + # httpcore + # httpx + # requests + # sentry-sdk +cffi==1.17.1 + # via cryptography +charset-normalizer==3.4.1 + # via requests +click==8.1.8 + # via + # typer + # uvicorn +cryptography==44.0.2 + # via + # ess-notify (pyproject.toml) + # authlib +fastapi==0.115.11 + # via ess-notify (pyproject.toml) +google-auth==2.38.0 + # via ess-notify (pyproject.toml) +gunicorn==23.0.0 + # via ess-notify (pyproject.toml) h11==0.14.0 -h2==4.1.0 -hpack==4.0.0 -httpcore==1.0.2 -httptools==0.6.1 -httpx==0.26.0 -hyperframe==6.0.1 -idna==3.6 -itsdangerous==2.1.2 -Jinja2==3.1.3 + # via + # httpcore + # uvicorn +h2==4.2.0 + # via ess-notify (pyproject.toml) +hpack==4.1.0 + # via h2 +httpcore==1.0.7 + # via httpx +httptools==0.6.4 + # via uvicorn +httpx==0.28.1 + # via ess-notify (pyproject.toml) +hyperframe==6.1.0 + # via h2 +idna==3.10 + # via + # anyio + # httpx + # requests +itsdangerous==2.2.0 + # via ess-notify (pyproject.toml) +jinja2==3.1.5 + # via ess-notify (pyproject.toml) ldap3==2.9.1 -Mako==1.3.2 -MarkupSafe==2.1.5 -packaging==23.2 -pyasn1==0.5.1 -pyasn1-modules==0.3.0 -pycparser==2.21 -pydantic==2.6.1 -pydantic_core==2.16.2 -PyJWT==2.8.0 + # via ess-notify (pyproject.toml) +mako==1.3.9 + # via alembic +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via + # jinja2 + # mako +mdurl==0.1.2 + # via markdown-it-py +packaging==24.2 + # via gunicorn +pyasn1==0.6.1 + # via + # ldap3 + # pyasn1-modules + # rsa +pyasn1-modules==0.4.1 + # via google-auth +pycparser==2.22 + # via cffi +pydantic==2.10.6 + # via + # ess-notify (pyproject.toml) + # fastapi +pydantic-core==2.27.2 + # via pydantic +pygments==2.19.1 + # via rich +pyjwt==2.10.1 + # via ess-notify (pyproject.toml) python-dotenv==1.0.1 -python-multipart==0.0.7 -PyYAML==6.0.1 -requests==2.31.0 + # via uvicorn +python-multipart==0.0.20 + # via ess-notify (pyproject.toml) +pyyaml==6.0.2 + # via uvicorn +requests==2.32.3 + # via ess-notify (pyproject.toml) +rich==13.9.4 + # via typer rsa==4.9 -sentry-sdk==1.40.1 -sniffio==1.3.0 -SQLAlchemy==1.3.24 -starlette==0.36.3 -typer==0.9.0 -typing_extensions==4.9.0 -urllib3==2.2.0 -uvicorn==0.27.0.post1 -uvloop==0.19.0 -watchfiles==0.21.0 -websockets==12.0 + # via google-auth +sentry-sdk==2.22.0 + # via ess-notify (pyproject.toml) +shellingham==1.5.4 + # via typer +sniffio==1.3.1 + # via anyio +sqlalchemy==1.3.24 + # via + # ess-notify (pyproject.toml) + # alembic +starlette==0.46.0 + # via fastapi +typer==0.15.2 + # via ess-notify (pyproject.toml) +typing-extensions==4.12.2 + # via + # alembic + # anyio + # fastapi + # pydantic + # pydantic-core + # typer +urllib3==2.3.0 + # via + # requests + # sentry-sdk +uvicorn==0.34.0 + # via ess-notify (pyproject.toml) +uvloop==0.21.0 + # via uvicorn +watchfiles==1.0.4 + # via uvicorn +websockets==15.0 + # via uvicorn diff --git a/setup.py b/setup.py deleted file mode 100644 index 048db0a..0000000 --- a/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -import setuptools - -with open("README.md", "r") as f: - long_description = f.read() - - -postgres_requires = ["psycopg2"] -requirements = [ - "alembic", - "aiofiles", - "cryptography", - "fastapi", - "pydantic>=2.3", - "fastapi-versioning", - "google-auth", - "requests", - "h2", - "itsdangerous", - "jinja2", - "python-multipart", - "httpx", - "PyJWT", - "ldap3", - "SQLAlchemy<1.4", - "uvicorn[standard]", - "gunicorn", - "sentry-sdk", - "typer", -] -tests_requires = [ - "packaging", - "pytest", - "pytest-cov", - "pytest-asyncio", - "pytest-mock", - "pytest-factoryboy", - "respx", - "Faker", -] - -setuptools.setup( - name="ess-notify", - description="ESS notification server", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://gitlab.esss.lu.se/ics-software/ess-notify-server", - license="BSD-2 license", - setup_requires=["setuptools_scm"], - install_requires=requirements, - packages=setuptools.find_packages(exclude=["tests", "tests.*"]), - classifiers=[ - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - entry_points={"console_scripts": ["notify-server=app.command:cli"]}, - extras_require={"postgres": postgres_requires, "tests": tests_requires}, - python_requires=">=3.9", -) diff --git a/tests/api/test_services.py b/tests/api/test_services.py index dd4056a..539c05a 100644 --- a/tests/api/test_services.py +++ b/tests/api/test_services.py @@ -1,8 +1,6 @@ import json import uuid import pytest -import importlib.metadata -import packaging.version from fastapi.testclient import TestClient from app import models, schemas from ..utils import no_tz_isoformat @@ -105,7 +103,6 @@ def test_update_service_invalid_color( "input": color, "msg": "Value error, Color should match [0-9a-fA-F]{6}", "type": "value_error", - "url": f"{pydantic_errors_url()}/v/value_error", } ], } @@ -190,7 +187,6 @@ def test_read_service_notifications_invalid_service_id( "msg": "Input should be a valid UUID, invalid length: expected " "length 32 for simple format, found 4", "type": "uuid_parsing", - "url": f"{pydantic_errors_url()}/v/uuid_parsing", } ], } @@ -267,9 +263,3 @@ def test_create_notification_for_service( "title": sample_notification["title"], "url": sample_notification["url"], } - - -def pydantic_errors_url(): - version_str = importlib.metadata.version("pydantic") - version = packaging.version.parse(version_str) - return f"https://errors.pydantic.dev/{version.major}.{version.minor}" diff --git a/tests/conftest.py b/tests/conftest.py index cd300a7..b9b46d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,7 @@ environ["LDAP_SERVER"] = "ldap.example.org" environ["APNS_KEY_ID"] = "UB40ZXKCDZ" environ["AUTHENTICATION_URL"] = "https://auth.example.org/login" -environ[ - "APNS_AUTH_KEY" -] = """-----BEGIN PRIVATE KEY----- +environ["APNS_AUTH_KEY"] = """-----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgtAParbMemenK/+8T JYWanX1jzKaFcgmupVALPHyaKKKhRANCAARVmMAXI+WPS/vjIsFBHb3B5dQKqgT8 ytZPnlbWNLGGR7tKdB1eLzyBlIVFe9El4Wlvs19ACPRMtE7l75IlbOT+