diff --git a/src/app/routes/oauth.py b/src/app/routes/oauth.py index 5e7e2886..11c74012 100644 --- a/src/app/routes/oauth.py +++ b/src/app/routes/oauth.py @@ -9,7 +9,7 @@ import secrets import hashlib import base64 -from typing import Dict +from typing import Dict, Optional import logging from flask import Blueprint, jsonify, request, url_for, redirect, render_template @@ -31,6 +31,27 @@ TOKEN_RESOURCES: Dict[str, str] = {} +def _resolve_client(client_id: Optional[str]) -> Optional[OAuthClient]: + """Resolve an OAuth client by id or fall back to the single registered one.""" + + if client_id: + return OAuthClient.query.filter_by(client_id=client_id).first() + + # When no client_id is provided, fall back to the sole registered client. + clients = OAuthClient.query.limit(2).all() + if len(clients) == 1: + logger.info( + "OAuth: defaulted to sole client", extra={"client_id": clients[0].client_id} + ) + return clients[0] + + logger.info( + "OAuth: unable to resolve client without id", + extra={"registered_clients": len(clients)}, + ) + return None + + def canonical_mcp_resource() -> str: """Return the canonical MCP server URI (no trailing slash). @@ -229,17 +250,19 @@ def issue_token(): if not data: data = request.get_json(silent=True) or {} - client_id = data.get('client_id') + client_id_param = data.get('client_id') client_secret = data.get('client_secret') - client = OAuthClient.query.filter_by(client_id=client_id).first() + client = _resolve_client(client_id_param) if not client or (client_secret and client.client_secret != client_secret): logger.info("OAuth: invalid_client at token endpoint", extra={ - "client_id": client_id, + "client_id": client_id_param, "has_secret": bool(client_secret), "remote_addr": request.remote_addr, }) return _json_error(401, 'invalid_client', 'Client authentication failed') + client_id = client.client_id + grant_type = data.get('grant_type', 'client_credentials') resource = data.get('resource') if not resource: @@ -260,7 +283,7 @@ def issue_token(): if grant_type == 'authorization_code': code = data.get('code') code_verifier = data.get('code_verifier') - redirect_uri = data.get('redirect_uri') + redirect_uri = data.get('redirect_uri') or client.redirect_uri if not code or not code_verifier or not redirect_uri: logger.info("OAuth: authorization_code missing fields", extra={ "client_id": client_id, @@ -339,20 +362,25 @@ def issue_token(): @login_required def authorize(): """Display a consent page and issue an authorization code.""" - client_id = request.args.get('client_id') + client_id_param = request.args.get('client_id') + client = _resolve_client(client_id_param) redirect_uri = request.args.get('redirect_uri') - code_challenge = request.values.get('code_challenge') - state = request.values.get('state') - scope = request.values.get('scope') - resource = request.values.get('resource') - client = OAuthClient.query.filter_by(client_id=client_id).first() + if client and not redirect_uri: + redirect_uri = client.redirect_uri + if not client or (client.redirect_uri and client.redirect_uri != redirect_uri): logger.info("OAuth: authorize invalid_client or redirect mismatch", extra={ - "client_id": client_id, + "client_id": client_id_param, "redirect_uri": redirect_uri, }) return jsonify({'error': 'invalid_client'}), 400 + client_id = client.client_id + code_challenge = request.values.get('code_challenge') + state = request.values.get('state') + scope = request.values.get('scope') + resource = request.values.get('resource') + if request.method == 'POST' and request.form.get('confirm') == 'yes': code = issue_auth_code( client_id=client_id, diff --git a/src/tests/test_asgi_mcp_lifespan.py b/src/tests/test_asgi_mcp_lifespan.py index e7980ef4..3cf52aa4 100644 --- a/src/tests/test_asgi_mcp_lifespan.py +++ b/src/tests/test_asgi_mcp_lifespan.py @@ -2,6 +2,16 @@ def test_parent_asgi_app_uses_mcp_lifespan(monkeypatch): + monkeypatch.setenv("SECRET_KEY", "testing") + monkeypatch.setenv("RECAPTCHA_PUBLIC_KEY", "testing") + monkeypatch.setenv("RECAPTCHA_PRIVATE_KEY", "testing") + monkeypatch.setenv("CELERY_BROKER_URL", "memory://") + monkeypatch.setenv("CELERY_RESULT_BACKEND", "cache+memory://") + import src.config.env as env + monkeypatch.setattr(env, "SECRET_KEY", "testing", raising=False) + monkeypatch.setattr(env, "RECAPTCHA_PUBLIC_KEY", "testing", raising=False) + monkeypatch.setattr(env, "RECAPTCHA_PRIVATE_KEY", "testing", raising=False) + # Create a dummy ASGI app with a recognizable lifespan callable class DummyASGI: async def __call__(self, scope, receive, send): diff --git a/src/tests/test_asgi_sse_cors.py b/src/tests/test_asgi_sse_cors.py index 5c7a20a7..ebef54c2 100644 --- a/src/tests/test_asgi_sse_cors.py +++ b/src/tests/test_asgi_sse_cors.py @@ -3,11 +3,19 @@ from starlette.routing import Route, Mount from starlette.middleware.cors import CORSMiddleware from starlette.testclient import TestClient -import pytest - - -@pytest.mark.asyncio -async def test_tasks_events_sse_includes_cors_header_direct(): +import asyncio + + +def test_tasks_events_sse_includes_cors_header_direct(monkeypatch): + monkeypatch.setenv("SECRET_KEY", "testing") + monkeypatch.setenv("RECAPTCHA_PUBLIC_KEY", "testing") + monkeypatch.setenv("RECAPTCHA_PRIVATE_KEY", "testing") + monkeypatch.setenv("CELERY_BROKER_URL", "memory://") + monkeypatch.setenv("CELERY_RESULT_BACKEND", "cache+memory://") + import src.config.env as env + monkeypatch.setattr(env, "SECRET_KEY", "testing", raising=False) + monkeypatch.setattr(env, "RECAPTCHA_PUBLIC_KEY", "testing", raising=False) + monkeypatch.setattr(env, "RECAPTCHA_PRIVATE_KEY", "testing", raising=False) # Call the SSE handler directly to validate response headers from src.asgi import sse_task_events from starlette.requests import Request @@ -18,7 +26,7 @@ async def test_tasks_events_sse_includes_cors_header_direct(): "headers": [], } req = Request(scope) - resp = await sse_task_events(req) + resp = asyncio.run(sse_task_events(req)) assert resp.headers.get("Access-Control-Allow-Origin") == "*" diff --git a/tests/test_oauth_flow.py b/tests/test_oauth_flow.py index 06fb1186..c4c80c53 100644 --- a/tests/test_oauth_flow.py +++ b/tests/test_oauth_flow.py @@ -32,7 +32,6 @@ def test_oauth_registration_and_access(app, client): resp = client.post( '/token', data={ - 'client_id': client_id, 'grant_type': 'client_credentials', 'ttl': 3600, }, @@ -102,7 +101,6 @@ def test_authorization_code_flow(app, client): '/authorize', query_string={ 'response_type': 'code', - 'client_id': 'cid', 'redirect_uri': 'https://example.com/cb', 'code_challenge': code_challenge, 'state': 'abc', @@ -114,7 +112,7 @@ def test_authorization_code_flow(app, client): token = soup.find('input', {'name': 'csrf_token'})['value'] resp = client.post( - '/authorize?client_id=cid&redirect_uri=https://example.com/cb&state=abc', + '/authorize?redirect_uri=https://example.com/cb&state=abc', data={'confirm': 'yes', 'code_challenge': code_challenge, 'csrf_token': token}, follow_redirects=False, ) @@ -131,7 +129,6 @@ def test_authorization_code_flow(app, client): 'grant_type': 'authorization_code', 'code': code, 'code_verifier': code_verifier, - 'client_id': 'cid', 'client_secret': 'secret', 'redirect_uri': 'https://example.com/cb', },