diff --git a/app/mcp/route.ts b/app/mcp/route.ts index 465a06b9..e7d3719d 100644 --- a/app/mcp/route.ts +++ b/app/mcp/route.ts @@ -1,29 +1,23 @@ import { registerAllTools } from "@/lib/mcp/tools"; -import { createMcpHandler } from "mcp-handler"; +import { createMcpHandler, withMcpAuth } from "mcp-handler"; +import { verifyApiKey } from "@/lib/mcp/verifyApiKey"; -let handler: ReturnType | null = null; +const baseHandler = createMcpHandler( + server => { + registerAllTools(server); + }, + { + serverInfo: { + name: "recoup-mcp", + version: "0.0.1", + }, + }, +); -/** - * Gets the MCP handler for the API. - * - * @returns The MCP handler. - */ -async function getHandler(): Promise> { - if (!handler) { - handler = createMcpHandler( - server => { - registerAllTools(server); - }, - { - serverInfo: { - name: "recoup-mcp", - version: "0.0.1", - }, - }, - ); - } - return handler; -} +// Wrap with auth - API key is required for all MCP requests +const handler = withMcpAuth(baseHandler, verifyApiKey, { + required: true, +}); /** * GET handler for the MCP API. @@ -32,7 +26,6 @@ async function getHandler(): Promise> { * @returns The response from the MCP handler. */ export async function GET(req: Request) { - const handler = await getHandler(); return handler(req); } @@ -43,6 +36,5 @@ export async function GET(req: Request) { * @returns The response from the MCP handler. */ export async function POST(req: Request) { - const handler = await getHandler(); return handler(req); } diff --git a/lib/mcp/resolveAccountId.ts b/lib/mcp/resolveAccountId.ts new file mode 100644 index 00000000..100dcafd --- /dev/null +++ b/lib/mcp/resolveAccountId.ts @@ -0,0 +1,52 @@ +import { canAccessAccount } from "@/lib/organizations/canAccessAccount"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; + +export interface ResolveAccountIdParams { + authInfo: McpAuthInfo | undefined; + accountIdOverride: string | undefined; +} + +export interface ResolveAccountIdResult { + accountId: string | null; + error: string | null; +} + +/** + * Resolves the accountId from MCP auth info or an override parameter. + * Validates access when an org API key attempts to use an account_id override. + * + * @param params - The auth info and optional account_id override. + * @returns The resolved accountId or an error message. + */ +export async function resolveAccountId({ + authInfo, + accountIdOverride, +}: ResolveAccountIdParams): Promise { + const authAccountId = authInfo?.extra?.accountId; + const authOrgId = authInfo?.extra?.orgId; + + if (authAccountId) { + // If account_id override is provided, validate access (for org API keys) + if (accountIdOverride && accountIdOverride !== authAccountId) { + const hasAccess = await canAccessAccount({ + orgId: authOrgId, + targetAccountId: accountIdOverride, + }); + if (!hasAccess) { + return { accountId: null, error: "Access denied to specified account_id" }; + } + return { accountId: accountIdOverride, error: null }; + } + return { accountId: authAccountId, error: null }; + } + + if (accountIdOverride) { + return { accountId: accountIdOverride, error: null }; + } + + return { + accountId: null, + error: + "Authentication required. Provide an API key via Authorization: Bearer header, or provide account_id from the system prompt context.", + }; +} diff --git a/lib/mcp/tools/artists/__tests__/registerCreateNewArtistTool.test.ts b/lib/mcp/tools/artists/__tests__/registerCreateNewArtistTool.test.ts index 1d8ce004..42056f75 100644 --- a/lib/mcp/tools/artists/__tests__/registerCreateNewArtistTool.test.ts +++ b/lib/mcp/tools/artists/__tests__/registerCreateNewArtistTool.test.ts @@ -1,8 +1,11 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; const mockCreateArtistInDb = vi.fn(); const mockCopyRoom = vi.fn(); +const mockCanAccessAccount = vi.fn(); vi.mock("@/lib/artists/createArtistInDb", () => ({ createArtistInDb: (...args: unknown[]) => mockCreateArtistInDb(...args), @@ -12,11 +15,39 @@ vi.mock("@/lib/rooms/copyRoom", () => ({ copyRoom: (...args: unknown[]) => mockCopyRoom(...args), })); +vi.mock("@/lib/organizations/canAccessAccount", () => ({ + canAccessAccount: (...args: unknown[]) => mockCanAccessAccount(...args), +})); + import { registerCreateNewArtistTool } from "../registerCreateNewArtistTool"; +type ServerRequestHandlerExtra = RequestHandlerExtra; + +/** + * Creates a mock extra object with optional authInfo. + */ +function createMockExtra(authInfo?: { + accountId?: string; + orgId?: string | null; +}): ServerRequestHandlerExtra { + return { + authInfo: authInfo + ? { + token: "test-token", + scopes: ["mcp:tools"], + clientId: authInfo.accountId, + extra: { + accountId: authInfo.accountId, + orgId: authInfo.orgId ?? null, + }, + } + : undefined, + } as unknown as ServerRequestHandlerExtra; +} + describe("registerCreateNewArtistTool", () => { let mockServer: McpServer; - let registeredHandler: (args: unknown) => Promise; + let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise; beforeEach(() => { vi.clearAllMocks(); @@ -40,7 +71,7 @@ describe("registerCreateNewArtistTool", () => { ); }); - it("creates an artist and returns success", async () => { + it("creates an artist and returns success with account_id", async () => { const mockArtist = { id: "artist-123", account_id: "artist-123", @@ -50,10 +81,13 @@ describe("registerCreateNewArtistTool", () => { }; mockCreateArtistInDb.mockResolvedValue(mockArtist); - const result = await registeredHandler({ - name: "Test Artist", - account_id: "owner-456", - }); + const result = await registeredHandler( + { + name: "Test Artist", + account_id: "owner-456", + }, + createMockExtra(), + ); expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "owner-456", undefined); expect(result).toEqual({ @@ -66,6 +100,34 @@ describe("registerCreateNewArtistTool", () => { }); }); + it("creates an artist using auth info accountId", async () => { + const mockArtist = { + id: "artist-123", + account_id: "artist-123", + name: "Test Artist", + account_info: [{ image: null }], + account_socials: [], + }; + mockCreateArtistInDb.mockResolvedValue(mockArtist); + + const result = await registeredHandler( + { + name: "Test Artist", + }, + createMockExtra({ accountId: "auth-account-123" }), + ); + + expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "auth-account-123", undefined); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Successfully created artist"), + }, + ], + }); + }); + it("copies room when active_conversation_id is provided", async () => { const mockArtist = { id: "artist-123", @@ -77,11 +139,14 @@ describe("registerCreateNewArtistTool", () => { mockCreateArtistInDb.mockResolvedValue(mockArtist); mockCopyRoom.mockResolvedValue("new-room-789"); - const result = await registeredHandler({ - name: "Test Artist", - account_id: "owner-456", - active_conversation_id: "source-room-111", - }); + const result = await registeredHandler( + { + name: "Test Artist", + account_id: "owner-456", + active_conversation_id: "source-room-111", + }, + createMockExtra(), + ); expect(mockCopyRoom).toHaveBeenCalledWith("source-room-111", "artist-123"); expect(result).toEqual({ @@ -104,11 +169,14 @@ describe("registerCreateNewArtistTool", () => { }; mockCreateArtistInDb.mockResolvedValue(mockArtist); - await registeredHandler({ - name: "Test Artist", - account_id: "owner-456", - organization_id: "org-999", - }); + await registeredHandler( + { + name: "Test Artist", + account_id: "owner-456", + organization_id: "org-999", + }, + createMockExtra(), + ); expect(mockCreateArtistInDb).toHaveBeenCalledWith("Test Artist", "owner-456", "org-999"); }); @@ -116,10 +184,13 @@ describe("registerCreateNewArtistTool", () => { it("returns error when artist creation fails", async () => { mockCreateArtistInDb.mockResolvedValue(null); - const result = await registeredHandler({ - name: "Test Artist", - account_id: "owner-456", - }); + const result = await registeredHandler( + { + name: "Test Artist", + account_id: "owner-456", + }, + createMockExtra(), + ); expect(result).toEqual({ content: [ @@ -134,16 +205,88 @@ describe("registerCreateNewArtistTool", () => { it("returns error with message when exception is thrown", async () => { mockCreateArtistInDb.mockRejectedValue(new Error("Database connection failed")); - const result = await registeredHandler({ + const result = await registeredHandler( + { + name: "Test Artist", + account_id: "owner-456", + }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Database connection failed"), + }, + ], + }); + }); + + it("allows account_id override for org auth with access", async () => { + mockCanAccessAccount.mockResolvedValue(true); + const mockArtist = { + id: "artist-123", + account_id: "artist-123", name: "Test Artist", - account_id: "owner-456", + account_info: [{ image: null }], + account_socials: [], + }; + mockCreateArtistInDb.mockResolvedValue(mockArtist); + + await registeredHandler( + { + name: "Test Artist", + account_id: "target-account-456", + }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); + + expect(mockCanAccessAccount).toHaveBeenCalledWith({ + orgId: "org-account-id", + targetAccountId: "target-account-456", }); + expect(mockCreateArtistInDb).toHaveBeenCalledWith( + "Test Artist", + "target-account-456", + undefined, + ); + }); + + it("returns error when org auth lacks access to account_id", async () => { + mockCanAccessAccount.mockResolvedValue(false); + + const result = await registeredHandler( + { + name: "Test Artist", + account_id: "target-account-456", + }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); expect(result).toEqual({ content: [ { type: "text", - text: expect.stringContaining("Database connection failed"), + text: expect.stringContaining("Access denied to specified account_id"), + }, + ], + }); + }); + + it("returns error when neither auth nor account_id is provided", async () => { + const result = await registeredHandler( + { + name: "Test Artist", + }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Authentication required"), }, ], }); diff --git a/lib/mcp/tools/artists/registerCreateNewArtistTool.ts b/lib/mcp/tools/artists/registerCreateNewArtistTool.ts index ea11c6ed..cad17806 100644 --- a/lib/mcp/tools/artists/registerCreateNewArtistTool.ts +++ b/lib/mcp/tools/artists/registerCreateNewArtistTool.ts @@ -1,5 +1,9 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; import { createArtistInDb, type CreateArtistResult, @@ -14,7 +18,8 @@ const createNewArtistSchema = z.object({ .string() .optional() .describe( - "The account ID to create the artist for. Only required for organization API keys creating artists on behalf of other accounts.", + "The account ID to create the artist for. Only required for organization API keys creating artists on behalf of other accounts. " + + "If not provided, the account ID will be resolved from the authenticated API key.", ), active_conversation_id: z .string() @@ -57,26 +62,36 @@ export function registerCreateNewArtistTool(server: McpServer): void { { description: "Create a new artist account in the system. " + + "Requires authentication via API key (Authorization: Bearer header). " + "The account_id parameter is optional — only provide it when using an organization API key to create artists on behalf of other accounts. " + "The active_conversation_id parameter is optional — when omitted, use the active_conversation_id from the system prompt " + "to copy the conversation. Never ask the user to provide a room ID. " + "The organization_id parameter is optional — use the organization_id from the system prompt context to link the artist to the user's selected organization.", inputSchema: createNewArtistSchema, }, - async (args: CreateNewArtistArgs) => { + async (args: CreateNewArtistArgs, extra: RequestHandlerExtra) => { try { const { name, account_id, active_conversation_id, organization_id } = args; - if (!account_id) { - return getToolResultError( - "account_id is required. Provide it from the system prompt context.", - ); + // Resolve accountId from auth or use provided account_id + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId: resolvedAccountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: account_id, + }); + + if (error) { + return getToolResultError(error); + } + + if (!resolvedAccountId) { + return getToolResultError("Failed to resolve account ID"); } // Create the artist account (with optional org linking) const artist = await createArtistInDb( name, - account_id, + resolvedAccountId, organization_id ?? undefined, ); diff --git a/lib/mcp/verifyApiKey.ts b/lib/mcp/verifyApiKey.ts new file mode 100644 index 00000000..4bcc1b65 --- /dev/null +++ b/lib/mcp/verifyApiKey.ts @@ -0,0 +1,45 @@ +import type { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; +import { getApiKeyDetails } from "@/lib/keys/getApiKeyDetails"; + +export interface McpAuthInfoExtra extends Record { + accountId: string; + orgId: string | null; +} + +export interface McpAuthInfo extends AuthInfo { + extra: McpAuthInfoExtra; +} + +/** + * Verifies an API key and returns auth info with account details. + * + * @param _req - The request object (unused). + * @param bearerToken - The API key from the Authorization: Bearer header. + * @returns AuthInfo with accountId and orgId, or undefined if invalid. + */ +export async function verifyApiKey( + _req: Request, + bearerToken?: string, +): Promise { + if (!bearerToken) { + return undefined; + } + + const apiKey = bearerToken; + + const keyDetails = await getApiKeyDetails(apiKey); + + if (!keyDetails) { + return undefined; + } + + return { + token: apiKey, + scopes: ["mcp:tools"], + clientId: keyDetails.accountId, + extra: { + accountId: keyDetails.accountId, + orgId: keyDetails.orgId, + }, + }; +}