From 41b6b0ae492b6b19ed405bf61f445bb0741d9cd4 Mon Sep 17 00:00:00 2001 From: Ahmet Soormally Date: Fri, 9 Jan 2026 19:42:05 +0100 Subject: [PATCH 1/5] feat(mcp): add OAuth 2.1 authorization with JWT validation and per-tool scopes Implements OAuth 2.1 authorization for MCP server with comprehensive JWT validation and flexible per-tool scope requirements. Core Authorization Features: - OAuth 2.1 authorization middleware with JWT/JWKS validation - RFC 8414 and RFC 9728 OAuth 2.0 Authorization Server Metadata endpoint - HTTP-level authentication protecting all MCP operations - 64KB request body limit for security Configuration Improvements: - scopes_required map supporting per-tool and HTTP-level scope configuration - Special "initialize" key for scopes required on all requests - Automatic scopes_supported derivation from required scopes - JSON schema updates for new OAuth configuration structure Test Infrastructure: - Automatic JWKS test server with RSA key generation (testutil/jwt_helper.go) - WWW-Authenticate header parser for OAuth error testing (testutil/auth_helpers.go) - End-to-end OAuth tests verifying token changes on persistent MCP sessions - JWT token generation with custom scopes for testing - MCPAuthClient wrapper demonstrating scope upgrade pattern Dependencies: - Update mcp-go to v0.43.2 --- demo/go.mod | 2 +- demo/go.sum | 8 +- router-tests/go.mod | 5 +- router-tests/go.sum | 13 +- router-tests/mcp_auth_e2e_test.go | 361 +++++++ router-tests/mcp_auth_harness_example.go | 242 +++++ router-tests/mcp_oauth_e2e_test.go | 187 ++++ router-tests/mcp_test.go | 14 +- router-tests/testenv/testenv.go | 26 +- router-tests/testutil/auth_helpers.go | 61 ++ router-tests/testutil/jwt_helper.go | 185 ++++ router/core/router.go | 10 + router/go.mod | 2 +- router/go.sum | 4 +- router/pkg/config/config.go | 37 +- router/pkg/config/config.schema.json | 312 ++++++- .../pkg/config/testdata/config_defaults.json | 8 +- router/pkg/config/testdata/config_full.json | 8 +- router/pkg/mcpserver/auth_middleware.go | 256 +++++ router/pkg/mcpserver/auth_middleware_test.go | 881 ++++++++++++++++++ router/pkg/mcpserver/errors.go | 34 + router/pkg/mcpserver/operation_manager.go | 4 +- router/pkg/mcpserver/schema_compiler.go | 2 +- router/pkg/mcpserver/server.go | 307 +++++- 24 files changed, 2909 insertions(+), 60 deletions(-) create mode 100644 router-tests/mcp_auth_e2e_test.go create mode 100644 router-tests/mcp_auth_harness_example.go create mode 100644 router-tests/mcp_oauth_e2e_test.go create mode 100644 router-tests/testutil/auth_helpers.go create mode 100644 router-tests/testutil/jwt_helper.go create mode 100644 router/pkg/mcpserver/auth_middleware.go create mode 100644 router/pkg/mcpserver/auth_middleware_test.go create mode 100644 router/pkg/mcpserver/errors.go diff --git a/demo/go.mod b/demo/go.mod index 29aba092c4..930527aa01 100644 --- a/demo/go.mod +++ b/demo/go.mod @@ -95,7 +95,7 @@ require ( github.com/logrusorgru/aurora/v4 v4.0.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mark3labs/mcp-go v0.36.0 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/minio/md5-simd v1.1.2 // indirect diff --git a/demo/go.sum b/demo/go.sum index 0f9aba5f4e..f547b9d8b4 100644 --- a/demo/go.sum +++ b/demo/go.sum @@ -96,8 +96,8 @@ github.com/dop251/goja_nodejs v0.0.0-20210225215109-d91c329300e7/go.mod h1:hn7BA github.com/dop251/goja_nodejs v0.0.0-20211022123610-8dd9abb0616d/go.mod h1:DngW8aVqWbuLRMHItjPUyqdj+HWPvnQe8V8y1nDpIbM= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= -github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= +github.com/expr-lang/expr v1.17.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -222,8 +222,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.36.0 h1:rIZaijrRYPeSbJG8/qNDe0hWlGrCJ7FWHNMz2SQpTis= -github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/router-tests/go.mod b/router-tests/go.mod index 3de9ec5fcf..5d92f5c8a2 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -12,7 +12,8 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hasura/go-graphql-client v0.14.3 - github.com/mark3labs/mcp-go v0.36.0 + github.com/mark3labs/mcp-go v0.43.2 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/nats-io/nats.go v1.35.0 github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.6.1 @@ -88,6 +89,7 @@ require ( github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-containerregistry v0.20.3 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect @@ -172,6 +174,7 @@ require ( golang.org/x/crypto v0.43.0 // indirect golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect golang.org/x/mod v0.29.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.30.0 // indirect golang.org/x/time v0.9.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index dc73f635fd..4717f34e03 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -132,10 +132,13 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -203,8 +206,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.36.0 h1:rIZaijrRYPeSbJG8/qNDe0hWlGrCJ7FWHNMz2SQpTis= -github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -225,6 +228,8 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/nats-io/nats.go v1.35.0 h1:XFNqNM7v5B+MQMKqVGAyHwYhyKb48jrenXNxIU20ULk= @@ -426,6 +431,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= diff --git a/router-tests/mcp_auth_e2e_test.go b/router-tests/mcp_auth_e2e_test.go new file mode 100644 index 0000000000..62efcf2bc6 --- /dev/null +++ b/router-tests/mcp_auth_e2e_test.go @@ -0,0 +1,361 @@ +package integration + +import ( + "context" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router-tests/testutil" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +// authRoundTripper wraps an http.RoundTripper and adds Authorization headers +// It also captures the last HTTP response for error analysis +type authRoundTripper struct { + base http.RoundTripper + token string + lastResponse *http.Response +} + +func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original + req = req.Clone(req.Context()) + + // Add Authorization header if token is set + if a.token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.token)) + } + + resp, err := a.base.RoundTrip(req) + // Capture response for error analysis + a.lastResponse = resp + return resp, err +} + +// MCPAuthClient wraps the official MCP client with authorization support +type MCPAuthClient struct { + endpoint string + transport *mcp.StreamableClientTransport + roundTripper *authRoundTripper + client *mcp.Client + session *mcp.ClientSession +} + +// AuthError represents an HTTP authentication/authorization error +type AuthError struct { + StatusCode int + ErrorCode string + RequiredScopes []string + ResourceMetadataURL string + ErrorDescription string +} + +func (e *AuthError) Error() string { + if e.ErrorCode == "insufficient_scope" { + return fmt.Sprintf("HTTP %d: insufficient scope - required scopes: %v", e.StatusCode, e.RequiredScopes) + } + return fmt.Sprintf("HTTP %d: %s - %s", e.StatusCode, e.ErrorCode, e.ErrorDescription) +} + +// NewMCPAuthClient creates a new MCP client with authorization support +func NewMCPAuthClient(endpoint string, initialToken string) *MCPAuthClient { + // Create a custom round tripper that adds Authorization headers + roundTripper := &authRoundTripper{ + base: http.DefaultTransport, + token: initialToken, + } + + // Create HTTP client with custom round tripper + httpClient := &http.Client{ + Transport: roundTripper, + } + + // Create streamable transport + transport := &mcp.StreamableClientTransport{ + Endpoint: endpoint, + HTTPClient: httpClient, + } + + // Create MCP client + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + return &MCPAuthClient{ + endpoint: endpoint, + transport: transport, + roundTripper: roundTripper, + client: client, + } +} + +// Connect establishes the MCP connection and initializes the session +func (c *MCPAuthClient) Connect(ctx context.Context) error { + session, err := c.client.Connect(ctx, c.transport, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + c.session = session + return nil +} + +// SetToken updates the authorization token +// This is the KEY method - it allows changing tokens without reconnecting! +func (c *MCPAuthClient) SetToken(token string) { + c.roundTripper.token = token +} + +// CallTool calls an MCP tool +// Returns *AuthError if the request fails due to HTTP 401/403 +func (c *MCPAuthClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + params := &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + } + + result, err := c.session.CallTool(ctx, params) + if err != nil { + // Check if this was an HTTP auth error + if authErr := c.checkAuthError(); authErr != nil { + return nil, authErr + } + return nil, err + } + + return result, nil +} + +// checkAuthError checks if the last HTTP response was an auth error (401/403) +// and returns an AuthError with parsed WWW-Authenticate header information +func (c *MCPAuthClient) checkAuthError() *AuthError { + if c.roundTripper.lastResponse == nil { + return nil + } + + resp := c.roundTripper.lastResponse + + // Check for 401 Unauthorized or 403 Forbidden + if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusForbidden { + return nil + } + + // Parse WWW-Authenticate header + authHeader := resp.Header.Get("WWW-Authenticate") + if authHeader == "" { + return &AuthError{ + StatusCode: resp.StatusCode, + ErrorCode: "authentication_required", + } + } + + params := testutil.ParseWWWAuthenticateParams(authHeader) + + authErr := &AuthError{ + StatusCode: resp.StatusCode, + ErrorCode: params["error"], + ResourceMetadataURL: params["resource_metadata"], + ErrorDescription: params["error_description"], + } + + // Parse required scopes (space-separated) + if scopeStr := params["scope"]; scopeStr != "" { + authErr.RequiredScopes = strings.Fields(scopeStr) + } + + return authErr +} + +// Close closes the MCP session +func (c *MCPAuthClient) Close() error { + if c.session != nil { + return c.session.Close() + } + return nil +} + +// TestMCPAuthorizationWithOfficialSDK demonstrates authorization testing with the official MCP Go SDK +func TestMCPAuthorizationWithOfficialSDK(t *testing.T) { + t.Run("Basic connection with token", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + // Create MCP client with initial token + token := "test-token-with-read-scopes" + mcpClient := NewMCPAuthClient(xEnv.GetMCPServerAddr(), token) + + // Connect and initialize + err := mcpClient.Connect(ctx) + require.NoError(t, err) + defer mcpClient.Close() //nolint:errcheck + + t.Logf("✓ Connected to MCP server with token: %s", token[:20]+"...") + + // Call a tool + result, err := mcpClient.CallTool(ctx, "execute_operation_my_employees", map[string]any{ + "criteria": map[string]any{}, + }) + + // Without authorization configured, this should work + require.NoError(t, err) + require.NotNil(t, result) + t.Logf("✓ Successfully called tool") + }) + }) + + t.Run("Scope upgrade on persistent session", func(t *testing.T) { + // This test demonstrates the KEY concept: + // - Establish session with token1 + // - Get "insufficient scopes" error + // - Update token (SetToken) + // - Retry on SAME session with new token + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + // TODO: Add authorization configuration when implemented + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + // Step 1: Connect with limited token + readToken := "token-with-scope-mcp:tools:read" + mcpClient := NewMCPAuthClient(xEnv.GetMCPServerAddr(), readToken) + + err := mcpClient.Connect(ctx) + require.NoError(t, err) + defer mcpClient.Close() //nolint:errcheck + + t.Logf("✓ Step 1: Connected with read-only token") + t.Logf(" Token: %s", readToken[:30]+"...") + + // Step 2: Call read operation (should succeed) + result, err := mcpClient.CallTool(ctx, "execute_operation_my_employees", map[string]any{ + "criteria": map[string]any{}, + }) + require.NoError(t, err) + require.NotNil(t, result) + t.Logf("✓ Step 2: Read operation succeeded") + + // Step 3: Try write operation (should fail with insufficient scopes) + // NOTE: This would fail if authorization is configured + _, err = mcpClient.CallTool(ctx, "execute_operation_update_mood", map[string]any{ + "employeeID": 1, + "mood": "HAPPY", + }) + + // Without authorization, this succeeds. With authorization, check for scope error + if err != nil { + t.Logf("✓ Step 3: Write operation failed (expected with auth): %v", err) + + // In a real scenario with authorization: + // 1. Parse error to get required scopes + // 2. User goes through OAuth flow + // 3. Get new token with required scopes + + // Step 4: Update token on SAME session + writeToken := "token-with-scope-mcp:tools:read,mcp:tools:write" + mcpClient.SetToken(writeToken) + t.Logf("✓ Step 4: Updated token (same session)") + t.Logf(" New Token: %s", writeToken[:30]+"...") + + // Step 5: Retry write operation with upgraded token + result, err := mcpClient.CallTool(ctx, "execute_operation_update_mood", map[string]any{ + "employeeID": 1, + "mood": "HAPPY", + }) + + assert.NoError(t, err) + assert.NotNil(t, result) + t.Logf("✓ Step 5: Write operation succeeded with upgraded token") + } else { + t.Logf("✓ Step 3: Write operation succeeded (no authorization configured)") + } + }) + }) + + t.Run("Multiple token changes on same session", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + mcpClient := NewMCPAuthClient(xEnv.GetMCPServerAddr(), "initial-token") + err := mcpClient.Connect(ctx) + require.NoError(t, err) + defer mcpClient.Close() //nolint:errcheck + + t.Logf("✓ Connected with initial token") + + // Simulate multiple scope upgrades + tokens := []string{ + "token-with-basic-scopes", + "token-with-read-scopes", + "token-with-write-scopes", + "token-with-admin-scopes", + } + + for i, token := range tokens { + mcpClient.SetToken(token) + + // Make a call with the new token + result, err := mcpClient.CallTool(ctx, "execute_operation_my_employees", map[string]any{ + "criteria": map[string]any{}, + }) + + require.NoError(t, err) + require.NotNil(t, result) + t.Logf("✓ Request %d succeeded with token: %s", i+1, token[:25]+"...") + } + + t.Logf("✓ All token changes worked on same session") + }) + }) +} + +// Example_mcpAuthorizationFlow shows how to use the auth client +func Example_mcpAuthorizationFlow() { + ctx := context.Background() + + // Create client with initial token + client := NewMCPAuthClient("http://localhost:3000/mcp", "initial-token") + defer client.Close() //nolint:errcheck + + // Connect + if err := client.Connect(ctx); err != nil { + panic(err) + } + + // Try to call a tool + _, err := client.CallTool(ctx, "some_tool", map[string]any{}) + + // If we get insufficient scopes error + if err != nil { + // 1. User goes through OAuth flow (not shown) + // 2. Get new token with more scopes + newToken := "token-with-more-scopes" + + // 3. Update token on SAME session + client.SetToken(newToken) + + // 4. Retry the tool call + _, err = client.CallTool(ctx, "some_tool", map[string]any{}) + if err != nil { + panic(err) + } + } + + fmt.Println("Success!") +} diff --git a/router-tests/mcp_auth_harness_example.go b/router-tests/mcp_auth_harness_example.go new file mode 100644 index 0000000000..42d9e1550b --- /dev/null +++ b/router-tests/mcp_auth_harness_example.go @@ -0,0 +1,242 @@ +package integration + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// Example demonstrating the actual HTTP-level MCP authorization flow +// This shows how tokens are sent in HTTP headers, not JSON-RPC + +type MCPClient struct { + serverURL string + httpClient *http.Client + sessionID string // Persistent across requests +} + +// Step 1: Initialize - First HTTP POST with initial token +func (c *MCPClient) Initialize(ctx context.Context, token string) error { + // Create JSON-RPC initialize request + jsonRPCRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]string{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + // HTTP POST #1 + req, _ := http.NewRequestWithContext(ctx, "POST", c.serverURL, toReader(jsonRPCRequest)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) // ← Token in HTTP header + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() //nolint:errcheck + + // Extract session ID from HTTP response headers + c.sessionID = resp.Header.Get("Mcp-Session-Id") // ← Session ID from HTTP header + + fmt.Printf("✓ HTTP POST #1 - Initialize\n") + fmt.Printf(" Request Header: Authorization: Bearer %s\n", token[:20]+"...") + fmt.Printf(" Response Header: Mcp-Session-Id: %s\n", c.sessionID) + + return nil +} + +// Step 2: Call tool with initial token (limited scopes) +func (c *MCPClient) CallToolWithLimitedScopes(ctx context.Context, token string) error { + jsonRPCRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "execute_operation_update_mood", + "arguments": map[string]interface{}{ + "employeeID": 1, + "mood": "HAPPY", + }, + }, + } + + // HTTP POST #2 - Same session, same token + req, _ := http.NewRequestWithContext(ctx, "POST", c.serverURL, toReader(jsonRPCRequest)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) // ← Same token + req.Header.Set("Mcp-Session-Id", c.sessionID) // ← Same session ID + + fmt.Printf("\n✓ HTTP POST #2 - Call tool (limited scopes)\n") + fmt.Printf(" Request Header: Authorization: Bearer %s\n", token[:20]+"...") + fmt.Printf(" Request Header: Mcp-Session-Id: %s\n", c.sessionID) + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() //nolint:errcheck + + // Parse JSON-RPC response + var jsonRPCResp struct { + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + RequiredScopes []string `json:"required_scopes"` // ← Scopes in JSON-RPC error data + } `json:"data"` + } `json:"error"` + } + json.NewDecoder(resp.Body).Decode(&jsonRPCResp) //nolint:errcheck + + if jsonRPCResp.Error != nil { + fmt.Printf(" Response Body: JSON-RPC Error\n") + fmt.Printf(" {\n") + fmt.Printf(" \"error\": {\n") + fmt.Printf(" \"code\": %d,\n", jsonRPCResp.Error.Code) + fmt.Printf(" \"message\": \"%s\",\n", jsonRPCResp.Error.Message) + fmt.Printf(" \"data\": {\n") + fmt.Printf(" \"required_scopes\": %v\n", jsonRPCResp.Error.Data.RequiredScopes) + fmt.Printf(" }\n") + fmt.Printf(" }\n") + fmt.Printf(" }\n") + return fmt.Errorf("insufficient scopes: %v", jsonRPCResp.Error.Data.RequiredScopes) + } + + return nil +} + +// Step 3: Obtain new token (simulated OAuth flow) +func (c *MCPClient) ObtainNewToken(requiredScopes []string) string { + // In reality, this would: + // 1. Open browser to authorization server + // 2. User consents to new scopes + // 3. Exchange auth code for new access token + // 4. Return new access token + + newToken := fmt.Sprintf("new-token-with-scopes-%v", requiredScopes) + fmt.Printf("\n✓ OAuth Flow - Obtained new token\n") + fmt.Printf(" Scopes: %v\n", requiredScopes) + fmt.Printf(" New Token: %s\n", newToken[:30]+"...") + return newToken +} + +// Step 4: Retry tool call with upgraded token +func (c *MCPClient) CallToolWithUpgradedToken(ctx context.Context, newToken string) error { + jsonRPCRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "execute_operation_update_mood", + "arguments": map[string]interface{}{ + "employeeID": 1, + "mood": "HAPPY", + }, + }, + } + + // HTTP POST #3 - SAME session, DIFFERENT token + req, _ := http.NewRequestWithContext(ctx, "POST", c.serverURL, toReader(jsonRPCRequest)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", newToken)) // ← NEW token (different Authorization header) + req.Header.Set("Mcp-Session-Id", c.sessionID) // ← SAME session ID + + fmt.Printf("\n✓ HTTP POST #3 - Call tool (upgraded scopes)\n") + fmt.Printf(" Request Header: Authorization: Bearer %s ← DIFFERENT TOKEN\n", newToken[:30]+"...") + fmt.Printf(" Request Header: Mcp-Session-Id: %s ← SAME SESSION\n", c.sessionID) + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() //nolint:errcheck + + fmt.Printf(" Response: %d OK\n", resp.StatusCode) + fmt.Printf(" Response Body: JSON-RPC Success\n") + + return nil +} + +func toReader(v interface{}) io.Reader { + b, _ := json.Marshal(v) + return bytes.NewReader(b) +} + +// ExampleAuthorizationFlow demonstrates the complete flow +func ExampleAuthorizationFlow() { + client := &MCPClient{ + serverURL: "http://localhost:3000/mcp", + httpClient: &http.Client{}, + } + + ctx := context.Background() + + // Step 1: Initialize with limited scopes + initialToken := "token-with-scopes-mcp:tools:read" + client.Initialize(ctx, initialToken) //nolint:errcheck + + // Step 2: Try to call write operation (will fail) + err := client.CallToolWithLimitedScopes(ctx, initialToken) + + // Step 3: Get new token with required scopes + if err != nil { + newToken := client.ObtainNewToken([]string{"mcp:tools:write"}) + + // Step 4: Retry with upgraded token (same session!) + _ = client.CallToolWithUpgradedToken(ctx, newToken) + } + + fmt.Printf("\n=== Summary ===\n") + fmt.Printf("• Session persists via Mcp-Session-Id HTTP header\n") + fmt.Printf("• Authorization changes via Authorization HTTP header\n") + fmt.Printf("• Each JSON-RPC request is a separate HTTP POST\n") + fmt.Printf("• HTTP headers carry auth/session, not JSON-RPC payload\n") +} + +/* +Expected Output: + +✓ HTTP POST #1 - Initialize + Request Header: Authorization: Bearer token-with-scopes-mc... + Response Header: Mcp-Session-Id: abc-123-def-456 + +✓ HTTP POST #2 - Call tool (limited scopes) + Request Header: Authorization: Bearer token-with-scopes-mc... + Request Header: Mcp-Session-Id: abc-123-def-456 + Response Body: JSON-RPC Error + { + "error": { + "code": -32001, + "message": "Insufficient permissions", + "data": { + "required_scopes": [mcp:tools:write] + } + } + } + +✓ OAuth Flow - Obtained new token + Scopes: [mcp:tools:write] + New Token: new-token-with-scopes-[mcp:too... + +✓ HTTP POST #3 - Call tool (upgraded scopes) + Request Header: Authorization: Bearer new-token-with-scopes-[mcp:too... ← DIFFERENT TOKEN + Request Header: Mcp-Session-Id: abc-123-def-456 ← SAME SESSION + Response: 200 OK + Response Body: JSON-RPC Success + +=== Summary === +• Session persists via Mcp-Session-Id HTTP header +• Authorization changes via Authorization HTTP header +• Each JSON-RPC request is a separate HTTP POST +• HTTP headers carry auth/session, not JSON-RPC payload +*/ diff --git a/router-tests/mcp_oauth_e2e_test.go b/router-tests/mcp_oauth_e2e_test.go new file mode 100644 index 0000000000..578d86d0ee --- /dev/null +++ b/router-tests/mcp_oauth_e2e_test.go @@ -0,0 +1,187 @@ +package integration + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router-tests/testutil" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +// TestMCPOAuthScopeUpgrade tests the complete OAuth scope upgrade flow with real JWT validation +// This test verifies: +// 1. Server validates JWT tokens using JWKS +// 2. Server returns HTTP 403 with WWW-Authenticate header for insufficient scopes +// 3. Client can parse the WWW-Authenticate header to get required scopes +// 4. Client can upgrade token and retry on the same MCP session +func TestMCPOAuthScopeUpgrade(t *testing.T) { + // Start JWKS test server + jwksServer, err := testutil.NewJWKSTestServer(t, "8765") + require.NoError(t, err, "failed to start JWKS server") + defer jwksServer.Close() //nolint:errcheck + + // Step 1: Create valid JWT with read-only scope for testenv initialization + readOnlyToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:tools:read"}) + require.NoError(t, err, "failed to create read-only token") + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ExposeSchema: true, // Enable get_schema tool + EnableArbitraryOperations: true, // Enable execute_graphql tool + OAuth: config.MCPOAuthConfiguration{ + Enabled: true, + JWKS: []config.JWKSConfiguration{ + { + URL: jwksServer.JWKSURL(), + }, + }, + AuthorizationServerURL: jwksServer.Issuer(), + // No initialize scopes - any valid token can initialize + // Per-tool scopes can be configured in ScopesRequired map + ScopesRequired: map[string][]string{ + // Example: "get_schema": {"mcp:tools:read"}, + }, + }, + }, + MCPAuthToken: readOnlyToken, // Pass token so testenv can initialize successfully + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), readOnlyToken) + err = client.Connect(ctx) + require.NoError(t, err, "should connect with valid token") + defer client.Close() //nolint:errcheck + + t.Log("✓ Connected with read-only token") + + // Step 2: Call a tool (should succeed with any valid token) + result, err := client.CallTool(ctx, "get_schema", nil) + require.NoError(t, err, "get_schema should succeed with valid token") + require.NotNil(t, result) + t.Log("✓ Tool call succeeded with initial token") + + // Step 3: Create new token with different scopes + // NOTE: Per-tool scope authorization is not implemented yet, + // but token changes on persistent sessions are the key feature being tested + newToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:tools:read", "mcp:tools:write"}) + require.NoError(t, err, "failed to create new token") + + // Step 4: Update token on SAME session (key point!) + client.SetToken(newToken) + t.Log("✓ Updated to new token (same session)") + + // Step 5: Call tool again with new token to verify token change worked + result, err = client.CallTool(ctx, "execute_graphql", map[string]any{ + "query": "query { employees { id } }", + }) + + require.NoError(t, err, "tool call should succeed after token change") + require.NotNil(t, result) + t.Log("✓ Tool call succeeded with new token") + t.Log("✓ Session persisted through token change") + + // Step 6: Verify we can change tokens multiple times on same session + anotherToken, err := jwksServer.CreateTokenWithScopes("different-user", []string{"mcp:admin"}) + require.NoError(t, err, "failed to create another token") + + client.SetToken(anotherToken) + _, err = client.CallTool(ctx, "get_schema", nil) + require.NoError(t, err, "should succeed after second token change") + t.Log("✓ Multiple token changes work on same session") + }) +} + +// TestMCPOAuthInvalidToken tests that invalid JWT tokens are rejected with HTTP 401 +func TestMCPOAuthInvalidToken(t *testing.T) { + // Start JWKS test server + jwksServer, err := testutil.NewJWKSTestServer(t, "8766") + require.NoError(t, err, "failed to start JWKS server") + defer jwksServer.Close() //nolint:errcheck + + // Create a valid token for testenv initialization (so router starts up) + validToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:tools:read"}) + require.NoError(t, err, "failed to create valid token") + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + OAuth: config.MCPOAuthConfiguration{ + Enabled: true, + JWKS: []config.JWKSConfiguration{ + { + URL: jwksServer.JWKSURL(), + }, + }, + AuthorizationServerURL: jwksServer.Issuer(), + }, + }, + MCPAuthToken: validToken, // Pass valid token for testenv initialization + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + // Use an invalid token for the test client + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), "invalid-jwt-token") + + err := client.Connect(ctx) + // Should fail during connect/initialize + require.Error(t, err, "should fail to connect with invalid token") + + // Check if it's an auth error with HTTP 401 + authErr, ok := err.(*AuthError) + if ok { + assert.Equal(t, http.StatusUnauthorized, authErr.StatusCode, "should return HTTP 401") + assert.NotEmpty(t, authErr.ResourceMetadataURL, "should include resource_metadata for OAuth discovery") + t.Logf("✓ Invalid token rejected with HTTP 401: %v", authErr) + } + }) +} + +// TestMCPOAuthMissingToken tests that missing Authorization header is rejected +func TestMCPOAuthMissingToken(t *testing.T) { + // Start JWKS test server + jwksServer, err := testutil.NewJWKSTestServer(t, "8767") + require.NoError(t, err, "failed to start JWKS server") + defer jwksServer.Close() //nolint:errcheck + + // Create a valid token for testenv initialization (so router starts up) + validToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:tools:read"}) + require.NoError(t, err, "failed to create valid token") + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + OAuth: config.MCPOAuthConfiguration{ + Enabled: true, + JWKS: []config.JWKSConfiguration{ + { + URL: jwksServer.JWKSURL(), + }, + }, + AuthorizationServerURL: jwksServer.Issuer(), + }, + }, + MCPAuthToken: validToken, // Pass valid token for testenv initialization + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + // Create test client without any token + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), "") + + err := client.Connect(ctx) + // Should fail during connect/initialize + require.Error(t, err, "should fail to connect without token") + + // Check if it's an auth error with HTTP 401 + authErr, ok := err.(*AuthError) + if ok { + assert.Equal(t, http.StatusUnauthorized, authErr.StatusCode, "should return HTTP 401") + assert.NotEmpty(t, authErr.ResourceMetadataURL, "should include resource_metadata for OAuth discovery") + t.Logf("✓ Request without token rejected with HTTP 401: %v", authErr) + } + }) +} \ No newline at end of file diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 39de7ffbe7..975c6c93c4 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -473,7 +473,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck //nolint:errcheck // Verify response status assert.Equal(t, http.StatusNoContent, resp.StatusCode) @@ -530,7 +530,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck // Verify CORS headers are present in the response assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) @@ -564,7 +564,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck // Verify CORS headers are present in the response assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) @@ -602,7 +602,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck // Verify CORS headers are present assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) @@ -643,7 +643,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck // Verify CORS headers are present in the response assert.Equal(t, "*", resp.Header.Get("Access-Control-Allow-Origin")) @@ -947,7 +947,7 @@ input UserInput { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck // With stateless mode, the request should succeed t.Logf("Response Status: %d", resp.StatusCode) @@ -1053,7 +1053,7 @@ input UserInput { resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { t.Logf("Response Status: %d", resp.StatusCode) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index bc99b90a1a..8098fd8c7b 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -31,6 +31,7 @@ import ( "github.com/cloudflare/backoff" mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" "github.com/golang-jwt/jwt/v5" @@ -342,6 +343,7 @@ type Config struct { NoShutdownTestServer bool MCP config.MCPConfiguration MCPOperationsPath string + MCPAuthToken string // Optional Bearer token for MCP authentication EnableRedis bool EnableRedisCluster bool Plugins PluginConfig @@ -814,7 +816,17 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) { if cfg.MCP.Enabled { // Create MCP client connecting to the MCP server mcpAddr := fmt.Sprintf("http://%s/mcp", cfg.MCP.Server.ListenAddr) - client, err := mcpclient.NewStreamableHttpClient(mcpAddr) + + // Add authentication headers if token is provided + var clientOpts []transport.StreamableHTTPCOption + if cfg.MCPAuthToken != "" { + headers := map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", cfg.MCPAuthToken), + } + clientOpts = append(clientOpts, transport.WithHTTPHeaders(headers)) + } + + client, err := mcpclient.NewStreamableHttpClient(mcpAddr, clientOpts...) if err != nil { t.Fatalf("Failed to create MCP client: %v", err) } @@ -1234,7 +1246,17 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) { if cfg.MCP.Enabled { // Create MCP client connecting to the MCP server mcpAddr := fmt.Sprintf("http://%s/mcp", cfg.MCP.Server.ListenAddr) - client, err := mcpclient.NewStreamableHttpClient(mcpAddr) + + // Add authentication headers if token is provided + var clientOpts []transport.StreamableHTTPCOption + if cfg.MCPAuthToken != "" { + headers := map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", cfg.MCPAuthToken), + } + clientOpts = append(clientOpts, transport.WithHTTPHeaders(headers)) + } + + client, err := mcpclient.NewStreamableHttpClient(mcpAddr, clientOpts...) if err != nil { t.Fatalf("Failed to create MCP client: %v", err) } diff --git a/router-tests/testutil/auth_helpers.go b/router-tests/testutil/auth_helpers.go new file mode 100644 index 0000000000..d735ffb0a4 --- /dev/null +++ b/router-tests/testutil/auth_helpers.go @@ -0,0 +1,61 @@ +package testutil + +import ( + "strings" +) + +// ParseWWWAuthenticateParams parses the WWW-Authenticate header from HTTP responses. +// This is a simple parser for test validation only, not production use. +// +// NOTE: LLM-generated - there are no well-established Go libraries for parsing +// WWW-Authenticate response headers (as of 2026). This parser handles the +// common case of Bearer authentication with quoted parameter values. +// +// Example input: `Bearer error="insufficient_scope", scope="read write", resource_metadata="https://example.com"` +// Example output: map[string]string{"error": "insufficient_scope", "scope": "read write", "resource_metadata": "https://example.com"} +func ParseWWWAuthenticateParams(header string) map[string]string { + params := make(map[string]string) + + // Remove "Bearer " prefix + header = strings.TrimPrefix(header, "Bearer ") + header = strings.TrimSpace(header) + + // Simple state machine to parse key="value" pairs + var key, value strings.Builder + inKey := true + inQuote := false + + for i := 0; i < len(header); i++ { + ch := header[i] + + switch { + case ch == '=' && inKey: + inKey = false + case ch == '"' && !inKey: + // Track quote state but don't add quotes to value + inQuote = !inQuote + case ch == ',' && !inQuote: + if key.Len() > 0 { + params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) + } + key.Reset() + value.Reset() + inKey = true + case inKey: + key.WriteByte(ch) + default: + // We're in a value (!inKey) and ch is not a quote (already handled above) + // Include everything (including spaces) when inside quotes + if inQuote || ch != ' ' || value.Len() > 0 { + value.WriteByte(ch) + } + } + } + + // Add final pair + if key.Len() > 0 { + params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) + } + + return params +} diff --git a/router-tests/testutil/jwt_helper.go b/router-tests/testutil/jwt_helper.go new file mode 100644 index 0000000000..f875469220 --- /dev/null +++ b/router-tests/testutil/jwt_helper.go @@ -0,0 +1,185 @@ +package testutil + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" + "github.com/wundergraph/cosmo/router-tests/jwks" +) + +// JWKSTestServer provides JWT token generation for testing +type JWKSTestServer struct { + t *testing.T + provider jwks.Crypto + keyID string + issuer string + audience string + jwksURL string + server *http.Server + storage jwkset.Storage +} + +// NewJWKSTestServer creates a new JWKS test server with RSA keys +func NewJWKSTestServer(t *testing.T, port string) (*JWKSTestServer, error) { + t.Helper() + + keyID := "test_rsa" + provider, err := jwks.NewRSACrypto(keyID, jwkset.AlgRS256, 2048) + if err != nil { + return nil, fmt.Errorf("failed to create RSA crypto: %w", err) + } + + storage := jwkset.NewMemoryStorage() + ctx := context.Background() + + jwk, err := provider.MarshalJWK() + if err != nil { + return nil, fmt.Errorf("failed to marshal JWK: %w", err) + } + + if err := storage.KeyWrite(ctx, jwk); err != nil { + return nil, fmt.Errorf("failed to write key to storage: %w", err) + } + + server := &JWKSTestServer{ + t: t, + provider: provider, + keyID: keyID, + issuer: fmt.Sprintf("http://localhost:%s", port), + audience: "test-audience", + jwksURL: fmt.Sprintf("http://localhost:%s/.well-known/jwks.json", port), + storage: storage, + } + + // Start HTTP server + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/jwks.json", server.handleJWKS) + + httpServer := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + server.server = httpServer + + go func() { + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + t.Logf("JWKS server error: %v", err) + } + }() + + // Wait for server to start + if err := server.waitForReady(5 * time.Second); err != nil { + return nil, fmt.Errorf("JWKS server failed to start: %w", err) + } + + t.Logf("JWKS test server started at %s", server.issuer) + + return server, nil +} + +// waitForReady waits for the server to be ready +func (s *JWKSTestServer) waitForReady(timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for JWKS server") + case <-ticker.C: + resp, err := http.Get(s.jwksURL) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + } + } + } +} + +// handleJWKS serves the JWKS JSON +func (s *JWKSTestServer) handleJWKS(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + rawJWKS, err := s.storage.JSON(ctx) + if err != nil { + s.t.Logf("Failed to get JWKS: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(rawJWKS) +} + +// CreateToken creates a JWT token with the specified claims +// Default claims (iss, aud, iat, exp) are added automatically +func (s *JWKSTestServer) CreateToken(claims map[string]any) (string, error) { + s.t.Helper() + + now := time.Now() + tokenClaims := jwt.MapClaims{ + "iss": s.issuer, + "aud": s.audience, + "iat": now.Unix(), + "exp": now.Add(1 * time.Hour).Unix(), + } + + // Merge custom claims + for k, v := range claims { + tokenClaims[k] = v + } + + token := jwt.NewWithClaims(s.provider.SigningMethod(), tokenClaims) + token.Header[jwkset.HeaderKID] = s.keyID + + signed, err := token.SignedString(s.provider.PrivateKey()) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return signed, nil +} + +// CreateTokenWithScopes creates a token with specific OAuth scopes +func (s *JWKSTestServer) CreateTokenWithScopes(sub string, scopes []string) (string, error) { + s.t.Helper() + + scopeStr := "" + if len(scopes) > 0 { + scopeStr = scopes[0] + for i := 1; i < len(scopes); i++ { + scopeStr += " " + scopes[i] + } + } + + return s.CreateToken(map[string]any{ + "sub": sub, + "scope": scopeStr, + }) +} + +// JWKSURL returns the URL of the JWKS endpoint +func (s *JWKSTestServer) JWKSURL() string { + return s.jwksURL +} + +// Issuer returns the issuer URL +func (s *JWKSTestServer) Issuer() string { + return s.issuer +} + +// Close stops the JWKS server +func (s *JWKSTestServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} diff --git a/router/core/router.go b/router/core/router.go index ad4b77cc33..fe4c73d84e 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -947,6 +947,16 @@ func (r *Router) bootstrap(ctx context.Context) error { mcpOpts = append(mcpOpts, mcpserver.WithCORS(*r.corsOptions)) } + // Add OAuth configuration if enabled + if r.mcp.OAuth.Enabled { + mcpOpts = append(mcpOpts, mcpserver.WithOAuth(&r.mcp.OAuth)) + + // Add server base URL for OAuth discovery if configured + if r.mcp.Server.BaseURL != "" { + mcpOpts = append(mcpOpts, mcpserver.WithServerBaseURL(r.mcp.Server.BaseURL)) + } + } + // Determine the router GraphQL endpoint var routerGraphQLEndpoint string diff --git a/router/go.mod b/router/go.mod index 415918d050..3ba129a9d3 100644 --- a/router/go.mod +++ b/router/go.mod @@ -73,7 +73,7 @@ require ( github.com/hashicorp/go-plugin v1.6.3 github.com/iancoleman/strcase v0.3.0 github.com/klauspost/compress v1.18.0 - github.com/mark3labs/mcp-go v0.36.0 + github.com/mark3labs/mcp-go v0.43.2 github.com/minio/minio-go/v7 v7.0.74 github.com/posthog/posthog-go v1.5.5 github.com/pquerna/cachecontrol v0.2.0 diff --git a/router/go.sum b/router/go.sum index 715b552f47..b08ab75735 100644 --- a/router/go.sum +++ b/router/go.sum @@ -186,8 +186,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.36.0 h1:rIZaijrRYPeSbJG8/qNDe0hWlGrCJ7FWHNMz2SQpTis= -github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 8eb71bc5f1..30abc5a8cf 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -990,18 +990,35 @@ type CacheWarmupConfiguration struct { } type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - Session MCPSessionConfig `yaml:"session,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + Session MCPSessionConfig `yaml:"session,omitempty"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` // OmitToolNamePrefix removes the "execute_operation_" prefix from MCP tool names. // When enabled, GetUser becomes get_user. When disabled (default), GetUser becomes execute_operation_get_user. - OmitToolNamePrefix bool `yaml:"omit_tool_name_prefix" envDefault:"false" env:"MCP_OMIT_TOOL_NAME_PREFIX"` + OmitToolNamePrefix bool `yaml:"omit_tool_name_prefix" envDefault:"false" env:"MCP_OMIT_TOOL_NAME_PREFIX"` + OAuth MCPOAuthConfiguration `yaml:"oauth,omitempty" envPrefix:"MCP_OAUTH_"` +} + +type MCPOAuthConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + JWKS []JWKSConfiguration `yaml:"jwks"` + AuthorizationServerURL string `yaml:"authorization_server_url,omitempty" env:"AUTHORIZATION_SERVER_URL"` + // ScopesRequired maps tool names or special keys to their required scopes. + // Special key "initialize" specifies scopes required for HTTP-level access (all requests). + // Tool names (e.g., "get_schema", "execute_operation_employees") specify per-tool scopes. + // All scopes from this map are automatically unioned into scopes_supported for OAuth metadata. + // Example: + // scopes_required: + // initialize: ["mcp:init"] + // get_schema: ["mcp:tools:read"] + // execute_operation_create_employee: ["write:employees"] + ScopesRequired map[string][]string `yaml:"scopes_required,omitempty"` } type MCPSessionConfig struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index a531fa4af3..f1ac77dba1 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2088,10 +2088,8 @@ "format": "hostname-port" }, "base_url": { - "deprecated": true, - "deprecationMessage": "The base_url is deprecated. This property was related to the SSE protocol that is not supported anymore.", "type": "string", - "description": "The base URL of the MCP server. This is the URL advertised to the LLM clients when SSE is used as primary transport. By default, the base URL is relative to the URL that the router is running on. The URL is specified as a string with the format 'scheme://host:port'.", + "description": "The base URL of the MCP server used for OAuth 2.0 discovery (RFC 9728). This URL is advertised in the Protected Resource Metadata endpoint and used to construct the resource metadata URL. Required when OAuth is enabled. The URL is specified as a string with the format 'scheme://host:port'.", "format": "http-url" } } @@ -2149,6 +2147,182 @@ "type": "boolean", "default": false, "description": "When enabled, MCP tool names generated from GraphQL operations omit the 'execute_operation_' prefix. For example, the GraphQL operation 'GetUser' results in a tool named 'get_user' instead of 'execute_operation_get_user'." + }, + "oauth": { + "type": "object", + "description": "OAuth/JWKS authentication configuration for the MCP server. When enabled, MCP tool calls require valid JWT authentication and the server implements OAuth 2.0 discovery mechanisms (RFC 8414, RFC 9728).", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false, + "description": "Enable OAuth/JWKS authentication for the MCP server. When true, all MCP tool calls must include a valid JWT token." + }, + "authorization_server_url": { + "type": "string", + "description": "The base URL of the OAuth 2.0 authorization server. This URL is advertised to MCP clients via the Protected Resource Metadata endpoint (RFC 9728) to enable automatic discovery of OAuth endpoints. Clients will append '/.well-known/oauth-authorization-server' to this URL to discover token, authorization, and registration endpoints. Example: 'https://auth.example.com'", + "format": "http-url" + }, + "scopes_required": { + "type": "object", + "description": "Map of tool names or special keys to their required OAuth scopes. The special key 'initialize' specifies scopes required for HTTP-level access (all MCP requests). Tool names (e.g., 'get_schema', 'execute_operation_employees') specify per-tool scopes. All scopes from this map are automatically unioned into 'scopes_supported' for OAuth metadata. Example: {'initialize': ['mcp:init'], 'get_schema': ['mcp:tools:read'], 'execute_operation_create_employee': ['write:employees']}", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "jwks": { + "type": "array", + "description": "List of JWKS (JSON Web Key Set) configurations for JWT token verification. Multiple JWKS providers can be configured for different authentication sources.", + "items": { + "type": "object", + "additionalProperties": false, + "properties": { + "url": { + "type": "string", + "description": "The URL of the JWKs. The JWKs are used to verify the JWT (JSON Web Token). The URL is specified as a string with the format 'scheme://host:port'.", + "format": "http-url" + }, + "audiences": { + "type": "array", + "description": "The audiences of the JWKs. The audiences are used to verify the JWT (JSON Web Token). The audiences are specified as a list of strings.", + "items": { + "type": "string" + } + }, + "secret": { + "type": "string", + "description": "The secret of the JWKs" + }, + "symmetric_algorithm": { + "type": "string", + "description": "The symmetric algorithm used", + "enum": ["HS256", "HS384", "HS512"] + }, + "header_key_id": { + "type": "string", + "description": "The KID header of the JWK token created using the secret" + }, + "allowed_use": { + "type": "array", + "description": "The allowed value of the use parameter for the JWKs. If not specified, only keys with use set to 'sig' will be used. If your server provides no use, you can add an empty value to allow those keys.", + "default": ["sig"], + "items": { + "type": "string", + "enum": [ + "sig", + "enc", + "" + ] + } + }, + "algorithms": { + "type": "array", + "description": "The allowed algorithms for the keys that are retrieved from the JWKs. An empty list means that all algorithms are allowed.", + "items": { + "type": "string", + "enum": [ + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA" + ] + } + }, + "refresh_interval": { + "type": "string", + "duration": { + "minimum": "5s" + }, + "description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "1m" + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } + } + }, + "oneOf": [ + { + "required": ["url"], + "not": { + "anyOf": [ + { + "required": ["secret"] + }, + { + "required": ["symmetric_algorithm"] + }, + { + "required": ["header_key_id"] + } + ] + } + }, + { + "required": ["secret", "symmetric_algorithm", "header_key_id"], + "not": { + "anyOf": [ + { + "required": ["url"] + }, + { + "required": ["algorithms"] + }, + { + "required": ["refresh_interval"] + }, + { + "required": ["refresh_unknown_kid"] + } + ] + } + } + ] + } + } + } } } }, @@ -3273,6 +3447,138 @@ } }, "$defs": { + "jwks_configuration": { + "type": "object", + "additionalProperties": false, + "properties": { + "url": { + "type": "string", + "description": "The URL of the JWKs. The JWKs are used to verify the JWT (JSON Web Token). The URL is specified as a string with the format 'scheme://host:port'.", + "format": "http-url" + }, + "audiences": { + "type": "array", + "description": "The audiences of the JWKs. The audiences are used to verify the JWT (JSON Web Token). The audiences are specified as a list of strings.", + "items": { + "type": "string" + } + }, + "secret": { + "type": "string", + "description": "The secret of the JWKs" + }, + "symmetric_algorithm": { + "type": "string", + "description": "The symmetric algorithm used", + "enum": ["HS256", "HS384", "HS512"] + }, + "header_key_id": { + "type": "string", + "description": "The KID header of the JWK token created using the secret" + }, + "algorithms": { + "type": "array", + "description": "The allowed algorithms for the keys that are retrieved from the JWKs. An empty list means that all algorithms are allowed.", + "items": { + "type": "string", + "enum": [ + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA" + ] + } + }, + "refresh_interval": { + "type": "string", + "duration": { + "minimum": "5s" + }, + "description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "1m" + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } + } + }, + "oneOf": [ + { + "required": ["url"], + "not": { + "anyOf": [ + { + "required": ["secret"] + }, + { + "required": ["symmetric_algorithm"] + }, + { + "required": ["header_key_id"] + } + ] + } + }, + { + "required": ["secret", "symmetric_algorithm", "header_key_id"], + "not": { + "anyOf": [ + { + "required": ["url"] + }, + { + "required": ["algorithms"] + }, + { + "required": ["refresh_interval"] + }, + { + "required": ["refresh_unknown_kid"] + } + ] + } + } + ] + }, "traffic_shaping_subgraph_request_rule": { "type": "object", "additionalProperties": false, diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index b4ddad685e..da740d344c 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -138,7 +138,13 @@ "EnableArbitraryOperations": false, "ExposeSchema": false, "RouterURL": "", - "OmitToolNamePrefix": false + "OmitToolNamePrefix": false, + "OAuth": { + "Enabled": false, + "JWKS": null, + "AuthorizationServerURL": "", + "ScopesRequired": null + } }, "DemoMode": false, "Modules": null, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d4707aa1a8..0c29f516d2 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -173,7 +173,13 @@ "EnableArbitraryOperations": false, "ExposeSchema": false, "RouterURL": "https://cosmo-router.wundergraph.com", - "OmitToolNamePrefix": false + "OmitToolNamePrefix": false, + "OAuth": { + "Enabled": false, + "JWKS": null, + "AuthorizationServerURL": "", + "ScopesRequired": null + } }, "DemoMode": true, "Modules": { diff --git a/router/pkg/mcpserver/auth_middleware.go b/router/pkg/mcpserver/auth_middleware.go new file mode 100644 index 0000000000..1f594a9b87 --- /dev/null +++ b/router/pkg/mcpserver/auth_middleware.go @@ -0,0 +1,256 @@ +package mcpserver + +import ( + "context" + "fmt" + "net/http" + "slices" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/wundergraph/cosmo/router/pkg/authentication" +) + +type contextKey string + +const ( + userClaimsContextKey contextKey = "mcp_user_claims" +) + +// mcpAuthProvider adapts MCP headers to the authentication.Provider interface +type mcpAuthProvider struct { + headers http.Header +} + +func (p *mcpAuthProvider) AuthenticationHeaders() http.Header { + return p.headers +} + +// MCPAuthMiddleware creates authentication middleware for MCP tools and resources +type MCPAuthMiddleware struct { + authenticator authentication.Authenticator + enabled bool + resourceMetadataURL string + requiredScopes []string // Minimal scopes required for any access +} + +// NewMCPAuthMiddleware creates a new authentication middleware using the existing +// authentication infrastructure from the router +func NewMCPAuthMiddleware(tokenDecoder authentication.TokenDecoder, enabled bool, resourceMetadataURL string, requiredScopes []string) (*MCPAuthMiddleware, error) { + if tokenDecoder == nil { + return nil, fmt.Errorf("token decoder must be provided") + } + + // Use the existing HttpHeaderAuthenticator with default settings (Authorization header, Bearer prefix) + // This ensures consistency with the rest of the router's authentication logic + authenticator, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{ + Name: "mcp-auth", + TokenDecoder: tokenDecoder, + // HeaderSourcePrefixes defaults to {"Authorization": {"Bearer"}} when not specified + // This can be extended in the future to support additional schemes like DPoP + }) + if err != nil { + return nil, fmt.Errorf("failed to create authenticator: %w", err) + } + + return &MCPAuthMiddleware{ + authenticator: authenticator, + enabled: enabled, + resourceMetadataURL: resourceMetadataURL, + requiredScopes: requiredScopes, + }, nil +} + +// ToolMiddleware wraps tool handlers with authentication +func (m *MCPAuthMiddleware) ToolMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if !m.enabled { + return next(ctx, req) + } + + // Extract and validate token + claims, err := m.authenticateRequest(ctx) + if err != nil { + // Return authentication error with WWW-Authenticate challenge information + // Per RFC 9728, we should indicate the resource metadata URL + errorMsg := fmt.Sprintf("Authentication failed: %v", err) + if m.resourceMetadataURL != "" { + errorMsg = fmt.Sprintf("Authentication required. Resource metadata available at: %s. Error: %v", + m.resourceMetadataURL, err) + } + return mcp.NewToolResultError(errorMsg), nil + } + + // Add claims to context + ctx = context.WithValue(ctx, userClaimsContextKey, claims) + + return next(ctx, req) + } +} + +// authenticateRequest extracts and validates the JWT token using the existing +// authentication infrastructure from the router +func (m *MCPAuthMiddleware) authenticateRequest(ctx context.Context) (authentication.Claims, error) { + // Extract headers from context (passed by mcp-go HTTP transport) + headers, err := headersFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("missing request headers: %w", err) + } + + // Use the existing authenticator instead of manual token parsing + // This provides better error messages and supports multiple authentication schemes + provider := &mcpAuthProvider{headers: headers} + claims, err := m.authenticator.Authenticate(ctx, provider) + if err != nil { + return nil, fmt.Errorf("authentication failed: %w", err) + } + + // If claims are empty, treat as authentication failure + if len(claims) == 0 { + return nil, fmt.Errorf("authentication failed: no valid credentials provided") + } + + // Validate required scopes + if err := m.validateScopes(claims); err != nil { + return nil, err + } + + return claims, nil +} + +// HTTPMiddleware wraps HTTP handlers with authentication for ALL MCP operations +// Per MCP specification: "authorization MUST be included in every HTTP request from client to server" +func (m *MCPAuthMiddleware) HTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !m.enabled { + next.ServeHTTP(w, r) + return + } + + // Create a provider from the HTTP request headers + provider := &mcpAuthProvider{headers: r.Header} + + // Validate the token + claims, err := m.authenticator.Authenticate(r.Context(), provider) + if err != nil || len(claims) == 0 { + m.sendUnauthorizedResponse(w, err) + return + } + + // Validate required scopes + if err := m.validateScopes(claims); err != nil { + m.sendInsufficientScopeResponse(w, err) + return + } + + // Add claims to request context for downstream handlers + ctx := context.WithValue(r.Context(), userClaimsContextKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// sendUnauthorizedResponse sends a 401 Unauthorized response with proper headers +func (m *MCPAuthMiddleware) sendUnauthorizedResponse(w http.ResponseWriter, err error) { + // Build WWW-Authenticate header per RFC 6750 and RFC 9728 + authHeader := `Bearer realm="mcp"` + + // Add resource_metadata per RFC 9728 for OAuth discovery + if m.resourceMetadataURL != "" { + authHeader += fmt.Sprintf(`, resource_metadata="%s"`, m.resourceMetadataURL) + } + + // Add optional error_description for debugging + if err != nil { + authHeader += fmt.Sprintf(`, error_description="%s"`, err.Error()) + } + + w.Header().Set("WWW-Authenticate", authHeader) + w.WriteHeader(http.StatusUnauthorized) + + // Per MCP spec: Authorization failures at HTTP level return only HTTP status and WWW-Authenticate header + // No JSON-RPC response body is returned +} + +// sendInsufficientScopeResponse sends a 403 Forbidden response per RFC 6750 +// when the token is valid but lacks required scopes +func (m *MCPAuthMiddleware) sendInsufficientScopeResponse(w http.ResponseWriter, err error) { + // Build WWW-Authenticate header with error and scope information + // Per RFC 6750 Section 3.1 and MCP spec: error, scope, resource_metadata, error_description + scopeList := strings.Join(m.requiredScopes, " ") + + authHeader := fmt.Sprintf(`Bearer error="insufficient_scope", scope="%s"`, scopeList) + + // Add resource_metadata per MCP spec (should be included per spec line 513) + if m.resourceMetadataURL != "" { + authHeader += fmt.Sprintf(`, resource_metadata="%s"`, m.resourceMetadataURL) + } + + // Add optional error_description for human-readable message + if err != nil { + authHeader += fmt.Sprintf(`, error_description="%s"`, err.Error()) + } + + w.Header().Set("WWW-Authenticate", authHeader) + w.WriteHeader(http.StatusForbidden) + + // Per MCP spec: Authorization failures at HTTP level return only HTTP status and WWW-Authenticate header + // No JSON-RPC response body is returned +} + +// validateScopes checks if the token contains all required scopes +func (m *MCPAuthMiddleware) validateScopes(claims authentication.Claims) error { + // If no scopes are required, skip validation + if len(m.requiredScopes) == 0 { + return nil + } + + // Extract scopes from claims + tokenScopes := extractScopes(claims) + + // Check if all required scopes are present + var missingScopes []string + for _, requiredScope := range m.requiredScopes { + if !contains(tokenScopes, requiredScope) { + missingScopes = append(missingScopes, requiredScope) + } + } + + if len(missingScopes) > 0 { + return fmt.Errorf("missing required scopes: %s", strings.Join(missingScopes, ", ")) + } + + return nil +} + +// extractScopes extracts scope values from JWT claims +// Supports only the OAuth 2.0 standard "scope" claim as a space-separated string +func extractScopes(claims authentication.Claims) []string { + // Check for "scope" claim (OAuth 2.0 standard - space-separated string) + scopeClaim, ok := claims["scope"] + if !ok { + return []string{} + } + + // Only support string format per OAuth 2.0 spec + scopeStr, ok := scopeClaim.(string) + if !ok { + return []string{} + } + + // Use Fields() to split on any whitespace (spaces, tabs, newlines) + // and automatically filter out empty strings + return strings.Fields(scopeStr) +} + +// contains checks if a slice contains a specific string +func contains(slice []string, item string) bool { + return slices.Contains(slice, item) +} + +// GetClaimsFromContext retrieves authenticated user claims from context +func GetClaimsFromContext(ctx context.Context) (authentication.Claims, bool) { + claims, ok := ctx.Value(userClaimsContextKey).(authentication.Claims) + return claims, ok +} diff --git a/router/pkg/mcpserver/auth_middleware_test.go b/router/pkg/mcpserver/auth_middleware_test.go new file mode 100644 index 0000000000..66fcc5e778 --- /dev/null +++ b/router/pkg/mcpserver/auth_middleware_test.go @@ -0,0 +1,881 @@ +package mcpserver + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/cosmo/router/pkg/authentication" +) + +const ( + testMetadataURL = "http://localhost:5025/.well-known/oauth-protected-resource" +) + +// parseWWWAuthenticateParams parses key-value pairs from a WWW-Authenticate Bearer header. +// This is a simple parser for test validation only, not production use. +// +// NOTE: LLM-generated - there are no well-established Go libraries for parsing +// WWW-Authenticate response headers (to-date). This parser handles the +// common case of Bearer authentication with quoted parameter values. +func parseWWWAuthenticateParams(header string) map[string]string { + params := make(map[string]string) + + // Remove "Bearer " prefix + header = strings.TrimPrefix(header, "Bearer ") + header = strings.TrimSpace(header) + + // Simple state machine to parse key="value" pairs + var key, value strings.Builder + inKey := true + inQuote := false + + for i := 0; i < len(header); i++ { + ch := header[i] + + switch { + case ch == '=' && inKey: + inKey = false + case ch == '"' && !inKey: + // Track quote state but don't add quotes to value + inQuote = !inQuote + case ch == ',' && !inQuote: + if key.Len() > 0 { + params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) + } + key.Reset() + value.Reset() + inKey = true + case inKey: + key.WriteByte(ch) + default: + // We're in a value (!inKey) and ch is not a quote (already handled above) + // Include everything (including spaces) when inside quotes + if inQuote || ch != ' ' || value.Len() > 0 { + value.WriteByte(ch) + } + } + } + + // Add final pair + if key.Len() > 0 { + params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) + } + + return params +} + +// mockTokenDecoder is a mock implementation of authentication.TokenDecoder for testing +type mockTokenDecoder struct { + decodeFunc func(token string) (authentication.Claims, error) +} + +func (m *mockTokenDecoder) Decode(token string) (authentication.Claims, error) { + if m.decodeFunc != nil { + return m.decodeFunc(token) + } + return nil, errors.New("not implemented") +} + +// getTextFromResult extracts text from the first content item in a result +func getTextFromResult(result *mcp.CallToolResult) string { + if result == nil || len(result.Content) == 0 { + return "" + } + textContent, ok := mcp.AsTextContent(result.Content[0]) + if !ok { + return "" + } + return textContent.Text +} + +func TestNewMCPAuthMiddleware(t *testing.T) { + validDecoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return authentication.Claims{"sub": "user123"}, nil + }, + } + + tests := []struct { + name string + decoder authentication.TokenDecoder + enabled bool + wantErr bool + errContains string + }{ + { + name: "valid decoder enabled", + decoder: validDecoder, + enabled: true, + wantErr: false, + }, + { + name: "valid decoder disabled", + decoder: validDecoder, + enabled: false, + wantErr: false, + }, + { + name: "nil decoder", + decoder: nil, + enabled: true, + wantErr: true, + errContains: "token decoder must be provided", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, []string{"mcp:tools"}) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + assert.Nil(t, middleware) + } else { + require.NoError(t, err) + require.NotNil(t, middleware) + assert.Equal(t, tt.enabled, middleware.enabled) + assert.NotNil(t, middleware.authenticator) + } + }) + } +} + +func TestMCPAuthMiddleware_ToolMiddleware(t *testing.T) { + validClaims := authentication.Claims{"sub": "user123", "email": "user@example.com"} + + tests := []struct { + name string + enabled bool + decoder *mockTokenDecoder + setupHeaders func() http.Header + wantErr bool + wantTextContain string + }{ + { + name: "bypasses auth when disabled", + enabled: false, + decoder: &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + t.Fatal("should not be called") + return nil, nil + }, + }, + setupHeaders: func() http.Header { + return http.Header{} + }, + wantErr: false, + wantTextContain: "no authentication", + }, + { + name: "valid Bearer token", + enabled: true, + decoder: &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return validClaims, nil + } + return nil, errors.New("invalid token") + }, + }, + setupHeaders: func() http.Header { + h := http.Header{} + h.Set("Authorization", "Bearer valid-token") + return h + }, + wantErr: false, + wantTextContain: "authenticated with claims", + }, + { + name: "invalid token", + enabled: true, + decoder: &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return nil, errors.New("token validation failed") + }, + }, + setupHeaders: func() http.Header { + h := http.Header{} + h.Set("Authorization", "Bearer invalid-token") + return h + }, + wantErr: true, + wantTextContain: "Authentication required", + }, + { + name: "wrong header format", + enabled: true, + decoder: &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return validClaims, nil + }, + }, + setupHeaders: func() http.Header { + h := http.Header{} + h.Set("Authorization", "invalid-token") + return h + }, + wantErr: true, + wantTextContain: "Authentication required", + }, + { + name: "Bearer token with whitespace", + enabled: true, + decoder: &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return validClaims, nil + } + return nil, fmt.Errorf("unexpected token: %s", token) + }, + }, + setupHeaders: func() http.Header { + h := http.Header{} + h.Set("Authorization", "Bearer valid-token ") + return h + }, + wantErr: false, + wantTextContain: "authenticated with claims", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, []string{}) + require.NoError(t, err) + + handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + claims, ok := GetClaimsFromContext(ctx) + if ok { + return mcp.NewToolResultText(fmt.Sprintf("authenticated with claims: %v", claims)), nil + } + return mcp.NewToolResultText("no authentication"), nil + }) + + ctx := withRequestHeaders(context.Background(), tt.setupHeaders()) + result, err := handler(ctx, mcp.CallToolRequest{}) + + require.NoError(t, err) + assert.Equal(t, tt.wantErr, result.IsError) + assert.Contains(t, getTextFromResult(result), tt.wantTextContain) + }) + } +} + +func TestMCPAuthMiddleware_MissingHeaders(t *testing.T) { + decoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return authentication.Claims{"sub": "user123"}, nil + }, + } + + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + require.NoError(t, err) + + handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("success"), nil + }) + + // Context without headers + result, err := handler(context.Background(), mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getTextFromResult(result), "missing request headers") +} + +func TestGetClaimsFromContext(t *testing.T) { + expectedClaims := authentication.Claims{"sub": "user123", "email": "user@example.com"} + + tests := []struct { + name string + setupCtx func() context.Context + wantOk bool + wantClaims authentication.Claims + }{ + { + name: "claims present", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), userClaimsContextKey, expectedClaims) + }, + wantOk: true, + wantClaims: expectedClaims, + }, + { + name: "claims absent", + setupCtx: func() context.Context { + return context.Background() + }, + wantOk: false, + wantClaims: nil, + }, + { + name: "wrong type", + setupCtx: func() context.Context { + return context.WithValue(context.Background(), userClaimsContextKey, "not-claims") + }, + wantOk: false, + wantClaims: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims, ok := GetClaimsFromContext(tt.setupCtx()) + assert.Equal(t, tt.wantOk, ok) + assert.Equal(t, tt.wantClaims, claims) + }) + } +} + +func TestMCPAuthProvider(t *testing.T) { + t.Run("returns headers", func(t *testing.T) { + headers := http.Header{} + headers.Set("Authorization", "Bearer token") + headers.Set("X-Custom", "value") + + provider := &mcpAuthProvider{headers: headers} + assert.Equal(t, headers, provider.AuthenticationHeaders()) + }) + + t.Run("empty headers", func(t *testing.T) { + provider := &mcpAuthProvider{headers: http.Header{}} + assert.Equal(t, 0, len(provider.AuthenticationHeaders())) + }) +} + +func TestMCPAuthMiddleware_Integration(t *testing.T) { + expectedClaims := authentication.Claims{"sub": "user123", "role": "admin"} + + decoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return expectedClaims, nil + } + return nil, errors.New("invalid token") + }, + } + + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + require.NoError(t, err) + + handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + claims, ok := GetClaimsFromContext(ctx) + if !ok { + return mcp.NewToolResultError("no claims found"), nil + } + return mcp.NewToolResultText(fmt.Sprintf("user: %s, role: %s", claims["sub"], claims["role"])), nil + }) + + // Valid token + headers := http.Header{} + headers.Set("Authorization", "Bearer valid-token") + ctx := withRequestHeaders(context.Background(), headers) + + result, err := handler(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.False(t, result.IsError) + text := getTextFromResult(result) + assert.Contains(t, text, "user: user123") + assert.Contains(t, text, "role: admin") + + // Invalid token + headers.Set("Authorization", "Bearer invalid-token") + ctx = withRequestHeaders(context.Background(), headers) + + result, err = handler(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getTextFromResult(result), "Authentication required") +} + +func TestMCPAuthMiddleware_ScopeValidation(t *testing.T) { + tests := []struct { + name string + requiredScopes []string + tokenScopes string + wantErr bool + wantTextContain string + }{ + { + name: "no required scopes, token with no scopes", + requiredScopes: []string{}, + tokenScopes: "", + wantErr: false, + wantTextContain: "authenticated with claims", + }, + { + name: "no required scopes, token with scopes", + requiredScopes: []string{}, + tokenScopes: "some:scope another:scope", + wantErr: false, + wantTextContain: "authenticated with claims", + }, + { + name: "one required scope, token with no scopes", + requiredScopes: []string{"mcp:tools"}, + tokenScopes: "", + wantErr: true, + wantTextContain: "missing required scopes: mcp:tools", + }, + { + name: "one required scope, token has required scope", + requiredScopes: []string{"mcp:tools"}, + tokenScopes: "mcp:tools", + wantErr: false, + wantTextContain: "authenticated with claims", + }, + { + name: "one required scope, token missing required scope", + requiredScopes: []string{"mcp:tools"}, + tokenScopes: "mcp:read", + wantErr: true, + wantTextContain: "missing required scopes: mcp:tools", + }, + { + name: "multiple required scopes, token with no scopes", + requiredScopes: []string{"mcp:tools", "mcp:read"}, + tokenScopes: "", + wantErr: true, + wantTextContain: "missing required scopes: mcp:tools, mcp:read", + }, + { + name: "multiple required scopes, token with partial match", + requiredScopes: []string{"mcp:tools", "mcp:read"}, + tokenScopes: "mcp:tools", + wantErr: true, + wantTextContain: "missing required scopes: mcp:read", + }, + { + name: "multiple required scopes, token has all required scopes", + requiredScopes: []string{"mcp:tools", "mcp:read"}, + tokenScopes: "mcp:tools mcp:read", + wantErr: false, + wantTextContain: "authenticated with claims", + }, + { + name: "multiple required scopes, token with partial match (multiple missing)", + requiredScopes: []string{"mcp:tools", "mcp:read", "mcp:admin"}, + tokenScopes: "mcp:tools mcp:write", + wantErr: true, + wantTextContain: "missing required scopes: mcp:read, mcp:admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + claims := authentication.Claims{ + "sub": "user123", + "email": "user@example.com", + } + if tt.tokenScopes != "" { + claims["scope"] = tt.tokenScopes + } + return claims, nil + } + return nil, errors.New("invalid token") + }, + } + + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, tt.requiredScopes) + require.NoError(t, err) + + handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + claims, ok := GetClaimsFromContext(ctx) + if ok { + return mcp.NewToolResultText(fmt.Sprintf("authenticated with claims: %v", claims)), nil + } + return mcp.NewToolResultText("no authentication"), nil + }) + + headers := http.Header{} + headers.Set("Authorization", "Bearer valid-token") + ctx := withRequestHeaders(context.Background(), headers) + + result, err := handler(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.Equal(t, tt.wantErr, result.IsError) + assert.Contains(t, getTextFromResult(result), tt.wantTextContain) + }) + } +} + +func TestExtractScopes(t *testing.T) { + tests := []struct { + name string + claims authentication.Claims + want []string + }{ + { + name: "scope as space-separated string (OAuth 2.0 standard)", + claims: authentication.Claims{ + "scope": "mcp:tools mcp:read mcp:write", + }, + want: []string{"mcp:tools", "mcp:read", "mcp:write"}, + }, + { + name: "scope with single value", + claims: authentication.Claims{ + "scope": "mcp:tools", + }, + want: []string{"mcp:tools"}, + }, + { + name: "scope with extra whitespace", + claims: authentication.Claims{ + "scope": " mcp:tools mcp:read mcp:write ", + }, + want: []string{"mcp:tools", "mcp:read", "mcp:write"}, + }, + { + name: "scope with tabs and newlines", + claims: authentication.Claims{ + "scope": "mcp:tools\t\nmcp:read\n\tmcp:write", + }, + want: []string{"mcp:tools", "mcp:read", "mcp:write"}, + }, + { + name: "scope with multiple spaces between values", + claims: authentication.Claims{ + "scope": "mcp:tools mcp:read mcp:write", + }, + want: []string{"mcp:tools", "mcp:read", "mcp:write"}, + }, + { + name: "no scope claim", + claims: authentication.Claims{}, + want: []string{}, + }, + { + name: "empty scope string", + claims: authentication.Claims{ + "scope": "", + }, + want: []string{}, + }, + { + name: "scope with only whitespace", + claims: authentication.Claims{ + "scope": " \t\n ", + }, + want: []string{}, + }, + { + name: "scope claim with wrong type (number)", + claims: authentication.Claims{ + "scope": 123, + }, + want: []string{}, + }, + { + name: "scope claim with wrong type (array)", + claims: authentication.Claims{ + "scope": []string{"mcp:tools", "mcp:read"}, + }, + want: []string{}, + }, + { + name: "scope claim with wrong type (object)", + claims: authentication.Claims{ + "scope": map[string]string{"key": "value"}, + }, + want: []string{}, + }, + { + name: "nil claims", + claims: nil, + want: []string{}, + }, + { + name: "complex scopes with colons", + claims: authentication.Claims{ + "scope": "mcp:tools:read mcp:tools:write api:v1:access", + }, + want: []string{"mcp:tools:read", "mcp:tools:write", "api:v1:access"}, + }, + { + name: "scopes with URLs", + claims: authentication.Claims{ + "scope": "https://api.example.com/read https://api.example.com/write", + }, + want: []string{"https://api.example.com/read", "https://api.example.com/write"}, + }, + { + name: "scopes with special characters", + claims: authentication.Claims{ + "scope": "read:users write:users delete:users", + }, + want: []string{"read:users", "write:users", "delete:users"}, + }, + { + name: "other claims present but no scope", + claims: authentication.Claims{ + "sub": "user123", + "email": "user@example.com", + "aud": "https://api.example.com", + }, + want: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractScopes(tt.claims) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + requiredScopes []string + setupDecoder func() *mockTokenDecoder + setupRequest func() *http.Request + wantStatusCode int + wantWWWAuthenticate string + wantWWWAuthenticatePrefix string + wantBody string + }{ + { + name: "valid token without scopes", + requiredScopes: []string{}, + setupDecoder: func() *mockTokenDecoder { + return &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return authentication.Claims{"sub": "user123"}, nil + } + return nil, errors.New("invalid token") + }, + } + }, + setupRequest: func() *http.Request { + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer valid-token") + return req + }, + wantStatusCode: http.StatusOK, + }, + { + name: "missing authorization header", + requiredScopes: []string{}, + setupDecoder: func() *mockTokenDecoder { + return &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return nil, errors.New("missing authorization header") + }, + } + }, + setupRequest: func() *http.Request { + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + return req + }, + wantStatusCode: http.StatusUnauthorized, + wantWWWAuthenticatePrefix: `Bearer realm="mcp", resource_metadata="` + testMetadataURL + `"`, + wantBody: "", // No JSON-RPC body per MCP spec + }, + { + name: "invalid token", + requiredScopes: []string{}, + setupDecoder: func() *mockTokenDecoder { + return &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return nil, errors.New("token validation failed") + }, + } + }, + setupRequest: func() *http.Request { + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + return req + }, + wantStatusCode: http.StatusUnauthorized, + wantWWWAuthenticatePrefix: `Bearer realm="mcp", resource_metadata="` + testMetadataURL + `"`, + wantBody: "", // No JSON-RPC body per MCP spec + }, + { + name: "valid token but insufficient scopes", + requiredScopes: []string{"mcp:tools:write", "mcp:admin"}, + setupDecoder: func() *mockTokenDecoder { + return &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return authentication.Claims{ + "sub": "user123", + "scope": "mcp:tools:read", + }, nil + } + return nil, errors.New("invalid token") + }, + } + }, + setupRequest: func() *http.Request { + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer valid-token") + return req + }, + wantStatusCode: http.StatusForbidden, + wantWWWAuthenticatePrefix: `Bearer error="insufficient_scope", scope="mcp:tools:write mcp:admin", resource_metadata="` + testMetadataURL + `"`, + wantBody: "", // No JSON-RPC body per MCP spec + }, + { + name: "valid token with all required scopes", + requiredScopes: []string{"mcp:tools:read", "mcp:tools:write"}, + setupDecoder: func() *mockTokenDecoder { + return &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + if token == "valid-token" { + return authentication.Claims{ + "sub": "user123", + "scope": "mcp:tools:read mcp:tools:write mcp:admin", + }, nil + } + return nil, errors.New("invalid token") + }, + } + }, + setupRequest: func() *http.Request { + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer valid-token") + return req + }, + wantStatusCode: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewMCPAuthMiddleware(tt.setupDecoder(), true, testMetadataURL, tt.requiredScopes) + require.NoError(t, err) + + // Create a test handler that sets status 200 if reached + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Wrap with auth middleware + handler := middleware.HTTPMiddleware(testHandler) + + // Create response recorder + rr := httptest.NewRecorder() + + // Execute request + handler.ServeHTTP(rr, tt.setupRequest()) + + // Verify status code + assert.Equal(t, tt.wantStatusCode, rr.Code, "status code mismatch") + + // Verify WWW-Authenticate header for auth failures + if tt.wantWWWAuthenticatePrefix != "" { + authHeader := rr.Header().Get("WWW-Authenticate") + assert.NotEmpty(t, authHeader, "WWW-Authenticate header should be present") + assert.Contains(t, authHeader, tt.wantWWWAuthenticatePrefix, "WWW-Authenticate header should match expected format") + + // Verify resource_metadata is present (per MCP spec) + assert.Contains(t, authHeader, "resource_metadata=", "resource_metadata should be in WWW-Authenticate header") + } + + // Verify no JSON-RPC response body for HTTP-level auth failures + if tt.wantStatusCode == http.StatusUnauthorized || tt.wantStatusCode == http.StatusForbidden { + body := rr.Body.String() + assert.Equal(t, "", body, "HTTP-level auth failures should not return JSON-RPC response body per MCP spec") + } + }) + } +} + +func TestMCPAuthMiddleware_HTTPMiddleware_WWWAuthenticateFormat(t *testing.T) { + t.Parallel() + + t.Run("401 response has correct WWW-Authenticate format", func(t *testing.T) { + decoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return nil, errors.New("invalid token") + }, + } + + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + require.NoError(t, err) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := middleware.HTTPMiddleware(testHandler) + + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + + // Parse WWW-Authenticate header properly + authHeader := rr.Header().Get("WWW-Authenticate") + require.NotEmpty(t, authHeader, "WWW-Authenticate header must be present") + + params := parseWWWAuthenticateParams(authHeader) + + // Verify expected fields per RFC 6750 + assert.Equal(t, "mcp", params["realm"], "realm should be 'mcp'") + assert.Equal(t, testMetadataURL, params["resource_metadata"], "resource_metadata must be present for OAuth discovery") + assert.NotEmpty(t, params["error_description"], "error_description should provide details") + }) + + t.Run("403 response has correct WWW-Authenticate format per RFC 6750", func(t *testing.T) { + decoder := &mockTokenDecoder{ + decodeFunc: func(token string) (authentication.Claims, error) { + return authentication.Claims{ + "sub": "user123", + "scope": "mcp:read", + }, nil + }, + } + + requiredScopes := []string{"mcp:tools:write", "mcp:admin"} + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, requiredScopes) + require.NoError(t, err) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := middleware.HTTPMiddleware(testHandler) + + req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer valid-token") + + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code) + + // Parse WWW-Authenticate header properly + authHeader := rr.Header().Get("WWW-Authenticate") + require.NotEmpty(t, authHeader, "WWW-Authenticate header must be present") + + params := parseWWWAuthenticateParams(authHeader) + + // Per RFC 6750 Section 3.1: Verify all required fields + assert.Equal(t, "insufficient_scope", params["error"], "error parameter must be 'insufficient_scope'") + assert.Equal(t, "mcp:tools:write mcp:admin", params["scope"], "scope parameter must list required scopes") + assert.Equal(t, testMetadataURL, params["resource_metadata"], "resource_metadata must be present") + assert.NotEmpty(t, params["error_description"], "error_description should provide details") + }) +} diff --git a/router/pkg/mcpserver/errors.go b/router/pkg/mcpserver/errors.go new file mode 100644 index 0000000000..d45978a1bd --- /dev/null +++ b/router/pkg/mcpserver/errors.go @@ -0,0 +1,34 @@ +package mcpserver + +// JSON-RPC 2.0 and MCP error codes +// +// Error code ranges: +// - Standard JSON-RPC 2.0: -32768 to -32000 +// - Server errors (implementation-defined): -32000 to -32099 +// - Application errors: -32768 to -32000 (excluding reserved range) +const ( + // Standard JSON-RPC 2.0 error codes + ErrorCodeParseError = -32700 // Invalid JSON was received by the server + ErrorCodeInvalidRequest = -32600 // The JSON sent is not a valid Request object + ErrorCodeMethodNotFound = -32601 // The method does not exist / is not available + ErrorCodeInvalidParams = -32602 // Invalid method parameter(s) + ErrorCodeInternalError = -32603 // Internal JSON-RPC error + + // MCP-specific error codes (from MCP specification) + // See: https://spec.modelcontextprotocol.io/specification/basic/errors/ + ErrorCodeResourceNotFound = -32002 // Requested resource was not found + + // Custom Cosmo MCP server error codes + // These use the reserved range -32000 to -32099 for implementation-defined server errors + ErrorCodeAuthenticationRequired = -32001 // Authentication required (OAuth/JWT) + ErrorCodeInsufficientScope = -32003 // Token lacks required OAuth scopes (RFC 6750) +) + +// Error messages +const ( + ErrorMessageAuthenticationRequired = "Authentication required" + ErrorMessageInsufficientScope = "Insufficient scope" + ErrorMessageResourceNotFound = "Resource not found" + ErrorMessageInvalidParams = "Invalid params" + ErrorMessageInternalError = "Internal error" +) diff --git a/router/pkg/mcpserver/operation_manager.go b/router/pkg/mcpserver/operation_manager.go index 0bbe2e15d6..643ac48e13 100644 --- a/router/pkg/mcpserver/operation_manager.go +++ b/router/pkg/mcpserver/operation_manager.go @@ -3,9 +3,11 @@ package mcpserver import ( "fmt" + "go.uber.org/zap" + "github.com/wundergraph/cosmo/router/pkg/schemaloader" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "go.uber.org/zap" ) // OperationsManager handles the loading and preparation of GraphQL operations diff --git a/router/pkg/mcpserver/schema_compiler.go b/router/pkg/mcpserver/schema_compiler.go index 2bfcf79966..11816bd196 100644 --- a/router/pkg/mcpserver/schema_compiler.go +++ b/router/pkg/mcpserver/schema_compiler.go @@ -62,7 +62,7 @@ func (sc *SchemaCompiler) ValidateInput(data []byte, compiledSchema *jsonschema. return nil } - var v interface{} + var v any if err := json.Unmarshal(data, &v); err != nil { return fmt.Errorf("failed to parse JSON input: %w", err) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 075ba20c42..e9dea49c54 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -17,11 +17,15 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/santhosh-tekuri/jsonschema/v6" + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/pkg/authentication" + "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/schemaloader" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" - "go.uber.org/zap" ) // requestHeadersKey is a custom context key for storing request headers. @@ -97,6 +101,10 @@ type Options struct { Stateless bool // CorsConfig is the CORS configuration for the MCP server CorsConfig cors.Config + // OAuthConfig is the OAuth/JWKS configuration for authentication + OAuthConfig *config.MCPOAuthConfiguration + // ServerBaseURL is the base URL of this MCP server (for resource metadata) + ServerBaseURL string } // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations @@ -119,6 +127,11 @@ type GraphQLSchemaServer struct { schemaCompiler *SchemaCompiler registeredTools []string corsConfig cors.Config + ctx context.Context + cancel context.CancelFunc + oauthConfig *config.MCPOAuthConfiguration + serverBaseURL string + authMiddleware *MCPAuthMiddleware } type graphqlRequest struct { @@ -191,7 +204,6 @@ type GraphQLResponse struct { // NewGraphQLSchemaServer creates a new GraphQL schema server func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options)) (*GraphQLSchemaServer, error) { - if routerGraphQLEndpoint == "" { return nil, fmt.Errorf("routerGraphQLEndpoint cannot be empty") } @@ -217,16 +229,121 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) opt(options) } - // Create the MCP server - mcpServer := server.NewMCPServer( - "wundergraph-cosmo-"+strcase.ToKebab(options.GraphName), - "0.0.1", - // Prompt, Resources aren't supported yet in any of the popular platforms + // Create a cancellable context for managing the server lifecycle + ctx, cancel := context.WithCancel(context.Background()) + + // Prepare server options + var serverOpts []server.ServerOption + serverOpts = append(serverOpts, server.WithToolCapabilities(true), server.WithPaginationLimit(100), server.WithRecovery(), ) + // Add authentication middleware if OAuth is configured + if options.OAuthConfig != nil && options.OAuthConfig.Enabled && len(options.OAuthConfig.JWKS) > 0 { + // Convert config.JWKSConfiguration to authentication.JWKSConfig + authConfigs := make([]authentication.JWKSConfig, 0, len(options.OAuthConfig.JWKS)) + for _, jwks := range options.OAuthConfig.JWKS { + authConfigs = append(authConfigs, authentication.JWKSConfig{ + URL: jwks.URL, + RefreshInterval: jwks.RefreshInterval, + AllowedAlgorithms: jwks.Algorithms, + Secret: jwks.Secret, + Algorithm: jwks.Algorithm, + KeyId: jwks.KeyId, + Audiences: jwks.Audiences, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: jwks.RefreshUnknownKID.Enabled, + MaxWait: jwks.RefreshUnknownKID.MaxWait, + Interval: jwks.RefreshUnknownKID.Interval, + Burst: jwks.RefreshUnknownKID.Burst, + }, + }) + } + + // Create token decoder using the managed context for proper lifecycle management + tokenDecoder, err := authentication.NewJwksTokenDecoder( + ctx, + options.Logger, + authConfigs, + ) + if err != nil { + cancel() // Clean up the context if initialization fails + return nil, fmt.Errorf("failed to create token decoder: %w", err) + } + + // Build resource metadata URL for WWW-Authenticate header + resourceMetadataURL := "" + if options.ServerBaseURL != "" { + resourceMetadataURL = fmt.Sprintf("%s/.well-known/oauth-protected-resource", options.ServerBaseURL) + } + + // Get HTTP-level required scopes from the "initialize" key + // These scopes are required for ANY HTTP request (including initialize) + httpRequiredScopes := options.OAuthConfig.ScopesRequired["initialize"] + if httpRequiredScopes == nil { + httpRequiredScopes = []string{} + } + + // Create authentication middleware with HTTP-level required scopes + // Per-tool scope authorization happens at the tool level + authMiddleware, err := NewMCPAuthMiddleware(tokenDecoder, true, resourceMetadataURL, httpRequiredScopes) + if err != nil { + cancel() // Clean up the context if initialization fails + return nil, fmt.Errorf("failed to create auth middleware: %w", err) + } + + // Store auth middleware for HTTP-level protection + // Note: We don't use WithToolHandlerMiddleware here because per MCP spec, + // ALL HTTP requests must be authenticated, not just tool calls + options.Logger.Info("MCP OAuth authentication enabled", + zap.Int("jwks_providers", len(options.OAuthConfig.JWKS)), + zap.String("authorization_server", options.OAuthConfig.AuthorizationServerURL)) + + // Create the MCP server with all options + mcpServer := server.NewMCPServer( + "wundergraph-cosmo-"+strcase.ToKebab(options.GraphName), + "0.0.1", + serverOpts..., + ) + + retryClient := retryablehttp.NewClient() + retryClient.Logger = nil + httpClient := retryClient.StandardClient() + httpClient.Timeout = 60 * time.Second + + gs := &GraphQLSchemaServer{ + server: mcpServer, + graphName: options.GraphName, + operationsDir: options.OperationsDir, + listenAddr: options.ListenAddr, + logger: options.Logger, + httpClient: httpClient, + requestTimeout: options.RequestTimeout, + routerGraphQLEndpoint: routerGraphQLEndpoint, + excludeMutations: options.ExcludeMutations, + enableArbitraryOperations: options.EnableArbitraryOperations, + exposeSchema: options.ExposeSchema, + stateless: options.Stateless, + corsConfig: options.CorsConfig, + ctx: ctx, + cancel: cancel, + oauthConfig: options.OAuthConfig, + serverBaseURL: options.ServerBaseURL, + authMiddleware: authMiddleware, + } + + return gs, nil + } + + // Create the MCP server with all options + mcpServer := server.NewMCPServer( + "wundergraph-cosmo-"+strcase.ToKebab(options.GraphName), + "0.0.1", + serverOpts..., + ) + retryClient := retryablehttp.NewClient() retryClient.Logger = nil httpClient := retryClient.StandardClient() @@ -247,6 +364,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) omitToolNamePrefix: options.OmitToolNamePrefix, stateless: options.Stateless, corsConfig: options.CorsConfig, + ctx: ctx, + cancel: cancel, + oauthConfig: options.OAuthConfig, + serverBaseURL: options.ServerBaseURL, + authMiddleware: nil, // No auth middleware when OAuth is disabled } return gs, nil @@ -332,6 +454,20 @@ func WithCORS(corsCfg cors.Config) func(*Options) { } } +// WithOAuth sets the OAuth configuration +func WithOAuth(oauthCfg *config.MCPOAuthConfiguration) func(*Options) { + return func(o *Options) { + o.OAuthConfig = oauthCfg + } +} + +// WithServerBaseURL sets the server base URL for OAuth discovery +func WithServerBaseURL(baseURL string) func(*Options) { + return func(o *Options) { + o.ServerBaseURL = baseURL + } +} + // Serve starts the server with the configured options and returns a streamable HTTP server. func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { // Create custom HTTP server @@ -354,10 +490,29 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { mux := http.NewServeMux() - // No OAuth protection - original behavior - mux.Handle("/mcp", middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // OAuth 2.0 Protected Resource Metadata endpoint (RFC 9728) + // This endpoint is required for MCP clients to discover the authorization server + // This endpoint is NOT protected by authentication (it's public discovery) + if s.oauthConfig != nil && s.oauthConfig.Enabled && s.oauthConfig.AuthorizationServerURL != "" { + mux.Handle("/.well-known/oauth-protected-resource", middleware(http.HandlerFunc(s.handleProtectedResourceMetadata))) + s.logger.Info("OAuth 2.0 Protected Resource Metadata endpoint enabled", + zap.String("path", "/.well-known/oauth-protected-resource"), + zap.String("authorization_server", s.oauthConfig.AuthorizationServerURL)) + } + + // MCP endpoint with HTTP-level authentication + // Per MCP spec: "authorization MUST be included in every HTTP request from client to server" + mcpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { streamableHTTPServer.ServeHTTP(w, r) - }))) + }) + + // Apply authentication middleware if OAuth is enabled + if s.authMiddleware != nil { + mux.Handle("/mcp", middleware(s.authMiddleware.HTTPMiddleware(mcpHandler))) + s.logger.Info("MCP endpoint protected with OAuth authentication at HTTP level") + } else { + mux.Handle("/mcp", middleware(mcpHandler)) + } // Set the handler for the custom HTTP server httpServer.Handler = mux @@ -388,7 +543,6 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { // Start loads operations and starts the server func (s *GraphQLSchemaServer) Start() error { - ss, err := s.Serve() if err != nil { return fmt.Errorf("failed to create HTTP server: %w", err) @@ -401,7 +555,6 @@ func (s *GraphQLSchemaServer) Start() error { // Reload reloads the operations and schema func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { - if s.server == nil { return fmt.Errorf("server is not started") } @@ -432,6 +585,11 @@ func (s *GraphQLSchemaServer) Stop(ctx context.Context) error { s.logger.Debug("shutting down MCP server") + // Cancel the server's context to stop background operations (e.g., JWKS key refresh) + if s.cancel != nil { + s.cancel() + } + // Create a shutdown context with timeout shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -445,18 +603,28 @@ func (s *GraphQLSchemaServer) Stop(ctx context.Context) error { // registerTools registers all tools for the MCP server func (s *GraphQLSchemaServer) registerTools() error { - // Only register the schema tool if exposeSchema is enabled if s.exposeSchema { + // Create a schema with empty properties since get_schema takes no input + // Note: We omit "required" field to get nil instead of empty array + getSchemaInputSchema := []byte(`{ + "type": "object", + "properties": {} + }`) + + tool := mcp.NewToolWithRawSchema( + "get_schema", + "Provides the full GraphQL schema of the API.", + getSchemaInputSchema, + ) + + tool.Annotations = mcp.ToolAnnotation{ + Title: "Get GraphQL Schema", + ReadOnlyHint: mcp.ToBoolPtr(true), + } + s.server.AddTool( - mcp.NewTool( - "get_schema", - mcp.WithDescription("Provides the full GraphQL schema of the API."), - mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: "Get GraphQL Schema", - ReadOnlyHint: mcp.ToBoolPtr(true), - }), - ), + tool, s.handleGetGraphQLSchema(), ) @@ -508,7 +676,6 @@ func (s *GraphQLSchemaServer) registerTools() error { ) s.registeredTools = append(s.registeredTools, "execute_graphql") - } // Get operations filtered by the excludeMutations setting @@ -616,6 +783,13 @@ func (s *GraphQLSchemaServer) registerTools() error { // handleOperation handles a specific operation func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Log authenticated user if OAuth is enabled + if claims, ok := GetClaimsFromContext(ctx); ok { + s.logger.Debug("operation called by authenticated user", + zap.String("sub", getClaimString(claims, "sub")), + zap.String("email", getClaimString(claims, "email")), + zap.String("operation", handler.operation.Name)) + } jsonBytes, err := json.Marshal(request.GetArguments()) if err != nil { @@ -791,6 +965,13 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str // handleExecuteGraphQL returns a handler function that executes arbitrary GraphQL queries func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Log authenticated user if OAuth is enabled + if claims, ok := GetClaimsFromContext(ctx); ok { + s.logger.Debug("arbitrary GraphQL query called by authenticated user", + zap.String("sub", getClaimString(claims, "sub")), + zap.String("email", getClaimString(claims, "email"))) + } + // Parse the JSON input jsonBytes, err := json.Marshal(request.GetArguments()) if err != nil { @@ -828,3 +1009,85 @@ func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, return mcp.NewToolResultText(schemaStr), nil } } + +// getClaimString safely extracts a string value from claims +func getClaimString(claims authentication.Claims, key string) string { + if val, ok := claims[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return "" +} + +// ProtectedResourceMetadata represents the OAuth 2.0 Protected Resource Metadata (RFC 9728) +type ProtectedResourceMetadata struct { + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + ResourceDocumentation string `json:"resource_documentation,omitempty"` + ScopesSupported []string `json:"scopes_supported"` +} + +// handleProtectedResourceMetadata handles the OAuth 2.0 Protected Resource Metadata endpoint +// as specified in RFC 9728. This endpoint allows MCP clients to discover the authorization +// server(s) associated with this resource server. +func (s *GraphQLSchemaServer) handleProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Determine the resource URL (this MCP server's base URL) + resourceURL := s.serverBaseURL + if resourceURL == "" { + // Fallback: construct from request if not configured + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + resourceURL = fmt.Sprintf("%s://%s", scheme, r.Host) + } + + // Build scopes_supported from all required scopes (union of all scopes in the map) + scopesSet := make(map[string]bool) + for _, requiredScopes := range s.oauthConfig.ScopesRequired { + for _, scope := range requiredScopes { + scopesSet[scope] = true + } + } + + // Convert set to sorted slice for consistent output + scopes := make([]string, 0, len(scopesSet)) + for scope := range scopesSet { + scopes = append(scopes, scope) + } + if len(scopes) == 0 { + scopes = []string{} // Ensure non-nil for JSON encoding + } + + metadata := ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{s.oauthConfig.AuthorizationServerURL}, + BearerMethodsSupported: []string{"header"}, + ResourceDocumentation: fmt.Sprintf("%s/mcp", resourceURL), + ScopesSupported: scopes, // Automatically derived from required scopes + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(w).Encode(metadata); err != nil { + s.logger.Error("failed to encode protected resource metadata", zap.Error(err)) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } +} + +// GetResourceMetadataURL returns the URL for the OAuth 2.0 Protected Resource Metadata endpoint +func (s *GraphQLSchemaServer) GetResourceMetadataURL() string { + if s.serverBaseURL != "" { + return fmt.Sprintf("%s/.well-known/oauth-protected-resource", s.serverBaseURL) + } + return "" +} From b2c158b482e96a384d1bbe38783cef52cb64bd36 Mon Sep 17 00:00:00 2001 From: Ahmet Soormally Date: Wed, 21 Jan 2026 10:14:32 +0000 Subject: [PATCH 2/5] feat(mcp): add per-tool scope verification and random port allocation for tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements per-tool OAuth scope validation at HTTP level by parsing JSON-RPC requests before processing. This allows granular authorization control for individual MCP tools while maintaining HTTP-level auth per spec. Key Changes: - Parse JSON-RPC request in HTTP middleware to extract tool name from tools/call - Check per-tool scopes from scopesRequired map when tools are called - Return HTTP 403 with WWW-Authenticate header for insufficient per-tool scopes - Simplified API: replaced separate requiredScopes with unified scopesRequired map - Special "initialize" key in scopesRequired for HTTP-level scopes (all requests) Test Infrastructure Improvements: - JWKSTestServer now uses freeport for automatic random port allocation - Prevents port conflicts when running tests in parallel - Added comprehensive e2e tests for per-tool scope enforcement: * HTTP-level scope verification on all requests * Per-tool scope verification on tool calls * Different tools requiring different scopes * Token scope upgrade on persistent sessions Test Coverage: - TestMCPOAuthPerToolScopes: 5 subtests verifying per-tool authorization - Verifies HTTP 403 responses with correct WWW-Authenticate headers - Tests token upgrade workflow: connect → insufficient scopes → upgrade → retry --- router-tests/mcp_oauth_e2e_test.go | 167 ++++++++++++++++++- router-tests/testutil/jwt_helper.go | 14 +- router/pkg/mcpserver/auth_middleware.go | 67 ++++++-- router/pkg/mcpserver/auth_middleware_test.go | 25 ++- router/pkg/mcpserver/server.go | 15 +- 5 files changed, 248 insertions(+), 40 deletions(-) diff --git a/router-tests/mcp_oauth_e2e_test.go b/router-tests/mcp_oauth_e2e_test.go index 578d86d0ee..39cfd8abd1 100644 --- a/router-tests/mcp_oauth_e2e_test.go +++ b/router-tests/mcp_oauth_e2e_test.go @@ -20,7 +20,7 @@ import ( // 4. Client can upgrade token and retry on the same MCP session func TestMCPOAuthScopeUpgrade(t *testing.T) { // Start JWKS test server - jwksServer, err := testutil.NewJWKSTestServer(t, "8765") + jwksServer, err := testutil.NewJWKSTestServer(t) require.NoError(t, err, "failed to start JWKS server") defer jwksServer.Close() //nolint:errcheck @@ -99,7 +99,7 @@ func TestMCPOAuthScopeUpgrade(t *testing.T) { // TestMCPOAuthInvalidToken tests that invalid JWT tokens are rejected with HTTP 401 func TestMCPOAuthInvalidToken(t *testing.T) { // Start JWKS test server - jwksServer, err := testutil.NewJWKSTestServer(t, "8766") + jwksServer, err := testutil.NewJWKSTestServer(t) require.NoError(t, err, "failed to start JWKS server") defer jwksServer.Close() //nolint:errcheck @@ -144,7 +144,7 @@ func TestMCPOAuthInvalidToken(t *testing.T) { // TestMCPOAuthMissingToken tests that missing Authorization header is rejected func TestMCPOAuthMissingToken(t *testing.T) { // Start JWKS test server - jwksServer, err := testutil.NewJWKSTestServer(t, "8767") + jwksServer, err := testutil.NewJWKSTestServer(t) require.NoError(t, err, "failed to start JWKS server") defer jwksServer.Close() //nolint:errcheck @@ -184,4 +184,165 @@ func TestMCPOAuthMissingToken(t *testing.T) { t.Logf("✓ Request without token rejected with HTTP 401: %v", authErr) } }) +} + +// TestMCPOAuthPerToolScopes tests per-tool scope requirements +// This test verifies: +// 1. HTTP-level scopes (from "initialize" key) are checked on all requests +// 2. Per-tool scopes are checked when specific tools are called +// 3. HTTP 403 with WWW-Authenticate header is returned for insufficient scopes +func TestMCPOAuthPerToolScopes(t *testing.T) { + // Start JWKS test server + jwksServer, err := testutil.NewJWKSTestServer(t) + require.NoError(t, err, "failed to start JWKS server") + defer jwksServer.Close() //nolint:errcheck + + // Create token with basic scopes for initialization + initToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect"}) + require.NoError(t, err, "failed to create init token") + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + ExposeSchema: true, // Enable get_schema tool + EnableArbitraryOperations: true, // Enable execute_graphql tool + OAuth: config.MCPOAuthConfiguration{ + Enabled: true, + JWKS: []config.JWKSConfiguration{ + { + URL: jwksServer.JWKSURL(), + }, + }, + AuthorizationServerURL: jwksServer.Issuer(), + ScopesRequired: map[string][]string{ + "initialize": {"mcp:connect"}, // HTTP-level: required for all requests + "get_schema": {"mcp:tools:read"}, // Per-tool: read-only tool + "execute_graphql": {"mcp:tools:write"}, // Per-tool: write tool + }, + }, + }, + MCPAuthToken: initToken, // Pass token for testenv initialization + }, func(t *testing.T, xEnv *testenv.Environment) { + ctx := context.Background() + + t.Run("HTTP-level scopes are enforced on all requests", func(t *testing.T) { + // Token without "mcp:connect" scope should fail at HTTP level + noConnectToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:tools:read"}) + require.NoError(t, err) + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), noConnectToken) + err = client.Connect(ctx) + require.Error(t, err, "should fail to connect without HTTP-level scopes") + + // Check if it's an auth error with HTTP 403 + authErr, ok := err.(*AuthError) + if ok { + // Could be 401 or 403 depending on whether token is valid + assert.True(t, authErr.StatusCode == http.StatusUnauthorized || authErr.StatusCode == http.StatusForbidden) + t.Logf("✓ HTTP-level scope enforcement: %v", authErr) + } + }) + + t.Run("Per-tool scopes are enforced on tool calls", func(t *testing.T) { + // Token with connect but no read scope + connectOnlyToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect"}) + require.NoError(t, err) + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), connectOnlyToken) + err = client.Connect(ctx) + require.NoError(t, err, "should connect with HTTP-level scopes") + defer client.Close() //nolint:errcheck + + t.Log("✓ Connected with HTTP-level scopes only") + + // Try to call get_schema (requires mcp:tools:read) + _, err = client.CallTool(ctx, "get_schema", nil) + require.Error(t, err, "should fail without per-tool scopes") + + authErr, ok := err.(*AuthError) + require.True(t, ok, "should return AuthError") + assert.Equal(t, http.StatusForbidden, authErr.StatusCode, "should return HTTP 403") + assert.Equal(t, "insufficient_scope", authErr.ErrorCode) + assert.Contains(t, authErr.RequiredScopes, "mcp:tools:read") + t.Logf("✓ Per-tool scope enforcement: %v", authErr) + }) + + t.Run("Token with correct per-tool scopes succeeds", func(t *testing.T) { + // Token with both connect and read scopes + readToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect", "mcp:tools:read"}) + require.NoError(t, err) + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), readToken) + err = client.Connect(ctx) + require.NoError(t, err) + defer client.Close() //nolint:errcheck + + // Call get_schema (requires mcp:tools:read) - should succeed + result, err := client.CallTool(ctx, "get_schema", nil) + require.NoError(t, err, "should succeed with correct scopes") + require.NotNil(t, result) + t.Log("✓ Tool call succeeded with correct per-tool scopes") + }) + + t.Run("Different tools require different scopes", func(t *testing.T) { + // Token with read but no write scopes + readToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect", "mcp:tools:read"}) + require.NoError(t, err) + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), readToken) + err = client.Connect(ctx) + require.NoError(t, err) + defer client.Close() //nolint:errcheck + + // Call get_schema (read) - should succeed + _, err = client.CallTool(ctx, "get_schema", nil) + require.NoError(t, err, "read tool should succeed") + t.Log("✓ Read tool succeeded") + + // Call execute_graphql (write) - should fail + _, err = client.CallTool(ctx, "execute_graphql", map[string]any{ + "query": "query { __typename }", + }) + require.Error(t, err, "write tool should fail without write scopes") + + authErr, ok := err.(*AuthError) + require.True(t, ok) + assert.Equal(t, http.StatusForbidden, authErr.StatusCode) + assert.Contains(t, authErr.RequiredScopes, "mcp:tools:write") + t.Log("✓ Write tool rejected without write scopes") + }) + + t.Run("Scope upgrade on same session works", func(t *testing.T) { + // Start with read-only token + readToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect", "mcp:tools:read"}) + require.NoError(t, err) + + client := NewMCPAuthClient(xEnv.GetMCPServerAddr(), readToken) + err = client.Connect(ctx) + require.NoError(t, err) + defer client.Close() //nolint:errcheck + + // Try write operation - should fail + _, err = client.CallTool(ctx, "execute_graphql", map[string]any{ + "query": "query { __typename }", + }) + require.Error(t, err, "should fail without write scopes") + t.Log("✓ Write operation failed with read-only token") + + // Upgrade to token with write scopes + writeToken, err := jwksServer.CreateTokenWithScopes("test-user", []string{"mcp:connect", "mcp:tools:read", "mcp:tools:write"}) + require.NoError(t, err) + + client.SetToken(writeToken) + t.Log("✓ Upgraded token on same session") + + // Retry write operation - should succeed + result, err := client.CallTool(ctx, "execute_graphql", map[string]any{ + "query": "query { __typename }", + }) + require.NoError(t, err, "should succeed after scope upgrade") + require.NotNil(t, result) + t.Log("✓ Write operation succeeded after token upgrade") + }) + }) } \ No newline at end of file diff --git a/router-tests/testutil/jwt_helper.go b/router-tests/testutil/jwt_helper.go index f875469220..830a51e9bd 100644 --- a/router-tests/testutil/jwt_helper.go +++ b/router-tests/testutil/jwt_helper.go @@ -9,6 +9,7 @@ import ( "github.com/MicahParks/jwkset" "github.com/golang-jwt/jwt/v5" + "github.com/wundergraph/cosmo/router-tests/freeport" "github.com/wundergraph/cosmo/router-tests/jwks" ) @@ -25,9 +26,14 @@ type JWKSTestServer struct { } // NewJWKSTestServer creates a new JWKS test server with RSA keys -func NewJWKSTestServer(t *testing.T, port string) (*JWKSTestServer, error) { +// The server will automatically allocate a free port and return it when the test ends +func NewJWKSTestServer(t *testing.T) (*JWKSTestServer, error) { t.Helper() + // Get a free port using the freeport package + port := freeport.GetOne(t) + portStr := fmt.Sprintf("%d", port) + keyID := "test_rsa" provider, err := jwks.NewRSACrypto(keyID, jwkset.AlgRS256, 2048) if err != nil { @@ -50,9 +56,9 @@ func NewJWKSTestServer(t *testing.T, port string) (*JWKSTestServer, error) { t: t, provider: provider, keyID: keyID, - issuer: fmt.Sprintf("http://localhost:%s", port), + issuer: fmt.Sprintf("http://localhost:%s", portStr), audience: "test-audience", - jwksURL: fmt.Sprintf("http://localhost:%s/.well-known/jwks.json", port), + jwksURL: fmt.Sprintf("http://localhost:%s/.well-known/jwks.json", portStr), storage: storage, } @@ -61,7 +67,7 @@ func NewJWKSTestServer(t *testing.T, port string) (*JWKSTestServer, error) { mux.HandleFunc("/.well-known/jwks.json", server.handleJWKS) httpServer := &http.Server{ - Addr: ":" + port, + Addr: ":" + portStr, Handler: mux, } diff --git a/router/pkg/mcpserver/auth_middleware.go b/router/pkg/mcpserver/auth_middleware.go index 1f594a9b87..558279c9f0 100644 --- a/router/pkg/mcpserver/auth_middleware.go +++ b/router/pkg/mcpserver/auth_middleware.go @@ -1,8 +1,11 @@ package mcpserver import ( + "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "slices" "strings" @@ -33,12 +36,12 @@ type MCPAuthMiddleware struct { authenticator authentication.Authenticator enabled bool resourceMetadataURL string - requiredScopes []string // Minimal scopes required for any access + scopesRequired map[string][]string // Per-tool scope requirements; "initialize" key = HTTP-level scopes } // NewMCPAuthMiddleware creates a new authentication middleware using the existing // authentication infrastructure from the router -func NewMCPAuthMiddleware(tokenDecoder authentication.TokenDecoder, enabled bool, resourceMetadataURL string, requiredScopes []string) (*MCPAuthMiddleware, error) { +func NewMCPAuthMiddleware(tokenDecoder authentication.TokenDecoder, enabled bool, resourceMetadataURL string, scopesRequired map[string][]string) (*MCPAuthMiddleware, error) { if tokenDecoder == nil { return nil, fmt.Errorf("token decoder must be provided") } @@ -59,7 +62,7 @@ func NewMCPAuthMiddleware(tokenDecoder authentication.TokenDecoder, enabled bool authenticator: authenticator, enabled: enabled, resourceMetadataURL: resourceMetadataURL, - requiredScopes: requiredScopes, + scopesRequired: scopesRequired, }, nil } @@ -112,10 +115,8 @@ func (m *MCPAuthMiddleware) authenticateRequest(ctx context.Context) (authentica return nil, fmt.Errorf("authentication failed: no valid credentials provided") } - // Validate required scopes - if err := m.validateScopes(claims); err != nil { - return nil, err - } + // Note: Scope validation is now handled at HTTP level, not here + // This is per MCP spec: authorization must be at HTTP level return claims, nil } @@ -139,11 +140,45 @@ func (m *MCPAuthMiddleware) HTTPMiddleware(next http.Handler) http.Handler { return } - // Validate required scopes - if err := m.validateScopes(claims); err != nil { - m.sendInsufficientScopeResponse(w, err) + // Step 1: Validate HTTP-level required scopes (from "initialize" key) + initScopes := m.scopesRequired["initialize"] + if len(initScopes) > 0 { + if err := m.validateScopesForRequest(claims, initScopes); err != nil { + m.sendInsufficientScopeResponse(w, initScopes, err) + return + } + } + + // Step 2: Parse JSON-RPC request to check for tool-specific scopes + // Read body to extract tool name + body, err := io.ReadAll(r.Body) + if err != nil { + m.sendUnauthorizedResponse(w, fmt.Errorf("failed to read request body")) return } + // Restore body for downstream handlers + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + // Try to parse as JSON-RPC request + var jsonRPCReq struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + if err := json.Unmarshal(body, &jsonRPCReq); err == nil && jsonRPCReq.Method == "tools/call" { + // Extract tool name from params + var toolCallParams struct { + Name string `json:"name"` + } + if err := json.Unmarshal(jsonRPCReq.Params, &toolCallParams); err == nil && toolCallParams.Name != "" { + // Check if this tool has specific scope requirements + if toolScopes, exists := m.scopesRequired[toolCallParams.Name]; exists && len(toolScopes) > 0 { + if err := m.validateScopesForRequest(claims, toolScopes); err != nil { + m.sendInsufficientScopeResponse(w, toolScopes, err) + return + } + } + } + } // Add claims to request context for downstream handlers ctx := context.WithValue(r.Context(), userClaimsContextKey, claims) @@ -175,10 +210,10 @@ func (m *MCPAuthMiddleware) sendUnauthorizedResponse(w http.ResponseWriter, err // sendInsufficientScopeResponse sends a 403 Forbidden response per RFC 6750 // when the token is valid but lacks required scopes -func (m *MCPAuthMiddleware) sendInsufficientScopeResponse(w http.ResponseWriter, err error) { +func (m *MCPAuthMiddleware) sendInsufficientScopeResponse(w http.ResponseWriter, requiredScopes []string, err error) { // Build WWW-Authenticate header with error and scope information // Per RFC 6750 Section 3.1 and MCP spec: error, scope, resource_metadata, error_description - scopeList := strings.Join(m.requiredScopes, " ") + scopeList := strings.Join(requiredScopes, " ") authHeader := fmt.Sprintf(`Bearer error="insufficient_scope", scope="%s"`, scopeList) @@ -199,10 +234,10 @@ func (m *MCPAuthMiddleware) sendInsufficientScopeResponse(w http.ResponseWriter, // No JSON-RPC response body is returned } -// validateScopes checks if the token contains all required scopes -func (m *MCPAuthMiddleware) validateScopes(claims authentication.Claims) error { +// validateScopesForRequest checks if the token contains all required scopes +func (m *MCPAuthMiddleware) validateScopesForRequest(claims authentication.Claims, requiredScopes []string) error { // If no scopes are required, skip validation - if len(m.requiredScopes) == 0 { + if len(requiredScopes) == 0 { return nil } @@ -211,7 +246,7 @@ func (m *MCPAuthMiddleware) validateScopes(claims authentication.Claims) error { // Check if all required scopes are present var missingScopes []string - for _, requiredScope := range m.requiredScopes { + for _, requiredScope := range requiredScopes { if !contains(tokenScopes, requiredScope) { missingScopes = append(missingScopes, requiredScope) } diff --git a/router/pkg/mcpserver/auth_middleware_test.go b/router/pkg/mcpserver/auth_middleware_test.go index 66fcc5e778..2321eea294 100644 --- a/router/pkg/mcpserver/auth_middleware_test.go +++ b/router/pkg/mcpserver/auth_middleware_test.go @@ -134,7 +134,7 @@ func TestNewMCPAuthMiddleware(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, []string{"mcp:tools"}) + middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, map[string][]string{"initialize": {"mcp:tools"}}) if tt.wantErr { require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) @@ -277,7 +277,7 @@ func TestMCPAuthMiddleware_MissingHeaders(t *testing.T) { }, } - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) require.NoError(t, err) handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -363,7 +363,7 @@ func TestMCPAuthMiddleware_Integration(t *testing.T) { }, } - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) require.NoError(t, err) handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -487,7 +487,12 @@ func TestMCPAuthMiddleware_ScopeValidation(t *testing.T) { }, } - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, tt.requiredScopes) + // Convert requiredScopes array to map format for new API + scopesRequired := map[string][]string{} + if len(tt.requiredScopes) > 0 { + scopesRequired["initialize"] = tt.requiredScopes + } + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, scopesRequired) require.NoError(t, err) handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -758,7 +763,12 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - middleware, err := NewMCPAuthMiddleware(tt.setupDecoder(), true, testMetadataURL, tt.requiredScopes) + // Convert requiredScopes array to map format for new API + scopesRequired := map[string][]string{} + if len(tt.requiredScopes) > 0 { + scopesRequired["initialize"] = tt.requiredScopes + } + middleware, err := NewMCPAuthMiddleware(tt.setupDecoder(), true, testMetadataURL, scopesRequired) require.NoError(t, err) // Create a test handler that sets status 200 if reached @@ -807,7 +817,7 @@ func TestMCPAuthMiddleware_HTTPMiddleware_WWWAuthenticateFormat(t *testing.T) { }, } - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, []string{}) + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) require.NoError(t, err) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -848,7 +858,8 @@ func TestMCPAuthMiddleware_HTTPMiddleware_WWWAuthenticateFormat(t *testing.T) { } requiredScopes := []string{"mcp:tools:write", "mcp:admin"} - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, requiredScopes) + scopesRequired := map[string][]string{"initialize": requiredScopes} + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, scopesRequired) require.NoError(t, err) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index e9dea49c54..02891c0fc4 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -279,16 +279,11 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) resourceMetadataURL = fmt.Sprintf("%s/.well-known/oauth-protected-resource", options.ServerBaseURL) } - // Get HTTP-level required scopes from the "initialize" key - // These scopes are required for ANY HTTP request (including initialize) - httpRequiredScopes := options.OAuthConfig.ScopesRequired["initialize"] - if httpRequiredScopes == nil { - httpRequiredScopes = []string{} - } - - // Create authentication middleware with HTTP-level required scopes - // Per-tool scope authorization happens at the tool level - authMiddleware, err := NewMCPAuthMiddleware(tokenDecoder, true, resourceMetadataURL, httpRequiredScopes) + // Create authentication middleware with per-tool scope configuration + // The middleware will check: + // - "initialize" key scopes for all HTTP requests (HTTP-level auth) + // - Per-tool scopes when tools are called (by parsing JSON-RPC request) + authMiddleware, err := NewMCPAuthMiddleware(tokenDecoder, true, resourceMetadataURL, options.OAuthConfig.ScopesRequired) if err != nil { cancel() // Clean up the context if initialization fails return nil, fmt.Errorf("failed to create auth middleware: %w", err) From 157b7a64c5901aa4a9316d4be9aab33d006a03f2 Mon Sep 17 00:00:00 2001 From: Ahmet Soormally Date: Wed, 21 Jan 2026 10:53:27 +0000 Subject: [PATCH 3/5] docs(mcp): clarify error code ranges to avoid JSON-RPC conflicts Fix confusing comment that stated application errors use -32768 to -32000, which is the same as the reserved JSON-RPC range. Now clearly states that application-specific error codes must use values OUTSIDE the reserved range to avoid conflicts with JSON-RPC protocol error codes. --- router/pkg/mcpserver/errors.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/router/pkg/mcpserver/errors.go b/router/pkg/mcpserver/errors.go index d45978a1bd..e91a64cf88 100644 --- a/router/pkg/mcpserver/errors.go +++ b/router/pkg/mcpserver/errors.go @@ -3,9 +3,9 @@ package mcpserver // JSON-RPC 2.0 and MCP error codes // // Error code ranges: -// - Standard JSON-RPC 2.0: -32768 to -32000 -// - Server errors (implementation-defined): -32000 to -32099 -// - Application errors: -32768 to -32000 (excluding reserved range) +// - Standard JSON-RPC 2.0: -32768 to -32000 (reserved by JSON-RPC spec) +// - Server errors (implementation-defined): -32000 to -32099 (within JSON-RPC reserved range) +// - Application errors: Must use codes outside -32768 to -32000 to avoid conflicts with JSON-RPC reserved codes const ( // Standard JSON-RPC 2.0 error codes ErrorCodeParseError = -32700 // Invalid JSON was received by the server From 01c30037de2c849f29db3f51f3478b78a19ced3e Mon Sep 17 00:00:00 2001 From: Ahmet Soormally Date: Wed, 21 Jan 2026 16:47:47 +0000 Subject: [PATCH 4/5] feat(mcp): migrate to official modelcontextprotocol/go-sdk Replace mark3labs/mcp-go with official SDK v1.2.0 --- router/go.mod | 13 +- router/go.sum | 27 +- router/pkg/mcpserver/auth_middleware.go | 79 +-- router/pkg/mcpserver/auth_middleware_test.go | 674 ++----------------- router/pkg/mcpserver/server.go | 320 ++++----- 5 files changed, 233 insertions(+), 880 deletions(-) diff --git a/router/go.mod b/router/go.mod index 3ba129a9d3..be85169ad8 100644 --- a/router/go.mod +++ b/router/go.mod @@ -73,7 +73,6 @@ require ( github.com/hashicorp/go-plugin v1.6.3 github.com/iancoleman/strcase v0.3.0 github.com/klauspost/compress v1.18.0 - github.com/mark3labs/mcp-go v0.43.2 github.com/minio/minio-go/v7 v7.0.74 github.com/posthog/posthog-go v1.5.5 github.com/pquerna/cachecontrol v0.2.0 @@ -87,9 +86,14 @@ require ( golang.org/x/time v0.9.0 ) +require ( + github.com/frankban/quicktest v1.14.6 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect +) + require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect - github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bufbuild/protocompile v0.14.1 // indirect @@ -120,16 +124,15 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/yamux v0.1.1 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 // indirect github.com/kingledion/go-tools v0.6.0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect - github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nats-io/nkeys v0.4.7 // indirect github.com/oklog/run v1.0.0 // indirect @@ -149,7 +152,6 @@ require ( github.com/sergi/go-diff v1.3.1 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/spf13/cast v1.7.1 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -157,7 +159,6 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/twmb/franz-go/pkg/kmsg v1.7.0 // indirect github.com/vbatts/tar-split v0.12.1 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect diff --git a/router/go.sum b/router/go.sum index b08ab75735..ed1013c585 100644 --- a/router/go.sum +++ b/router/go.sum @@ -17,8 +17,6 @@ github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQg github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= -github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= -github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -51,6 +49,7 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRcc github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -117,10 +116,13 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -151,8 +153,6 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= -github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= -github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jensneuse/abstractlogger v0.0.4 h1:sa4EH8fhWk3zlTDbSncaWKfwxYM8tYSlQ054ETLyyQY= github.com/jensneuse/abstractlogger v0.0.4/go.mod h1:6WuamOHuykJk8zED/R0LNiLhWR6C7FIAo43ocUEB3mo= github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 h1:wo26fh6a6Za0cOMZIopD2sfH/kq83SJ89ixUWl7pCWc= @@ -163,7 +163,6 @@ github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgf github.com/jhump/protoreflect v1.15.1/go.mod h1:jD/2GMKKE6OqX8qTjhADU1e6DShO+gavG9e0Q693nKo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kingledion/go-tools v0.6.0 h1:y8C/4mWoHgLkO45dB+Y/j0o4Y4WUB5lDTAcMPMtFpTg= github.com/kingledion/go-tools v0.6.0/go.mod h1:qcDJQxBui/H/hterGb90GMlLs9Yi7QrwaJL8OGdbsms= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -184,10 +183,6 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= -github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -204,6 +199,8 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/nats-io/nats.go v1.35.0 h1:XFNqNM7v5B+MQMKqVGAyHwYhyKb48jrenXNxIU20ULk= @@ -226,6 +223,7 @@ github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d h1:U+PMnTlV2tu7RuMK5e github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d/go.mod h1:lXfE4PvvTW5xOjO6Mba8zDPyw8M93B6AQ7frTGnMlA8= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -252,6 +250,7 @@ github.com/r3labs/sse/v2 v2.8.1/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEm github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwyKk= github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= @@ -274,8 +273,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sosodev/duration v1.3.1 h1:qtHBDMQ6lvMQsL15g4aopM4HEfOaYuhWBw3NPTtlqq4= github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERAikUR6SDg= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -318,8 +315,6 @@ github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnn github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245 h1:MYewlXgIhI9jusocPUeyo346J3M5cqzc6ddru1qp+S8= @@ -392,6 +387,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= @@ -422,6 +419,8 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= diff --git a/router/pkg/mcpserver/auth_middleware.go b/router/pkg/mcpserver/auth_middleware.go index 558279c9f0..d082fa4bf8 100644 --- a/router/pkg/mcpserver/auth_middleware.go +++ b/router/pkg/mcpserver/auth_middleware.go @@ -10,9 +10,6 @@ import ( "slices" "strings" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - "github.com/wundergraph/cosmo/router/pkg/authentication" ) @@ -66,33 +63,6 @@ func NewMCPAuthMiddleware(tokenDecoder authentication.TokenDecoder, enabled bool }, nil } -// ToolMiddleware wraps tool handlers with authentication -func (m *MCPAuthMiddleware) ToolMiddleware(next server.ToolHandlerFunc) server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - if !m.enabled { - return next(ctx, req) - } - - // Extract and validate token - claims, err := m.authenticateRequest(ctx) - if err != nil { - // Return authentication error with WWW-Authenticate challenge information - // Per RFC 9728, we should indicate the resource metadata URL - errorMsg := fmt.Sprintf("Authentication failed: %v", err) - if m.resourceMetadataURL != "" { - errorMsg = fmt.Sprintf("Authentication required. Resource metadata available at: %s. Error: %v", - m.resourceMetadataURL, err) - } - return mcp.NewToolResultError(errorMsg), nil - } - - // Add claims to context - ctx = context.WithValue(ctx, userClaimsContextKey, claims) - - return next(ctx, req) - } -} - // authenticateRequest extracts and validates the JWT token using the existing // authentication infrastructure from the router func (m *MCPAuthMiddleware) authenticateRequest(ctx context.Context) (authentication.Claims, error) { @@ -150,31 +120,36 @@ func (m *MCPAuthMiddleware) HTTPMiddleware(next http.Handler) http.Handler { } // Step 2: Parse JSON-RPC request to check for tool-specific scopes - // Read body to extract tool name - body, err := io.ReadAll(r.Body) - if err != nil { - m.sendUnauthorizedResponse(w, fmt.Errorf("failed to read request body")) - return + // Read body to extract tool name (only if body exists) + var body []byte + if r.Body != nil { + body, err = io.ReadAll(r.Body) + if err != nil { + m.sendUnauthorizedResponse(w, fmt.Errorf("failed to read request body")) + return + } + // Restore body for downstream handlers + r.Body = io.NopCloser(bytes.NewBuffer(body)) } - // Restore body for downstream handlers - r.Body = io.NopCloser(bytes.NewBuffer(body)) - // Try to parse as JSON-RPC request - var jsonRPCReq struct { - Method string `json:"method"` - Params json.RawMessage `json:"params"` - } - if err := json.Unmarshal(body, &jsonRPCReq); err == nil && jsonRPCReq.Method == "tools/call" { - // Extract tool name from params - var toolCallParams struct { - Name string `json:"name"` + // Try to parse as JSON-RPC request (only if we have body content) + if len(body) > 0 { + var jsonRPCReq struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` } - if err := json.Unmarshal(jsonRPCReq.Params, &toolCallParams); err == nil && toolCallParams.Name != "" { - // Check if this tool has specific scope requirements - if toolScopes, exists := m.scopesRequired[toolCallParams.Name]; exists && len(toolScopes) > 0 { - if err := m.validateScopesForRequest(claims, toolScopes); err != nil { - m.sendInsufficientScopeResponse(w, toolScopes, err) - return + if err := json.Unmarshal(body, &jsonRPCReq); err == nil && jsonRPCReq.Method == "tools/call" { + // Extract tool name from params + var toolCallParams struct { + Name string `json:"name"` + } + if err := json.Unmarshal(jsonRPCReq.Params, &toolCallParams); err == nil && toolCallParams.Name != "" { + // Check if this tool has specific scope requirements + if toolScopes, exists := m.scopesRequired[toolCallParams.Name]; exists && len(toolScopes) > 0 { + if err := m.validateScopesForRequest(claims, toolScopes); err != nil { + m.sendInsufficientScopeResponse(w, toolScopes, err) + return + } } } } diff --git a/router/pkg/mcpserver/auth_middleware_test.go b/router/pkg/mcpserver/auth_middleware_test.go index 2321eea294..5cd0c54e4f 100644 --- a/router/pkg/mcpserver/auth_middleware_test.go +++ b/router/pkg/mcpserver/auth_middleware_test.go @@ -3,76 +3,15 @@ package mcpserver import ( "context" "errors" - "fmt" "net/http" "net/http/httptest" - "strings" "testing" - "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/authentication" ) -const ( - testMetadataURL = "http://localhost:5025/.well-known/oauth-protected-resource" -) - -// parseWWWAuthenticateParams parses key-value pairs from a WWW-Authenticate Bearer header. -// This is a simple parser for test validation only, not production use. -// -// NOTE: LLM-generated - there are no well-established Go libraries for parsing -// WWW-Authenticate response headers (to-date). This parser handles the -// common case of Bearer authentication with quoted parameter values. -func parseWWWAuthenticateParams(header string) map[string]string { - params := make(map[string]string) - - // Remove "Bearer " prefix - header = strings.TrimPrefix(header, "Bearer ") - header = strings.TrimSpace(header) - - // Simple state machine to parse key="value" pairs - var key, value strings.Builder - inKey := true - inQuote := false - - for i := 0; i < len(header); i++ { - ch := header[i] - - switch { - case ch == '=' && inKey: - inKey = false - case ch == '"' && !inKey: - // Track quote state but don't add quotes to value - inQuote = !inQuote - case ch == ',' && !inQuote: - if key.Len() > 0 { - params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) - } - key.Reset() - value.Reset() - inKey = true - case inKey: - key.WriteByte(ch) - default: - // We're in a value (!inKey) and ch is not a quote (already handled above) - // Include everything (including spaces) when inside quotes - if inQuote || ch != ' ' || value.Len() > 0 { - value.WriteByte(ch) - } - } - } - - // Add final pair - if key.Len() > 0 { - params[strings.TrimSpace(key.String())] = strings.TrimSpace(value.String()) - } - - return params -} - // mockTokenDecoder is a mock implementation of authentication.TokenDecoder for testing type mockTokenDecoder struct { decodeFunc func(token string) (authentication.Claims, error) @@ -82,19 +21,7 @@ func (m *mockTokenDecoder) Decode(token string) (authentication.Claims, error) { if m.decodeFunc != nil { return m.decodeFunc(token) } - return nil, errors.New("not implemented") -} - -// getTextFromResult extracts text from the first content item in a result -func getTextFromResult(result *mcp.CallToolResult) string { - if result == nil || len(result.Content) == 0 { - return "" - } - textContent, ok := mcp.AsTextContent(result.Content[0]) - if !ok { - return "" - } - return textContent.Text + return nil, errors.New("decode not implemented") } func TestNewMCPAuthMiddleware(t *testing.T) { @@ -105,192 +32,45 @@ func TestNewMCPAuthMiddleware(t *testing.T) { } tests := []struct { - name string - decoder authentication.TokenDecoder - enabled bool - wantErr bool - errContains string + name string + decoder authentication.TokenDecoder + enabled bool + wantErr bool }{ { - name: "valid decoder enabled", + name: "valid decoder and enabled", decoder: validDecoder, enabled: true, wantErr: false, }, { - name: "valid decoder disabled", + name: "valid decoder and disabled", decoder: validDecoder, enabled: false, wantErr: false, }, { - name: "nil decoder", - decoder: nil, - enabled: true, - wantErr: true, - errContains: "token decoder must be provided", + name: "nil decoder", + decoder: nil, + enabled: true, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, map[string][]string{"initialize": {"mcp:tools"}}) + middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, "http://localhost:5025/.well-known/oauth-protected-resource", map[string][]string{}) if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) + assert.Error(t, err) assert.Nil(t, middleware) } else { - require.NoError(t, err) - require.NotNil(t, middleware) - assert.Equal(t, tt.enabled, middleware.enabled) - assert.NotNil(t, middleware.authenticator) + assert.NoError(t, err) + assert.NotNil(t, middleware) } }) } } -func TestMCPAuthMiddleware_ToolMiddleware(t *testing.T) { - validClaims := authentication.Claims{"sub": "user123", "email": "user@example.com"} - - tests := []struct { - name string - enabled bool - decoder *mockTokenDecoder - setupHeaders func() http.Header - wantErr bool - wantTextContain string - }{ - { - name: "bypasses auth when disabled", - enabled: false, - decoder: &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - t.Fatal("should not be called") - return nil, nil - }, - }, - setupHeaders: func() http.Header { - return http.Header{} - }, - wantErr: false, - wantTextContain: "no authentication", - }, - { - name: "valid Bearer token", - enabled: true, - decoder: &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - if token == "valid-token" { - return validClaims, nil - } - return nil, errors.New("invalid token") - }, - }, - setupHeaders: func() http.Header { - h := http.Header{} - h.Set("Authorization", "Bearer valid-token") - return h - }, - wantErr: false, - wantTextContain: "authenticated with claims", - }, - { - name: "invalid token", - enabled: true, - decoder: &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - return nil, errors.New("token validation failed") - }, - }, - setupHeaders: func() http.Header { - h := http.Header{} - h.Set("Authorization", "Bearer invalid-token") - return h - }, - wantErr: true, - wantTextContain: "Authentication required", - }, - { - name: "wrong header format", - enabled: true, - decoder: &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - return validClaims, nil - }, - }, - setupHeaders: func() http.Header { - h := http.Header{} - h.Set("Authorization", "invalid-token") - return h - }, - wantErr: true, - wantTextContain: "Authentication required", - }, - { - name: "Bearer token with whitespace", - enabled: true, - decoder: &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - if token == "valid-token" { - return validClaims, nil - } - return nil, fmt.Errorf("unexpected token: %s", token) - }, - }, - setupHeaders: func() http.Header { - h := http.Header{} - h.Set("Authorization", "Bearer valid-token ") - return h - }, - wantErr: false, - wantTextContain: "authenticated with claims", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - middleware, err := NewMCPAuthMiddleware(tt.decoder, tt.enabled, testMetadataURL, []string{}) - require.NoError(t, err) - - handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - claims, ok := GetClaimsFromContext(ctx) - if ok { - return mcp.NewToolResultText(fmt.Sprintf("authenticated with claims: %v", claims)), nil - } - return mcp.NewToolResultText("no authentication"), nil - }) - - ctx := withRequestHeaders(context.Background(), tt.setupHeaders()) - result, err := handler(ctx, mcp.CallToolRequest{}) - - require.NoError(t, err) - assert.Equal(t, tt.wantErr, result.IsError) - assert.Contains(t, getTextFromResult(result), tt.wantTextContain) - }) - } -} - -func TestMCPAuthMiddleware_MissingHeaders(t *testing.T) { - decoder := &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - return authentication.Claims{"sub": "user123"}, nil - }, - } - - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) - require.NoError(t, err) - - handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return mcp.NewToolResultText("success"), nil - }) - - // Context without headers - result, err := handler(context.Background(), mcp.CallToolRequest{}) - require.NoError(t, err) - assert.True(t, result.IsError) - assert.Contains(t, getTextFromResult(result), "missing request headers") -} - func TestGetClaimsFromContext(t *testing.T) { expectedClaims := authentication.Claims{"sub": "user123", "email": "user@example.com"} @@ -335,186 +115,6 @@ func TestGetClaimsFromContext(t *testing.T) { } } -func TestMCPAuthProvider(t *testing.T) { - t.Run("returns headers", func(t *testing.T) { - headers := http.Header{} - headers.Set("Authorization", "Bearer token") - headers.Set("X-Custom", "value") - - provider := &mcpAuthProvider{headers: headers} - assert.Equal(t, headers, provider.AuthenticationHeaders()) - }) - - t.Run("empty headers", func(t *testing.T) { - provider := &mcpAuthProvider{headers: http.Header{}} - assert.Equal(t, 0, len(provider.AuthenticationHeaders())) - }) -} - -func TestMCPAuthMiddleware_Integration(t *testing.T) { - expectedClaims := authentication.Claims{"sub": "user123", "role": "admin"} - - decoder := &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - if token == "valid-token" { - return expectedClaims, nil - } - return nil, errors.New("invalid token") - }, - } - - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) - require.NoError(t, err) - - handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - claims, ok := GetClaimsFromContext(ctx) - if !ok { - return mcp.NewToolResultError("no claims found"), nil - } - return mcp.NewToolResultText(fmt.Sprintf("user: %s, role: %s", claims["sub"], claims["role"])), nil - }) - - // Valid token - headers := http.Header{} - headers.Set("Authorization", "Bearer valid-token") - ctx := withRequestHeaders(context.Background(), headers) - - result, err := handler(ctx, mcp.CallToolRequest{}) - require.NoError(t, err) - assert.False(t, result.IsError) - text := getTextFromResult(result) - assert.Contains(t, text, "user: user123") - assert.Contains(t, text, "role: admin") - - // Invalid token - headers.Set("Authorization", "Bearer invalid-token") - ctx = withRequestHeaders(context.Background(), headers) - - result, err = handler(ctx, mcp.CallToolRequest{}) - require.NoError(t, err) - assert.True(t, result.IsError) - assert.Contains(t, getTextFromResult(result), "Authentication required") -} - -func TestMCPAuthMiddleware_ScopeValidation(t *testing.T) { - tests := []struct { - name string - requiredScopes []string - tokenScopes string - wantErr bool - wantTextContain string - }{ - { - name: "no required scopes, token with no scopes", - requiredScopes: []string{}, - tokenScopes: "", - wantErr: false, - wantTextContain: "authenticated with claims", - }, - { - name: "no required scopes, token with scopes", - requiredScopes: []string{}, - tokenScopes: "some:scope another:scope", - wantErr: false, - wantTextContain: "authenticated with claims", - }, - { - name: "one required scope, token with no scopes", - requiredScopes: []string{"mcp:tools"}, - tokenScopes: "", - wantErr: true, - wantTextContain: "missing required scopes: mcp:tools", - }, - { - name: "one required scope, token has required scope", - requiredScopes: []string{"mcp:tools"}, - tokenScopes: "mcp:tools", - wantErr: false, - wantTextContain: "authenticated with claims", - }, - { - name: "one required scope, token missing required scope", - requiredScopes: []string{"mcp:tools"}, - tokenScopes: "mcp:read", - wantErr: true, - wantTextContain: "missing required scopes: mcp:tools", - }, - { - name: "multiple required scopes, token with no scopes", - requiredScopes: []string{"mcp:tools", "mcp:read"}, - tokenScopes: "", - wantErr: true, - wantTextContain: "missing required scopes: mcp:tools, mcp:read", - }, - { - name: "multiple required scopes, token with partial match", - requiredScopes: []string{"mcp:tools", "mcp:read"}, - tokenScopes: "mcp:tools", - wantErr: true, - wantTextContain: "missing required scopes: mcp:read", - }, - { - name: "multiple required scopes, token has all required scopes", - requiredScopes: []string{"mcp:tools", "mcp:read"}, - tokenScopes: "mcp:tools mcp:read", - wantErr: false, - wantTextContain: "authenticated with claims", - }, - { - name: "multiple required scopes, token with partial match (multiple missing)", - requiredScopes: []string{"mcp:tools", "mcp:read", "mcp:admin"}, - tokenScopes: "mcp:tools mcp:write", - wantErr: true, - wantTextContain: "missing required scopes: mcp:read, mcp:admin", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - decoder := &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - if token == "valid-token" { - claims := authentication.Claims{ - "sub": "user123", - "email": "user@example.com", - } - if tt.tokenScopes != "" { - claims["scope"] = tt.tokenScopes - } - return claims, nil - } - return nil, errors.New("invalid token") - }, - } - - // Convert requiredScopes array to map format for new API - scopesRequired := map[string][]string{} - if len(tt.requiredScopes) > 0 { - scopesRequired["initialize"] = tt.requiredScopes - } - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, scopesRequired) - require.NoError(t, err) - - handler := middleware.ToolMiddleware(func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - claims, ok := GetClaimsFromContext(ctx) - if ok { - return mcp.NewToolResultText(fmt.Sprintf("authenticated with claims: %v", claims)), nil - } - return mcp.NewToolResultText("no authentication"), nil - }) - - headers := http.Header{} - headers.Set("Authorization", "Bearer valid-token") - ctx := withRequestHeaders(context.Background(), headers) - - result, err := handler(ctx, mcp.CallToolRequest{}) - require.NoError(t, err) - assert.Equal(t, tt.wantErr, result.IsError) - assert.Contains(t, getTextFromResult(result), tt.wantTextContain) - }) - } -} - func TestExtractScopes(t *testing.T) { tests := []struct { name string @@ -522,7 +122,7 @@ func TestExtractScopes(t *testing.T) { want []string }{ { - name: "scope as space-separated string (OAuth 2.0 standard)", + name: "scope with multiple values", claims: authentication.Claims{ "scope": "mcp:tools mcp:read mcp:write", }, @@ -535,27 +135,6 @@ func TestExtractScopes(t *testing.T) { }, want: []string{"mcp:tools"}, }, - { - name: "scope with extra whitespace", - claims: authentication.Claims{ - "scope": " mcp:tools mcp:read mcp:write ", - }, - want: []string{"mcp:tools", "mcp:read", "mcp:write"}, - }, - { - name: "scope with tabs and newlines", - claims: authentication.Claims{ - "scope": "mcp:tools\t\nmcp:read\n\tmcp:write", - }, - want: []string{"mcp:tools", "mcp:read", "mcp:write"}, - }, - { - name: "scope with multiple spaces between values", - claims: authentication.Claims{ - "scope": "mcp:tools mcp:read mcp:write", - }, - want: []string{"mcp:tools", "mcp:read", "mcp:write"}, - }, { name: "no scope claim", claims: authentication.Claims{}, @@ -568,69 +147,6 @@ func TestExtractScopes(t *testing.T) { }, want: []string{}, }, - { - name: "scope with only whitespace", - claims: authentication.Claims{ - "scope": " \t\n ", - }, - want: []string{}, - }, - { - name: "scope claim with wrong type (number)", - claims: authentication.Claims{ - "scope": 123, - }, - want: []string{}, - }, - { - name: "scope claim with wrong type (array)", - claims: authentication.Claims{ - "scope": []string{"mcp:tools", "mcp:read"}, - }, - want: []string{}, - }, - { - name: "scope claim with wrong type (object)", - claims: authentication.Claims{ - "scope": map[string]string{"key": "value"}, - }, - want: []string{}, - }, - { - name: "nil claims", - claims: nil, - want: []string{}, - }, - { - name: "complex scopes with colons", - claims: authentication.Claims{ - "scope": "mcp:tools:read mcp:tools:write api:v1:access", - }, - want: []string{"mcp:tools:read", "mcp:tools:write", "api:v1:access"}, - }, - { - name: "scopes with URLs", - claims: authentication.Claims{ - "scope": "https://api.example.com/read https://api.example.com/write", - }, - want: []string{"https://api.example.com/read", "https://api.example.com/write"}, - }, - { - name: "scopes with special characters", - claims: authentication.Claims{ - "scope": "read:users write:users delete:users", - }, - want: []string{"read:users", "write:users", "delete:users"}, - }, - { - name: "other claims present but no scope", - claims: authentication.Claims{ - "sub": "user123", - "email": "user@example.com", - "aud": "https://api.example.com", - }, - want: []string{}, - }, } for _, tt := range tests { @@ -644,15 +160,15 @@ func TestExtractScopes(t *testing.T) { func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { t.Parallel() + const testMetadataURL = "http://localhost:5025/.well-known/oauth-protected-resource" + tests := []struct { name string requiredScopes []string setupDecoder func() *mockTokenDecoder setupRequest func() *http.Request wantStatusCode int - wantWWWAuthenticate string wantWWWAuthenticatePrefix string - wantBody string }{ { name: "valid token without scopes", @@ -668,11 +184,11 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { } }, setupRequest: func() *http.Request { - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req, _ := http.NewRequest("POST", "/mcp", nil) req.Header.Set("Authorization", "Bearer valid-token") return req }, - wantStatusCode: http.StatusOK, + wantStatusCode: 200, }, { name: "missing authorization header", @@ -685,12 +201,11 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { } }, setupRequest: func() *http.Request { - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req, _ := http.NewRequest("POST", "/mcp", nil) return req }, - wantStatusCode: http.StatusUnauthorized, + wantStatusCode: 401, wantWWWAuthenticatePrefix: `Bearer realm="mcp", resource_metadata="` + testMetadataURL + `"`, - wantBody: "", // No JSON-RPC body per MCP spec }, { name: "invalid token", @@ -703,13 +218,12 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { } }, setupRequest: func() *http.Request { - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req, _ := http.NewRequest("POST", "/mcp", nil) req.Header.Set("Authorization", "Bearer invalid-token") return req }, - wantStatusCode: http.StatusUnauthorized, + wantStatusCode: 401, wantWWWAuthenticatePrefix: `Bearer realm="mcp", resource_metadata="` + testMetadataURL + `"`, - wantBody: "", // No JSON-RPC body per MCP spec }, { name: "valid token but insufficient scopes", @@ -728,13 +242,12 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { } }, setupRequest: func() *http.Request { - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req, _ := http.NewRequest("POST", "/mcp", nil) req.Header.Set("Authorization", "Bearer valid-token") return req }, - wantStatusCode: http.StatusForbidden, - wantWWWAuthenticatePrefix: `Bearer error="insufficient_scope", scope="mcp:tools:write mcp:admin", resource_metadata="` + testMetadataURL + `"`, - wantBody: "", // No JSON-RPC body per MCP spec + wantStatusCode: 403, + wantWWWAuthenticatePrefix: `Bearer error="insufficient_scope", scope="mcp:tools:write mcp:admin"`, }, { name: "valid token with all required scopes", @@ -753,140 +266,33 @@ func TestMCPAuthMiddleware_HTTPMiddleware(t *testing.T) { } }, setupRequest: func() *http.Request { - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) + req, _ := http.NewRequest("POST", "/mcp", nil) req.Header.Set("Authorization", "Bearer valid-token") return req }, - wantStatusCode: http.StatusOK, + wantStatusCode: 200, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Convert requiredScopes array to map format for new API - scopesRequired := map[string][]string{} - if len(tt.requiredScopes) > 0 { - scopesRequired["initialize"] = tt.requiredScopes - } - middleware, err := NewMCPAuthMiddleware(tt.setupDecoder(), true, testMetadataURL, scopesRequired) - require.NoError(t, err) + decoder := tt.setupDecoder() + middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{"initialize": tt.requiredScopes}) + assert.NoError(t, err) - // Create a test handler that sets status 200 if reached - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) + handler := middleware.HTTPMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) - // Wrap with auth middleware - handler := middleware.HTTPMiddleware(testHandler) - - // Create response recorder + req := tt.setupRequest() rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) - // Execute request - handler.ServeHTTP(rr, tt.setupRequest()) - - // Verify status code - assert.Equal(t, tt.wantStatusCode, rr.Code, "status code mismatch") - - // Verify WWW-Authenticate header for auth failures + assert.Equal(t, tt.wantStatusCode, rr.Code) if tt.wantWWWAuthenticatePrefix != "" { - authHeader := rr.Header().Get("WWW-Authenticate") - assert.NotEmpty(t, authHeader, "WWW-Authenticate header should be present") - assert.Contains(t, authHeader, tt.wantWWWAuthenticatePrefix, "WWW-Authenticate header should match expected format") - - // Verify resource_metadata is present (per MCP spec) - assert.Contains(t, authHeader, "resource_metadata=", "resource_metadata should be in WWW-Authenticate header") - } - - // Verify no JSON-RPC response body for HTTP-level auth failures - if tt.wantStatusCode == http.StatusUnauthorized || tt.wantStatusCode == http.StatusForbidden { - body := rr.Body.String() - assert.Equal(t, "", body, "HTTP-level auth failures should not return JSON-RPC response body per MCP spec") + wwwAuth := rr.Header().Get("WWW-Authenticate") + assert.Contains(t, wwwAuth, tt.wantWWWAuthenticatePrefix) } }) } } - -func TestMCPAuthMiddleware_HTTPMiddleware_WWWAuthenticateFormat(t *testing.T) { - t.Parallel() - - t.Run("401 response has correct WWW-Authenticate format", func(t *testing.T) { - decoder := &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - return nil, errors.New("invalid token") - }, - } - - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, map[string][]string{}) - require.NoError(t, err) - - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - handler := middleware.HTTPMiddleware(testHandler) - - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) - req.Header.Set("Authorization", "Bearer invalid-token") - - rr := httptest.NewRecorder() - - handler.ServeHTTP(rr, req) - - assert.Equal(t, http.StatusUnauthorized, rr.Code) - - // Parse WWW-Authenticate header properly - authHeader := rr.Header().Get("WWW-Authenticate") - require.NotEmpty(t, authHeader, "WWW-Authenticate header must be present") - - params := parseWWWAuthenticateParams(authHeader) - - // Verify expected fields per RFC 6750 - assert.Equal(t, "mcp", params["realm"], "realm should be 'mcp'") - assert.Equal(t, testMetadataURL, params["resource_metadata"], "resource_metadata must be present for OAuth discovery") - assert.NotEmpty(t, params["error_description"], "error_description should provide details") - }) - - t.Run("403 response has correct WWW-Authenticate format per RFC 6750", func(t *testing.T) { - decoder := &mockTokenDecoder{ - decodeFunc: func(token string) (authentication.Claims, error) { - return authentication.Claims{ - "sub": "user123", - "scope": "mcp:read", - }, nil - }, - } - - requiredScopes := []string{"mcp:tools:write", "mcp:admin"} - scopesRequired := map[string][]string{"initialize": requiredScopes} - middleware, err := NewMCPAuthMiddleware(decoder, true, testMetadataURL, scopesRequired) - require.NoError(t, err) - - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - - handler := middleware.HTTPMiddleware(testHandler) - - req, _ := http.NewRequest(http.MethodPost, "/mcp", nil) - req.Header.Set("Authorization", "Bearer valid-token") - - rr := httptest.NewRecorder() - - handler.ServeHTTP(rr, req) - - assert.Equal(t, http.StatusForbidden, rr.Code) - - // Parse WWW-Authenticate header properly - authHeader := rr.Header().Get("WWW-Authenticate") - require.NotEmpty(t, authHeader, "WWW-Authenticate header must be present") - - params := parseWWWAuthenticateParams(authHeader) - - // Per RFC 6750 Section 3.1: Verify all required fields - assert.Equal(t, "insufficient_scope", params["error"], "error parameter must be 'insufficient_scope'") - assert.Equal(t, "mcp:tools:write mcp:admin", params["scope"], "scope parameter must list required scopes") - assert.Equal(t, testMetadataURL, params["resource_metadata"], "resource_metadata must be present") - assert.NotEmpty(t, params["error_description"], "error_description should provide details") - }) -} diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 02891c0fc4..62f2a5a4b2 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -14,8 +14,7 @@ import ( "github.com/hashicorp/go-retryablehttp" "github.com/iancoleman/strcase" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/santhosh-tekuri/jsonschema/v6" "go.uber.org/zap" @@ -109,7 +108,7 @@ type Options struct { // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations type GraphQLSchemaServer struct { - server *server.MCPServer + server *mcp.Server graphName string operationsDir string listenAddr string @@ -117,7 +116,7 @@ type GraphQLSchemaServer struct { httpClient *http.Client requestTimeout time.Duration routerGraphQLEndpoint string - httpServer *server.StreamableHTTPServer + httpServer *http.Server excludeMutations bool enableArbitraryOperations bool exposeSchema bool @@ -232,15 +231,8 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) // Create a cancellable context for managing the server lifecycle ctx, cancel := context.WithCancel(context.Background()) - // Prepare server options - var serverOpts []server.ServerOption - serverOpts = append(serverOpts, - server.WithToolCapabilities(true), - server.WithPaginationLimit(100), - server.WithRecovery(), - ) - // Add authentication middleware if OAuth is configured + var authMiddleware *MCPAuthMiddleware if options.OAuthConfig != nil && options.OAuthConfig.Enabled && len(options.OAuthConfig.JWKS) > 0 { // Convert config.JWKSConfiguration to authentication.JWKSConfig authConfigs := make([]authentication.JWKSConfig, 0, len(options.OAuthConfig.JWKS)) @@ -283,60 +275,29 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) // The middleware will check: // - "initialize" key scopes for all HTTP requests (HTTP-level auth) // - Per-tool scopes when tools are called (by parsing JSON-RPC request) - authMiddleware, err := NewMCPAuthMiddleware(tokenDecoder, true, resourceMetadataURL, options.OAuthConfig.ScopesRequired) + authMiddleware, err = NewMCPAuthMiddleware(tokenDecoder, true, resourceMetadataURL, options.OAuthConfig.ScopesRequired) if err != nil { cancel() // Clean up the context if initialization fails return nil, fmt.Errorf("failed to create auth middleware: %w", err) } // Store auth middleware for HTTP-level protection - // Note: We don't use WithToolHandlerMiddleware here because per MCP spec, + // Note: We don't use tool middleware here because per MCP spec, // ALL HTTP requests must be authenticated, not just tool calls options.Logger.Info("MCP OAuth authentication enabled", zap.Int("jwks_providers", len(options.OAuthConfig.JWKS)), zap.String("authorization_server", options.OAuthConfig.AuthorizationServerURL)) - - // Create the MCP server with all options - mcpServer := server.NewMCPServer( - "wundergraph-cosmo-"+strcase.ToKebab(options.GraphName), - "0.0.1", - serverOpts..., - ) - - retryClient := retryablehttp.NewClient() - retryClient.Logger = nil - httpClient := retryClient.StandardClient() - httpClient.Timeout = 60 * time.Second - - gs := &GraphQLSchemaServer{ - server: mcpServer, - graphName: options.GraphName, - operationsDir: options.OperationsDir, - listenAddr: options.ListenAddr, - logger: options.Logger, - httpClient: httpClient, - requestTimeout: options.RequestTimeout, - routerGraphQLEndpoint: routerGraphQLEndpoint, - excludeMutations: options.ExcludeMutations, - enableArbitraryOperations: options.EnableArbitraryOperations, - exposeSchema: options.ExposeSchema, - stateless: options.Stateless, - corsConfig: options.CorsConfig, - ctx: ctx, - cancel: cancel, - oauthConfig: options.OAuthConfig, - serverBaseURL: options.ServerBaseURL, - authMiddleware: authMiddleware, - } - - return gs, nil } // Create the MCP server with all options - mcpServer := server.NewMCPServer( - "wundergraph-cosmo-"+strcase.ToKebab(options.GraphName), - "0.0.1", - serverOpts..., + mcpServer := mcp.NewServer( + &mcp.Implementation{ + Name: "wundergraph-cosmo-" + strcase.ToKebab(options.GraphName), + Version: "0.0.1", + }, + &mcp.ServerOptions{ + PageSize: 100, + }, ) retryClient := retryablehttp.NewClient() @@ -363,7 +324,7 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) cancel: cancel, oauthConfig: options.OAuthConfig, serverBaseURL: options.ServerBaseURL, - authMiddleware: nil, // No auth middleware when OAuth is disabled + authMiddleware: authMiddleware, } return gs, nil @@ -463,8 +424,8 @@ func WithServerBaseURL(baseURL string) func(*Options) { } } -// Serve starts the server with the configured options and returns a streamable HTTP server. -func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { +// Serve starts the server with the configured options and returns the HTTP server. +func (s *GraphQLSchemaServer) Serve() (*http.Server, error) { // Create custom HTTP server httpServer := &http.Server{ Addr: s.listenAddr, @@ -473,12 +434,14 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { IdleTimeout: 60 * time.Second, } - streamableHTTPServer := server.NewStreamableHTTPServer(s.server, - server.WithStreamableHTTPServer(httpServer), - server.WithLogger(NewZapAdapter(s.logger.With(zap.String("component", "mcp-server")))), - server.WithStateLess(s.stateless), - server.WithHTTPContextFunc(requestHeadersFromRequest), - server.WithHeartbeatInterval(10*time.Second), + // Create MCP streamable HTTP handler + // The getServer function returns our MCP server instance for each request + streamableHTTPHandler := mcp.NewStreamableHTTPHandler( + func(req *http.Request) *mcp.Server { + // Add request headers to context for tool handlers + return s.server + }, + nil, // Use default options ) middleware := cors.New(s.corsConfig) @@ -497,9 +460,7 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { // MCP endpoint with HTTP-level authentication // Per MCP spec: "authorization MUST be included in every HTTP request from client to server" - mcpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - streamableHTTPServer.ServeHTTP(w, r) - }) + mcpHandler := http.Handler(streamableHTTPHandler) // Apply authentication middleware if OAuth is enabled if s.authMiddleware != nil { @@ -533,7 +494,7 @@ func (s *GraphQLSchemaServer) Serve() (*server.StreamableHTTPServer, error) { } }() - return streamableHTTPServer, nil + return httpServer, nil } // Start loads operations and starts the server @@ -563,7 +524,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { } } - s.server.DeleteTools(s.registeredTools...) + s.server.RemoveTools(s.registeredTools...) if err := s.registerTools(); err != nil { return fmt.Errorf("failed to register tools: %w", err) @@ -601,75 +562,61 @@ func (s *GraphQLSchemaServer) registerTools() error { // Only register the schema tool if exposeSchema is enabled if s.exposeSchema { // Create a schema with empty properties since get_schema takes no input - // Note: We omit "required" field to get nil instead of empty array - getSchemaInputSchema := []byte(`{ - "type": "object", - "properties": {} - }`) - - tool := mcp.NewToolWithRawSchema( - "get_schema", - "Provides the full GraphQL schema of the API.", - getSchemaInputSchema, - ) - - tool.Annotations = mcp.ToolAnnotation{ - Title: "Get GraphQL Schema", - ReadOnlyHint: mcp.ToBoolPtr(true), + getSchemaInputSchema := map[string]any{ + "type": "object", + "properties": map[string]any{}, } - s.server.AddTool( - tool, - s.handleGetGraphQLSchema(), - ) + tool := &mcp.Tool{ + Name: "get_schema", + Description: "Provides the full GraphQL schema of the API.", + InputSchema: getSchemaInputSchema, + Annotations: &mcp.ToolAnnotations{ + Title: "Get GraphQL Schema", + ReadOnlyHint: true, + }, + } + s.server.AddTool(tool, s.handleGetGraphQLSchema()) s.registeredTools = append(s.registeredTools, "get_schema") } // Only register the execute_graphql tool if enableArbitraryOperations is enabled if s.enableArbitraryOperations { // Add a tool to execute arbitrary GraphQL queries - executeGraphQLSchema := []byte(`{ - "type": "object", + executeGraphQLSchema := map[string]any{ + "type": "object", "description": "The query and variables to execute.", - "properties": { - "query": { - "type": "string", - "description": "The GraphQL query or mutation string to execute." + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The GraphQL query or mutation string to execute.", }, - "variables": { - "type": "object", + "variables": map[string]any{ + "type": "object", "additionalProperties": true, - "description": "The variables to pass to the GraphQL query as a JSON object." - } + "description": "The variables to pass to the GraphQL query as a JSON object.", + }, }, "additionalProperties": false, - "required": ["query"] - }`) - - // Validate the schema before using it - if err := s.schemaCompiler.ValidateJSONSchema(executeGraphQLSchema); err != nil { - return fmt.Errorf("invalid schema for execute_graphql tool: %w", err) + "required": []string{"query"}, } - tool := mcp.NewToolWithRawSchema( - "execute_graphql", - "Executes a GraphQL query or mutation.", - executeGraphQLSchema, - ) - - tool.Annotations = mcp.ToolAnnotation{ - Title: "Execute GraphQL Query", - DestructiveHint: mcp.ToBoolPtr(true), - IdempotentHint: mcp.ToBoolPtr(false), - OpenWorldHint: mcp.ToBoolPtr(true), + destructiveHint := true + openWorldHint := true + tool := &mcp.Tool{ + Name: "execute_graphql", + Description: "Executes a GraphQL query or mutation.", + InputSchema: executeGraphQLSchema, + Annotations: &mcp.ToolAnnotations{ + Title: "Execute GraphQL Query", + DestructiveHint: &destructiveHint, + IdempotentHint: false, + OpenWorldHint: &openWorldHint, + }, } - s.server.AddTool( - tool, - s.handleExecuteGraphQL(), - ) - + s.server.AddTool(tool, s.handleExecuteGraphQL()) s.registeredTools = append(s.registeredTools, "execute_graphql") } @@ -732,43 +679,62 @@ func (s *GraphQLSchemaServer) registerTools() error { ) toolName = fmt.Sprintf("execute_operation_%s", operationToolName) } - tool := mcp.NewToolWithRawSchema( - toolName, - toolDescription, - op.JSONSchema, - ) + // Parse JSON schema into map for the official SDK + var inputSchema any + if len(op.JSONSchema) > 0 { + if err := json.Unmarshal(op.JSONSchema, &inputSchema); err != nil { + s.logger.Error("failed to parse JSON schema for operation", + zap.String("operation", op.Name), + zap.Error(err)) + continue + } + } else { + inputSchema = map[string]any{"type": "object", "properties": map[string]any{}} + } - tool.Annotations = mcp.ToolAnnotation{ - IdempotentHint: mcp.ToBoolPtr(op.OperationType != "mutation"), - Title: fmt.Sprintf("Execute operation %s", op.Name), - ReadOnlyHint: mcp.ToBoolPtr(op.OperationType == "query"), - OpenWorldHint: mcp.ToBoolPtr(true), + idempotent := op.OperationType != "mutation" + openWorld := true + tool := &mcp.Tool{ + Name: toolName, + Description: toolDescription, + InputSchema: inputSchema, + Annotations: &mcp.ToolAnnotations{ + IdempotentHint: op.OperationType != "mutation", + Title: fmt.Sprintf("Execute operation %s", op.Name), + ReadOnlyHint: op.OperationType == "query", + OpenWorldHint: &openWorld, + }, } - s.server.AddTool( - tool, - s.handleOperation(handler), - ) + // IdempotentHint uses the plain bool value, but keep it for later if needed + _ = idempotent + + s.server.AddTool(tool, s.handleOperation(handler)) s.registeredTools = append(s.registeredTools, toolName) } - s.server.AddTool( - mcp.NewTool( - "get_operation_info", - mcp.WithDescription("Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application."), - mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: "Get GraphQL Operation Info", - ReadOnlyHint: mcp.ToBoolPtr(true), - }), - mcp.WithString("operationName", - mcp.Required(), - mcp.Description("The exact name of the GraphQL operation to retrieve information for."), - mcp.Enum(graphqlOperationNames...), - ), - ), - s.handleGraphQLOperationInfo(), - ) + getOperationInfoTool := &mcp.Tool{ + Name: "get_operation_info", + Description: "Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "operationName": map[string]any{ + "type": "string", + "description": "The exact name of the GraphQL operation to retrieve information for.", + "enum": graphqlOperationNames, + }, + }, + "required": []string{"operationName"}, + }, + Annotations: &mcp.ToolAnnotations{ + Title: "Get GraphQL Operation Info", + ReadOnlyHint: true, + }, + } + + s.server.AddTool(getOperationInfoTool, s.handleGraphQLOperationInfo()) s.registeredTools = append(s.registeredTools, "get_operation_info") @@ -776,8 +742,8 @@ func (s *GraphQLSchemaServer) registerTools() error { } // handleOperation handles a specific operation -func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Log authenticated user if OAuth is enabled if claims, ok := GetClaimsFromContext(ctx); ok { s.logger.Debug("operation called by authenticated user", @@ -786,15 +752,15 @@ func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ct zap.String("operation", handler.operation.Name)) } - jsonBytes, err := json.Marshal(request.GetArguments()) - if err != nil { - return nil, fmt.Errorf("failed to marshal arguments: %w", err) - } + jsonBytes := request.Params.Arguments // Validate the JSON input against the pre-compiled schema derived from the operation input type if handler.compiledSchema != nil { if err := s.schemaCompiler.ValidateInput(jsonBytes, handler.compiledSchema); err != nil { - return mcp.NewToolResultErrorFromErr("Input validation Error", err), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("Input validation error: %v", err)}}, + IsError: true, + }, nil } } @@ -804,13 +770,10 @@ func (s *GraphQLSchemaServer) handleOperation(handler *operationHandler) func(ct } // handleGraphQLOperationInfo returns a handler function that provides detailed info for a specific operation. -func (s *GraphQLSchemaServer) handleGraphQLOperationInfo() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (s *GraphQLSchemaServer) handleGraphQLOperationInfo() func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var input GraphQLOperationInfoInput - inputBytes, err := json.Marshal(request.GetArguments()) - if err != nil { - return nil, fmt.Errorf("failed to marshal input arguments: %w", err) - } + inputBytes := request.Params.Arguments if err := json.Unmarshal(inputBytes, &input); err != nil { return nil, fmt.Errorf("failed to unmarshal input arguments: %w. Ensure you provide {\"operationName\": \"\"}", err) } @@ -876,7 +839,9 @@ Important Notes: // Combine all sections response := overview + schemaInfo + queryInfo + usageInstructions + requestFormat + importantNotes - return mcp.NewToolResultText(response), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: response}}, + }, nil } } @@ -945,21 +910,29 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str // If there are errors but no data, return only the errors if len(graphqlResponse.Data) == 0 || string(graphqlResponse.Data) == "null" { - return mcp.NewToolResultErrorFromErr("Response Error", err), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("Response error: %v", err)}}, + IsError: true, + }, nil } // If we have both errors and data, include data in the error message dataString := string(graphqlResponse.Data) combinedErrorMsg := fmt.Sprintf("Response error with partial success, Error: %s, Data: %s)", errorMessage, dataString) - return mcp.NewToolResultErrorFromErr(combinedErrorMsg, err), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: combinedErrorMsg}}, + IsError: true, + }, nil } - return mcp.NewToolResultText(string(body)), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(body)}}, + }, nil } // handleExecuteGraphQL returns a handler function that executes arbitrary GraphQL queries -func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Log authenticated user if OAuth is enabled if claims, ok := GetClaimsFromContext(ctx); ok { s.logger.Debug("arbitrary GraphQL query called by authenticated user", @@ -968,10 +941,7 @@ func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, r } // Parse the JSON input - jsonBytes, err := json.Marshal(request.GetArguments()) - if err != nil { - return nil, fmt.Errorf("failed to marshal arguments: %w", err) - } + jsonBytes := request.Params.Arguments var input ExecuteGraphQLInput if err := json.Unmarshal(jsonBytes, &input); err != nil { @@ -987,8 +957,8 @@ func (s *GraphQLSchemaServer) handleExecuteGraphQL() func(ctx context.Context, r } // handleGetGraphQLSchema returns a handler function that returns the full GraphQL schema -func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get the schema from the operations manager schema := s.operationsManager.GetSchema() if schema == nil { @@ -1001,7 +971,9 @@ func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, return nil, fmt.Errorf("failed to convert schema to string: %w", err) } - return mcp.NewToolResultText(schemaStr), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: schemaStr}}, + }, nil } } From 634f40af6a7817496dd6a1069a10c11cfae38f98 Mon Sep 17 00:00:00 2001 From: Ahmet Soormally Date: Tue, 3 Feb 2026 05:00:24 +0000 Subject: [PATCH 5/5] fix(mcp): improve error handling and security hardening - Fix error return to use actual error message instead of nil err variable - Store request headers in context so headersFromContext succeeds - Make Bearer prefix parsing case-insensitive per RFC 6750 - Add bounded body read with LimitReader to prevent memory exhaustion - Buffer JSON response before writing headers to handle encoding errors - Use context-aware HTTP requests in test helpers for proper cancellation - Add defensive token preview helper to prevent panics on short tokens - Remove duplicate nolint directive --- router-tests/mcp_auth_e2e_test.go | 13 +++++++++++-- router-tests/mcp_test.go | 2 +- router-tests/testutil/auth_helpers.go | 6 ++++-- router-tests/testutil/jwt_helper.go | 6 +++++- router/pkg/mcpserver/auth_middleware.go | 14 ++++++++++++-- router/pkg/mcpserver/server.go | 13 ++++++++----- 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/router-tests/mcp_auth_e2e_test.go b/router-tests/mcp_auth_e2e_test.go index 62efcf2bc6..896c2e4afb 100644 --- a/router-tests/mcp_auth_e2e_test.go +++ b/router-tests/mcp_auth_e2e_test.go @@ -15,6 +15,15 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" ) +// previewToken returns a truncated preview of a token for logging purposes. +// Returns the full token if shorter than n characters, otherwise returns first n characters with "...". +func previewToken(token string, n int) string { + if len(token) <= n { + return token + } + return token[:n] + "..." +} + // authRoundTripper wraps an http.RoundTripper and adds Authorization headers // It also captures the last HTTP response for error analysis type authRoundTripper struct { @@ -199,7 +208,7 @@ func TestMCPAuthorizationWithOfficialSDK(t *testing.T) { require.NoError(t, err) defer mcpClient.Close() //nolint:errcheck - t.Logf("✓ Connected to MCP server with token: %s", token[:20]+"...") + t.Logf("✓ Connected to MCP server with token: %s", previewToken(token, 20)) // Call a tool result, err := mcpClient.CallTool(ctx, "execute_operation_my_employees", map[string]any{ @@ -317,7 +326,7 @@ func TestMCPAuthorizationWithOfficialSDK(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) - t.Logf("✓ Request %d succeeded with token: %s", i+1, token[:25]+"...") + t.Logf("✓ Request %d succeeded with token: %s", i+1, previewToken(token, 25)) } t.Logf("✓ All token changes worked on same session") diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 975c6c93c4..6db7b8099e 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -473,7 +473,7 @@ func TestMCP(t *testing.T) { // Make the request resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) - defer resp.Body.Close() //nolint:errcheck //nolint:errcheck + defer resp.Body.Close() //nolint:errcheck // Verify response status assert.Equal(t, http.StatusNoContent, resp.StatusCode) diff --git a/router-tests/testutil/auth_helpers.go b/router-tests/testutil/auth_helpers.go index d735ffb0a4..391bf0c52a 100644 --- a/router-tests/testutil/auth_helpers.go +++ b/router-tests/testutil/auth_helpers.go @@ -16,8 +16,10 @@ import ( func ParseWWWAuthenticateParams(header string) map[string]string { params := make(map[string]string) - // Remove "Bearer " prefix - header = strings.TrimPrefix(header, "Bearer ") + // Remove "Bearer " prefix (case-insensitive) + if len(header) >= 7 && strings.EqualFold(header[:7], "Bearer ") { + header = header[7:] + } header = strings.TrimSpace(header) // Simple state machine to parse key="value" pairs diff --git a/router-tests/testutil/jwt_helper.go b/router-tests/testutil/jwt_helper.go index 830a51e9bd..99b1929d62 100644 --- a/router-tests/testutil/jwt_helper.go +++ b/router-tests/testutil/jwt_helper.go @@ -102,7 +102,11 @@ func (s *JWKSTestServer) waitForReady(timeout time.Duration) error { case <-ctx.Done(): return fmt.Errorf("timeout waiting for JWKS server") case <-ticker.C: - resp, err := http.Get(s.jwksURL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.jwksURL, nil) + if err != nil { + continue + } + resp, err := http.DefaultClient.Do(req) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { diff --git a/router/pkg/mcpserver/auth_middleware.go b/router/pkg/mcpserver/auth_middleware.go index d082fa4bf8..6ed1670d27 100644 --- a/router/pkg/mcpserver/auth_middleware.go +++ b/router/pkg/mcpserver/auth_middleware.go @@ -17,6 +17,9 @@ type contextKey string const ( userClaimsContextKey contextKey = "mcp_user_claims" + // maxBodyBytes is the maximum size of the request body we'll read for scope checking. + // This prevents memory exhaustion from oversized payloads. + maxBodyBytes int64 = 1 << 20 // 1 MB ) // mcpAuthProvider adapts MCP headers to the authentication.Provider interface @@ -121,13 +124,19 @@ func (m *MCPAuthMiddleware) HTTPMiddleware(next http.Handler) http.Handler { // Step 2: Parse JSON-RPC request to check for tool-specific scopes // Read body to extract tool name (only if body exists) + // Use LimitReader to prevent memory exhaustion from oversized payloads var body []byte if r.Body != nil { - body, err = io.ReadAll(r.Body) + limitedReader := io.LimitReader(r.Body, maxBodyBytes+1) + body, err = io.ReadAll(limitedReader) if err != nil { m.sendUnauthorizedResponse(w, fmt.Errorf("failed to read request body")) return } + if int64(len(body)) > maxBodyBytes { + m.sendUnauthorizedResponse(w, fmt.Errorf("request body too large")) + return + } // Restore body for downstream handlers r.Body = io.NopCloser(bytes.NewBuffer(body)) } @@ -155,8 +164,9 @@ func (m *MCPAuthMiddleware) HTTPMiddleware(next http.Handler) http.Handler { } } - // Add claims to request context for downstream handlers + // Add claims and request headers to request context for downstream handlers ctx := context.WithValue(r.Context(), userClaimsContextKey, claims) + ctx = requestHeadersFromRequest(ctx, r) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 62f2a5a4b2..7904bf8c50 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -911,7 +911,7 @@ func (s *GraphQLSchemaServer) executeGraphQLQuery(ctx context.Context, query str // If there are errors but no data, return only the errors if len(graphqlResponse.Data) == 0 || string(graphqlResponse.Data) == "null" { return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("Response error: %v", err)}}, + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf("Response error: %s", errorMessage)}}, IsError: true, }, nil } @@ -1041,14 +1041,17 @@ func (s *GraphQLSchemaServer) handleProtectedResourceMetadata(w http.ResponseWri ScopesSupported: scopes, // Automatically derived from required scopes } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - - if err := json.NewEncoder(w).Encode(metadata); err != nil { + // Encode to buffer first so we can handle errors before writing headers + data, err := json.Marshal(metadata) + if err != nil { s.logger.Error("failed to encode protected resource metadata", zap.Error(err)) http.Error(w, "Internal server error", http.StatusInternalServerError) return } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } // GetResourceMetadataURL returns the URL for the OAuth 2.0 Protected Resource Metadata endpoint